|
|
@@ -20,43 +20,43 @@ import (
|
|
|
|
|
|
const viewTemplateStr = `{{define "common"}}
|
|
|
// View returns a readonly view of {{.StructName}}.
|
|
|
-func (p *{{.StructName}}) View() {{.ViewName}} {
|
|
|
- return {{.ViewName}}{ж: p}
|
|
|
+func (p *{{.StructName}}{{.TypeParamNames}}) View() {{.ViewName}}{{.TypeParamNames}} {
|
|
|
+ return {{.ViewName}}{{.TypeParamNames}}{ж: p}
|
|
|
}
|
|
|
|
|
|
-// {{.ViewName}} provides a read-only view over {{.StructName}}.
|
|
|
+// {{.ViewName}}{{.TypeParamNames}} provides a read-only view over {{.StructName}}{{.TypeParamNames}}.
|
|
|
//
|
|
|
// Its methods should only be called if ` + "`Valid()`" + ` returns true.
|
|
|
-type {{.ViewName}} struct {
|
|
|
+type {{.ViewName}}{{.TypeParams}} struct {
|
|
|
// ж is the underlying mutable value, named with a hard-to-type
|
|
|
// character that looks pointy like a pointer.
|
|
|
// It is named distinctively to make you think of how dangerous it is to escape
|
|
|
// to callers. You must not let callers be able to mutate it.
|
|
|
- ж *{{.StructName}}
|
|
|
+ ж *{{.StructName}}{{.TypeParamNames}}
|
|
|
}
|
|
|
|
|
|
// Valid reports whether underlying value is non-nil.
|
|
|
-func (v {{.ViewName}}) Valid() bool { return v.ж != nil }
|
|
|
+func (v {{.ViewName}}{{.TypeParamNames}}) Valid() bool { return v.ж != nil }
|
|
|
|
|
|
// AsStruct returns a clone of the underlying value which aliases no memory with
|
|
|
// the original.
|
|
|
-func (v {{.ViewName}}) AsStruct() *{{.StructName}}{
|
|
|
+func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypeParamNames}}{
|
|
|
if v.ж == nil {
|
|
|
return nil
|
|
|
}
|
|
|
return v.ж.Clone()
|
|
|
}
|
|
|
|
|
|
-func (v {{.ViewName}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
|
|
|
+func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
|
|
|
|
|
|
-func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error {
|
|
|
+func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error {
|
|
|
if v.ж != nil {
|
|
|
return errors.New("already initialized")
|
|
|
}
|
|
|
if len(b) == 0 {
|
|
|
return nil
|
|
|
}
|
|
|
- var x {{.StructName}}
|
|
|
+ var x {{.StructName}}{{.TypeParamNames}}
|
|
|
if err := json.Unmarshal(b, &x); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
@@ -65,17 +65,17 @@ func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error {
|
|
|
}
|
|
|
|
|
|
{{end}}
|
|
|
-{{define "valueField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} { return v.ж.{{.FieldName}} }
|
|
|
+{{define "valueField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { return v.ж.{{.FieldName}} }
|
|
|
{{end}}
|
|
|
-{{define "byteSliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.ByteSlice[{{.FieldType}}] { return views.ByteSliceOf(v.ж.{{.FieldName}}) }
|
|
|
+{{define "byteSliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.ByteSlice[{{.FieldType}}] { return views.ByteSliceOf(v.ж.{{.FieldName}}) }
|
|
|
{{end}}
|
|
|
-{{define "sliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.Slice[{{.FieldType}}] { return views.SliceOf(v.ж.{{.FieldName}}) }
|
|
|
+{{define "sliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Slice[{{.FieldType}}] { return views.SliceOf(v.ж.{{.FieldName}}) }
|
|
|
{{end}}
|
|
|
-{{define "viewSliceField"}}func (v {{.ViewName}}) {{.FieldName}}() views.SliceView[{{.FieldType}},{{.FieldViewName}}] { return views.SliceOfViews[{{.FieldType}},{{.FieldViewName}}](v.ж.{{.FieldName}}) }
|
|
|
+{{define "viewSliceField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.SliceView[{{.FieldType}},{{.FieldViewName}}] { return views.SliceOfViews[{{.FieldType}},{{.FieldViewName}}](v.ж.{{.FieldName}}) }
|
|
|
{{end}}
|
|
|
-{{define "viewField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}}View { return v.ж.{{.FieldName}}.View() }
|
|
|
+{{define "viewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return v.ж.{{.FieldName}}.View() }
|
|
|
{{end}}
|
|
|
-{{define "valuePointerField"}}func (v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} {
|
|
|
+{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {
|
|
|
if v.ж.{{.FieldName}} == nil {
|
|
|
return nil
|
|
|
}
|
|
|
@@ -85,21 +85,21 @@ func (v *{{.ViewName}}) UnmarshalJSON(b []byte) error {
|
|
|
|
|
|
{{end}}
|
|
|
{{define "mapField"}}
|
|
|
-func(v {{.ViewName}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.ж.{{.FieldName}})}
|
|
|
+func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.ж.{{.FieldName}})}
|
|
|
{{end}}
|
|
|
{{define "mapFnField"}}
|
|
|
-func(v {{.ViewName}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.ж.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} {
|
|
|
+func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.ж.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} {
|
|
|
return {{.MapFn}}
|
|
|
})}
|
|
|
{{end}}
|
|
|
{{define "mapSliceField"}}
|
|
|
-func(v {{.ViewName}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) }
|
|
|
+func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.ж.{{.FieldName}}) }
|
|
|
{{end}}
|
|
|
-{{define "unsupportedField"}}func(v {{.ViewName}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")}
|
|
|
+{{define "unsupportedField"}}func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")}
|
|
|
{{end}}
|
|
|
-{{define "stringFunc"}}func(v {{.ViewName}}) String() string { return v.ж.String() }
|
|
|
+{{define "stringFunc"}}func(v {{.ViewName}}{{.TypeParamNames}}) String() string { return v.ж.String() }
|
|
|
{{end}}
|
|
|
-{{define "equalFunc"}}func(v {{.ViewName}}) Equal(v2 {{.ViewName}}) bool { return v.ж.Equal(v2.ж) }
|
|
|
+{{define "equalFunc"}}func(v {{.ViewName}}{{.TypeParamNames}}) Equal(v2 {{.ViewName}}{{.TypeParamNames}}) bool { return v.ж.Equal(v2.ж) }
|
|
|
{{end}}
|
|
|
`
|
|
|
|
|
|
@@ -131,8 +131,11 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
it.Import("errors")
|
|
|
|
|
|
args := struct {
|
|
|
- StructName string
|
|
|
- ViewName string
|
|
|
+ StructName string
|
|
|
+ ViewName string
|
|
|
+ TypeParams string // e.g. [T constraints.Integer]
|
|
|
+ TypeParamNames string // e.g. [T]
|
|
|
+
|
|
|
FieldName string
|
|
|
FieldType string
|
|
|
FieldViewName string
|
|
|
@@ -143,9 +146,12 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
MapFn string
|
|
|
}{
|
|
|
StructName: typ.Obj().Name(),
|
|
|
- ViewName: typ.Obj().Name() + "View",
|
|
|
+ ViewName: typ.Origin().Obj().Name() + "View",
|
|
|
}
|
|
|
|
|
|
+ typeParams := typ.Origin().TypeParams()
|
|
|
+ args.TypeParams, args.TypeParamNames = codegen.FormatTypeParams(typeParams, it)
|
|
|
+
|
|
|
writeTemplate := func(name string) {
|
|
|
if err := viewTemplate.ExecuteTemplate(buf, name, args); err != nil {
|
|
|
log.Fatal(err)
|
|
|
@@ -182,19 +188,35 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
it.Import("tailscale.com/types/views")
|
|
|
shallow, deep, base := requiresCloning(elem)
|
|
|
if deep {
|
|
|
- if _, isPtr := elem.(*types.Pointer); isPtr {
|
|
|
- args.FieldViewName = it.QualifiedName(base) + "View"
|
|
|
- writeTemplate("viewSliceField")
|
|
|
- } else {
|
|
|
- writeTemplate("unsupportedField")
|
|
|
+ switch elem.Underlying().(type) {
|
|
|
+ case *types.Pointer:
|
|
|
+ if _, isIface := base.Underlying().(*types.Interface); !isIface {
|
|
|
+ args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View")
|
|
|
+ writeTemplate("viewSliceField")
|
|
|
+ } else {
|
|
|
+ writeTemplate("unsupportedField")
|
|
|
+ }
|
|
|
+ continue
|
|
|
+ case *types.Interface:
|
|
|
+ if viewType := viewTypeForValueType(elem); viewType != nil {
|
|
|
+ args.FieldViewName = it.QualifiedName(viewType)
|
|
|
+ writeTemplate("viewSliceField")
|
|
|
+ continue
|
|
|
+ }
|
|
|
}
|
|
|
+ writeTemplate("unsupportedField")
|
|
|
continue
|
|
|
} else if shallow {
|
|
|
- if _, isBasic := base.(*types.Basic); isBasic {
|
|
|
+ switch base.Underlying().(type) {
|
|
|
+ case *types.Basic, *types.Interface:
|
|
|
writeTemplate("unsupportedField")
|
|
|
- } else {
|
|
|
- args.FieldViewName = it.QualifiedName(base) + "View"
|
|
|
- writeTemplate("viewSliceField")
|
|
|
+ default:
|
|
|
+ if _, isIface := base.Underlying().(*types.Interface); !isIface {
|
|
|
+ args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View")
|
|
|
+ writeTemplate("viewSliceField")
|
|
|
+ } else {
|
|
|
+ writeTemplate("unsupportedField")
|
|
|
+ }
|
|
|
}
|
|
|
continue
|
|
|
}
|
|
|
@@ -205,6 +227,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
strucT := underlying
|
|
|
args.FieldType = it.QualifiedName(fieldType)
|
|
|
if codegen.ContainsPointers(strucT) {
|
|
|
+ args.FieldViewName = appendNameSuffix(args.FieldType, "View")
|
|
|
writeTemplate("viewField")
|
|
|
continue
|
|
|
}
|
|
|
@@ -229,7 +252,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
args.MapFn = "t.View()"
|
|
|
template = "mapFnField"
|
|
|
args.MapValueType = it.QualifiedName(mElem)
|
|
|
- args.MapValueView = args.MapValueType + "View"
|
|
|
+ args.MapValueView = appendNameSuffix(args.MapValueType, "View")
|
|
|
} else {
|
|
|
template = "mapField"
|
|
|
args.MapValueType = it.QualifiedName(mElem)
|
|
|
@@ -249,15 +272,20 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
case *types.Pointer:
|
|
|
ptr := x
|
|
|
pElem := ptr.Elem()
|
|
|
- switch pElem.(type) {
|
|
|
- case *types.Struct, *types.Named:
|
|
|
- ptrType := it.QualifiedName(ptr)
|
|
|
- viewType := it.QualifiedName(pElem) + "View"
|
|
|
- args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType)
|
|
|
- args.MapValueView = fmt.Sprintf("views.SliceView[%v,%v]", ptrType, viewType)
|
|
|
- args.MapValueType = "[]" + ptrType
|
|
|
- template = "mapFnField"
|
|
|
- default:
|
|
|
+ template = "unsupportedField"
|
|
|
+ if _, isIface := pElem.Underlying().(*types.Interface); !isIface {
|
|
|
+ switch pElem.(type) {
|
|
|
+ case *types.Struct, *types.Named:
|
|
|
+ ptrType := it.QualifiedName(ptr)
|
|
|
+ viewType := appendNameSuffix(it.QualifiedName(pElem), "View")
|
|
|
+ args.MapFn = fmt.Sprintf("views.SliceOfViews[%v,%v](t)", ptrType, viewType)
|
|
|
+ args.MapValueView = fmt.Sprintf("views.SliceView[%v,%v]", ptrType, viewType)
|
|
|
+ args.MapValueType = "[]" + ptrType
|
|
|
+ template = "mapFnField"
|
|
|
+ default:
|
|
|
+ template = "unsupportedField"
|
|
|
+ }
|
|
|
+ } else {
|
|
|
template = "unsupportedField"
|
|
|
}
|
|
|
default:
|
|
|
@@ -266,13 +294,29 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
case *types.Pointer:
|
|
|
ptr := u
|
|
|
pElem := ptr.Elem()
|
|
|
- switch pElem.(type) {
|
|
|
- case *types.Struct, *types.Named:
|
|
|
- args.MapValueType = it.QualifiedName(ptr)
|
|
|
- args.MapValueView = it.QualifiedName(pElem) + "View"
|
|
|
+ if _, isIface := pElem.Underlying().(*types.Interface); !isIface {
|
|
|
+ switch pElem.(type) {
|
|
|
+ case *types.Struct, *types.Named:
|
|
|
+ args.MapValueType = it.QualifiedName(ptr)
|
|
|
+ args.MapValueView = appendNameSuffix(it.QualifiedName(pElem), "View")
|
|
|
+ args.MapFn = "t.View()"
|
|
|
+ template = "mapFnField"
|
|
|
+ default:
|
|
|
+ template = "unsupportedField"
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ template = "unsupportedField"
|
|
|
+ }
|
|
|
+ case *types.Interface, *types.TypeParam:
|
|
|
+ if viewType := viewTypeForValueType(u); viewType != nil {
|
|
|
+ args.MapValueType = it.QualifiedName(u)
|
|
|
+ args.MapValueView = it.QualifiedName(viewType)
|
|
|
args.MapFn = "t.View()"
|
|
|
template = "mapFnField"
|
|
|
- default:
|
|
|
+ } else if !codegen.ContainsPointers(u) {
|
|
|
+ args.MapValueType = it.QualifiedName(mElem)
|
|
|
+ template = "mapField"
|
|
|
+ } else {
|
|
|
template = "unsupportedField"
|
|
|
}
|
|
|
default:
|
|
|
@@ -283,14 +327,28 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
case *types.Pointer:
|
|
|
ptr := underlying
|
|
|
_, deep, base := requiresCloning(ptr)
|
|
|
+
|
|
|
if deep {
|
|
|
- args.FieldType = it.QualifiedName(base)
|
|
|
- writeTemplate("viewField")
|
|
|
+ if _, isIface := base.Underlying().(*types.Interface); !isIface {
|
|
|
+ args.FieldType = it.QualifiedName(base)
|
|
|
+ args.FieldViewName = appendNameSuffix(args.FieldType, "View")
|
|
|
+ writeTemplate("viewField")
|
|
|
+ } else {
|
|
|
+ writeTemplate("unsupportedField")
|
|
|
+ }
|
|
|
} else {
|
|
|
args.FieldType = it.QualifiedName(ptr)
|
|
|
writeTemplate("valuePointerField")
|
|
|
}
|
|
|
continue
|
|
|
+ case *types.Interface:
|
|
|
+ // If fieldType is an interface with a "View() {ViewType}" method, it can be used to clone the field.
|
|
|
+ // This includes scenarios where fieldType is a constrained type parameter.
|
|
|
+ if viewType := viewTypeForValueType(underlying); viewType != nil {
|
|
|
+ args.FieldViewName = it.QualifiedName(viewType)
|
|
|
+ writeTemplate("viewField")
|
|
|
+ continue
|
|
|
+ }
|
|
|
}
|
|
|
writeTemplate("unsupportedField")
|
|
|
}
|
|
|
@@ -318,7 +376,27 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
}
|
|
|
}
|
|
|
fmt.Fprintf(buf, "\n")
|
|
|
- buf.Write(codegen.AssertStructUnchanged(t, args.StructName, "View", it))
|
|
|
+ buf.Write(codegen.AssertStructUnchanged(t, args.StructName, typeParams, "View", it))
|
|
|
+}
|
|
|
+
|
|
|
+func appendNameSuffix(name, suffix string) string {
|
|
|
+ if idx := strings.IndexRune(name, '['); idx != -1 {
|
|
|
+ // Insert suffix after the type name, but before type parameters.
|
|
|
+ return name[:idx] + suffix + name[idx:]
|
|
|
+ }
|
|
|
+ return name + suffix
|
|
|
+}
|
|
|
+
|
|
|
+func viewTypeForValueType(typ types.Type) types.Type {
|
|
|
+ viewMethod := codegen.LookupMethod(typ, "View")
|
|
|
+ if viewMethod == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ sig, ok := viewMethod.Type().(*types.Signature)
|
|
|
+ if !ok || sig.Results().Len() != 1 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return sig.Results().At(0).Type()
|
|
|
}
|
|
|
|
|
|
var (
|