|
|
@@ -13,9 +13,11 @@ import (
|
|
|
"html/template"
|
|
|
"log"
|
|
|
"os"
|
|
|
+ "slices"
|
|
|
"strings"
|
|
|
|
|
|
"tailscale.com/util/codegen"
|
|
|
+ "tailscale.com/util/must"
|
|
|
)
|
|
|
|
|
|
const viewTemplateStr = `{{define "common"}}
|
|
|
@@ -75,6 +77,8 @@ func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error {
|
|
|
{{end}}
|
|
|
{{define "viewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return v.ж.{{.FieldName}}.View() }
|
|
|
{{end}}
|
|
|
+{{define "makeViewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return {{.MakeViewFnName}}(&v.ж.{{.FieldName}}) }
|
|
|
+{{end}}
|
|
|
{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {
|
|
|
if v.ж.{{.FieldName}} == nil {
|
|
|
return nil
|
|
|
@@ -144,6 +148,9 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
MapValueType string
|
|
|
MapValueView string
|
|
|
MapFn string
|
|
|
+
|
|
|
+ // MakeViewFnName is the name of the function that accepts a value and returns a readonly view of it.
|
|
|
+ MakeViewFnName string
|
|
|
}{
|
|
|
StructName: typ.Obj().Name(),
|
|
|
ViewName: typ.Origin().Obj().Name() + "View",
|
|
|
@@ -227,8 +234,18 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
|
|
|
strucT := underlying
|
|
|
args.FieldType = it.QualifiedName(fieldType)
|
|
|
if codegen.ContainsPointers(strucT) {
|
|
|
- args.FieldViewName = appendNameSuffix(args.FieldType, "View")
|
|
|
- writeTemplate("viewField")
|
|
|
+ if viewType := viewTypeForValueType(fieldType); viewType != nil {
|
|
|
+ args.FieldViewName = it.QualifiedName(viewType)
|
|
|
+ writeTemplate("viewField")
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if viewType, makeViewFn := viewTypeForContainerType(fieldType); viewType != nil {
|
|
|
+ args.FieldViewName = it.QualifiedName(viewType)
|
|
|
+ args.MakeViewFnName = it.PackagePrefix(makeViewFn.Pkg()) + makeViewFn.Name()
|
|
|
+ writeTemplate("makeViewField")
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ writeTemplate("unsupportedField")
|
|
|
continue
|
|
|
}
|
|
|
writeTemplate("valueField")
|
|
|
@@ -388,6 +405,9 @@ func appendNameSuffix(name, suffix string) string {
|
|
|
}
|
|
|
|
|
|
func viewTypeForValueType(typ types.Type) types.Type {
|
|
|
+ if ptr, ok := typ.(*types.Pointer); ok {
|
|
|
+ return viewTypeForValueType(ptr.Elem())
|
|
|
+ }
|
|
|
viewMethod := codegen.LookupMethod(typ, "View")
|
|
|
if viewMethod == nil {
|
|
|
return nil
|
|
|
@@ -399,12 +419,116 @@ func viewTypeForValueType(typ types.Type) types.Type {
|
|
|
return sig.Results().At(0).Type()
|
|
|
}
|
|
|
|
|
|
+func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) {
|
|
|
+ // The container type should be an instantiated generic type,
|
|
|
+ // with its first type parameter specifying the element type.
|
|
|
+ containerType, ok := typ.(*types.Named)
|
|
|
+ if !ok || containerType.TypeArgs().Len() == 0 {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // Look up the view type for the container type.
|
|
|
+ // It must include an additional type parameter specifying the element's view type.
|
|
|
+ // For example, Container[T] => ContainerView[T, V].
|
|
|
+ containerViewTypeName := containerType.Obj().Name() + "View"
|
|
|
+ containerViewTypeObj, ok := containerType.Obj().Pkg().Scope().Lookup(containerViewTypeName).(*types.TypeName)
|
|
|
+ if !ok {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+ containerViewGenericType, ok := containerViewTypeObj.Type().(*types.Named)
|
|
|
+ if !ok || containerViewGenericType.TypeParams().Len() != containerType.TypeArgs().Len()+1 {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // Create a list of type arguments for instantiating the container view type.
|
|
|
+ // Include all type arguments specified for the container type...
|
|
|
+ containerViewTypeArgs := make([]types.Type, containerViewGenericType.TypeParams().Len())
|
|
|
+ for i := range containerType.TypeArgs().Len() {
|
|
|
+ containerViewTypeArgs[i] = containerType.TypeArgs().At(i)
|
|
|
+ }
|
|
|
+ // ...and add the element view type.
|
|
|
+ // For that, we need to first determine the named elem type...
|
|
|
+ elemType, ok := baseType(containerType.TypeArgs().At(0)).(*types.Named)
|
|
|
+ if !ok {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+ // ...then infer the view type from it.
|
|
|
+ var elemViewType *types.Named
|
|
|
+ elemTypeName := elemType.Obj().Name()
|
|
|
+ elemViewTypeBaseName := elemType.Obj().Name() + "View"
|
|
|
+ if elemViewTypeName, ok := elemType.Obj().Pkg().Scope().Lookup(elemViewTypeBaseName).(*types.TypeName); ok {
|
|
|
+ // The elem's view type is already defined in the same package as the elem type.
|
|
|
+ elemViewType = elemViewTypeName.Type().(*types.Named)
|
|
|
+ } else if slices.Contains(typeNames, elemTypeName) {
|
|
|
+ // The elem's view type has not been generated yet, but we can define
|
|
|
+ // and use a blank type with the expected view type name.
|
|
|
+ elemViewTypeName = types.NewTypeName(0, elemType.Obj().Pkg(), elemViewTypeBaseName, nil)
|
|
|
+ elemViewType = types.NewNamed(elemViewTypeName, types.NewStruct(nil, nil), nil)
|
|
|
+ if elemTypeParams := elemType.TypeParams(); elemTypeParams != nil {
|
|
|
+ elemViewType.SetTypeParams(collectTypeParams(elemTypeParams))
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // The elem view type does not exist and won't be generated.
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+ // If elemType is an instantiated generic type, instantiate the elemViewType as well.
|
|
|
+ if elemTypeArgs := elemType.TypeArgs(); elemTypeArgs != nil {
|
|
|
+ elemViewType = must.Get(types.Instantiate(nil, elemViewType, collectTypes(elemTypeArgs), false)).(*types.Named)
|
|
|
+ }
|
|
|
+ // And finally set the elemViewType as the last type argument.
|
|
|
+ containerViewTypeArgs[len(containerViewTypeArgs)-1] = elemViewType
|
|
|
+
|
|
|
+ // Instantiate the container view type with the specified type arguments.
|
|
|
+ containerViewType := must.Get(types.Instantiate(nil, containerViewGenericType, containerViewTypeArgs, false))
|
|
|
+ // Look up a function to create a view of a container.
|
|
|
+ // It should be in the same package as the container type, named {ViewType}Of,
|
|
|
+ // and have a signature like {ViewType}Of(c *Container[T]) ContainerView[T, V].
|
|
|
+ makeContainerView, ok := containerType.Obj().Pkg().Scope().Lookup(containerViewTypeName + "Of").(*types.Func)
|
|
|
+ if !ok {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+ return containerViewType.(*types.Named), makeContainerView
|
|
|
+}
|
|
|
+
|
|
|
+func baseType(typ types.Type) types.Type {
|
|
|
+ if ptr, ok := typ.(*types.Pointer); ok {
|
|
|
+ return ptr.Elem()
|
|
|
+ }
|
|
|
+ return typ
|
|
|
+}
|
|
|
+
|
|
|
+func collectTypes(list *types.TypeList) []types.Type {
|
|
|
+ // TODO(nickkhyl): use slices.Collect in Go 1.23?
|
|
|
+ if list.Len() == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ res := make([]types.Type, list.Len())
|
|
|
+ for i := range res {
|
|
|
+ res[i] = list.At(i)
|
|
|
+ }
|
|
|
+ return res
|
|
|
+}
|
|
|
+
|
|
|
+func collectTypeParams(list *types.TypeParamList) []*types.TypeParam {
|
|
|
+ if list.Len() == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ res := make([]*types.TypeParam, list.Len())
|
|
|
+ for i := range res {
|
|
|
+ p := list.At(i)
|
|
|
+ res[i] = types.NewTypeParam(p.Obj(), p.Constraint())
|
|
|
+ }
|
|
|
+ return res
|
|
|
+}
|
|
|
+
|
|
|
var (
|
|
|
flagTypes = flag.String("type", "", "comma-separated list of types; required")
|
|
|
flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
|
|
|
flagCloneFunc = flag.Bool("clonefunc", false, "add a top-level Clone func")
|
|
|
|
|
|
flagCloneOnlyTypes = flag.String("clone-only-type", "", "comma-separated list of types (a subset of --type) that should only generate a go:generate clone line and not actual views")
|
|
|
+
|
|
|
+ typeNames []string
|
|
|
)
|
|
|
|
|
|
func main() {
|
|
|
@@ -415,7 +539,7 @@ func main() {
|
|
|
flag.Usage()
|
|
|
os.Exit(2)
|
|
|
}
|
|
|
- typeNames := strings.Split(*flagTypes, ",")
|
|
|
+ typeNames = strings.Split(*flagTypes, ",")
|
|
|
|
|
|
var flagArgs []string
|
|
|
flagArgs = append(flagArgs, fmt.Sprintf("-clonefunc=%v", *flagCloneFunc))
|