Browse Source

cmd/viewer, types/views, util/codegen: add viewer support for custom container types

This adds support for container-like types such as Container[T] that
don't explicitly specify a view type for T. Instead, a package implementing
a container type should also implement and export a ContainerView[T, V] type
and a ContainerViewOf(*Container[T]) ContainerView[T, V] function, which
returns a view for the specified container, inferring the element view type V
from the element type T.

Updates #12736

Signed-off-by: Nick Khyl <[email protected]>
Nick Khyl 1 year ago
parent
commit
20562a4fb9

+ 49 - 1
cmd/viewer/tests/tests.go

@@ -9,10 +9,11 @@ import (
 	"net/netip"
 
 	"golang.org/x/exp/constraints"
+	"tailscale.com/types/ptr"
 	"tailscale.com/types/views"
 )
 
-//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct --clone-only-type=OnlyGetClone
+//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers --clone-only-type=OnlyGetClone
 
 type StructWithoutPtrs struct {
 	Int int
@@ -114,3 +115,50 @@ type GenericCloneableStruct[T views.ViewCloner[T, V], V views.StructView[T]] str
 	PtrValueMap map[string]*T
 	SliceMap    map[string][]T
 }
+
+// Container is a pre-defined container type, such as a collection, an optional
+// value or a generic wrapper.
+type Container[T any] struct {
+	Item T
+}
+
+func (c *Container[T]) Clone() *Container[T] {
+	if c == nil {
+		return nil
+	}
+	if cloner, ok := any(c.Item).(views.Cloner[T]); ok {
+		return &Container[T]{cloner.Clone()}
+	}
+	if !views.ContainsPointers[T]() {
+		return ptr.To(*c)
+	}
+	panic(fmt.Errorf("%T contains pointers, but is not cloneable", c.Item))
+}
+
+// ContainerView is a pre-defined readonly view of a Container[T].
+type ContainerView[T views.ViewCloner[T, V], V views.StructView[T]] struct {
+	// ж is the underlying mutable value, named with a hard-to-type
+	// character that looks pointy like a pointer.
+	// It is named distinctively to make you think of how dangerous it is to escape
+	// to callers. You must not let callers be able to mutate it.
+	ж *Container[T]
+}
+
+func (cv ContainerView[T, V]) Item() V {
+	return cv.ж.Item.View()
+}
+
+func ContainerViewOf[T views.ViewCloner[T, V], V views.StructView[T]](c *Container[T]) ContainerView[T, V] {
+	return ContainerView[T, V]{c}
+}
+
+type GenericBasicStruct[T BasicType] struct {
+	Value T
+}
+
+type StructWithContainers struct {
+	IntContainer             Container[int]
+	CloneableContainer       Container[*StructWithPtrs]
+	BasicGenericContainer    Container[GenericBasicStruct[int]]
+	ClonableGenericContainer Container[*GenericNoPtrsStruct[int]]
+}

+ 21 - 0
cmd/viewer/tests/tests_clone.go

@@ -416,3 +416,24 @@ func _GenericCloneableStructCloneNeedsRegeneration[T views.ViewCloner[T, V], V v
 		SliceMap    map[string][]T
 	}{})
 }
+
+// Clone makes a deep copy of StructWithContainers.
+// The result aliases no memory with the original.
+func (src *StructWithContainers) Clone() *StructWithContainers {
+	if src == nil {
+		return nil
+	}
+	dst := new(StructWithContainers)
+	*dst = *src
+	dst.CloneableContainer = *src.CloneableContainer.Clone()
+	dst.ClonableGenericContainer = *src.ClonableGenericContainer.Clone()
+	return dst
+}
+
+// A compilation failure here means this code must be regenerated, with the command at the top of this file.
+var _StructWithContainersCloneNeedsRegeneration = StructWithContainers(struct {
+	IntContainer             Container[int]
+	CloneableContainer       Container[*StructWithPtrs]
+	BasicGenericContainer    Container[GenericBasicStruct[int]]
+	ClonableGenericContainer Container[*GenericNoPtrsStruct[int]]
+}{})

+ 65 - 1
cmd/viewer/tests/tests_view.go

@@ -14,7 +14,7 @@ import (
 	"tailscale.com/types/views"
 )
 
-//go:generate go run tailscale.com/cmd/cloner  -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct
+//go:generate go run tailscale.com/cmd/cloner  -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers
 
 // View returns a readonly view of StructWithPtrs.
 func (p *StructWithPtrs) View() StructWithPtrsView {
@@ -604,3 +604,67 @@ func _GenericCloneableStructViewNeedsRegeneration[T views.ViewCloner[T, V], V vi
 		SliceMap    map[string][]T
 	}{})
 }
+
+// View returns a readonly view of StructWithContainers.
+func (p *StructWithContainers) View() StructWithContainersView {
+	return StructWithContainersView{ж: p}
+}
+
+// StructWithContainersView provides a read-only view over StructWithContainers.
+//
+// Its methods should only be called if `Valid()` returns true.
+type StructWithContainersView struct {
+	// ж is the underlying mutable value, named with a hard-to-type
+	// character that looks pointy like a pointer.
+	// It is named distinctively to make you think of how dangerous it is to escape
+	// to callers. You must not let callers be able to mutate it.
+	ж *StructWithContainers
+}
+
+// Valid reports whether underlying value is non-nil.
+func (v StructWithContainersView) Valid() bool { return v.ж != nil }
+
+// AsStruct returns a clone of the underlying value which aliases no memory with
+// the original.
+func (v StructWithContainersView) AsStruct() *StructWithContainers {
+	if v.ж == nil {
+		return nil
+	}
+	return v.ж.Clone()
+}
+
+func (v StructWithContainersView) MarshalJSON() ([]byte, error) { return json.Marshal(v.ж) }
+
+func (v *StructWithContainersView) UnmarshalJSON(b []byte) error {
+	if v.ж != nil {
+		return errors.New("already initialized")
+	}
+	if len(b) == 0 {
+		return nil
+	}
+	var x StructWithContainers
+	if err := json.Unmarshal(b, &x); err != nil {
+		return err
+	}
+	v.ж = &x
+	return nil
+}
+
+func (v StructWithContainersView) IntContainer() Container[int] { return v.ж.IntContainer }
+func (v StructWithContainersView) CloneableContainer() ContainerView[*StructWithPtrs, StructWithPtrsView] {
+	return ContainerViewOf(&v.ж.CloneableContainer)
+}
+func (v StructWithContainersView) BasicGenericContainer() Container[GenericBasicStruct[int]] {
+	return v.ж.BasicGenericContainer
+}
+func (v StructWithContainersView) ClonableGenericContainer() ContainerView[*GenericNoPtrsStruct[int], GenericNoPtrsStructView[int]] {
+	return ContainerViewOf(&v.ж.ClonableGenericContainer)
+}
+
+// A compilation failure here means this code must be regenerated, with the command at the top of this file.
+var _StructWithContainersViewNeedsRegeneration = StructWithContainers(struct {
+	IntContainer             Container[int]
+	CloneableContainer       Container[*StructWithPtrs]
+	BasicGenericContainer    Container[GenericBasicStruct[int]]
+	ClonableGenericContainer Container[*GenericNoPtrsStruct[int]]
+}{})

+ 127 - 3
cmd/viewer/viewer.go

@@ -13,9 +13,11 @@ import (
 	"html/template"
 	"log"
 	"os"
+	"slices"
 	"strings"
 
 	"tailscale.com/util/codegen"
+	"tailscale.com/util/must"
 )
 
 const viewTemplateStr = `{{define "common"}}
@@ -75,6 +77,8 @@ func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error {
 {{end}}
 {{define "viewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return v.ж.{{.FieldName}}.View() }
 {{end}}
+{{define "makeViewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return {{.MakeViewFnName}}(&v.ж.{{.FieldName}}) }
+{{end}}
 {{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {
 	if v.ж.{{.FieldName}} == nil {
 		return nil
@@ -144,6 +148,9 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
 		MapValueType string
 		MapValueView string
 		MapFn        string
+
+		// MakeViewFnName is the name of the function that accepts a value and returns a readonly view of it.
+		MakeViewFnName string
 	}{
 		StructName: typ.Obj().Name(),
 		ViewName:   typ.Origin().Obj().Name() + "View",
@@ -227,8 +234,18 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi
 			strucT := underlying
 			args.FieldType = it.QualifiedName(fieldType)
 			if codegen.ContainsPointers(strucT) {
-				args.FieldViewName = appendNameSuffix(args.FieldType, "View")
-				writeTemplate("viewField")
+				if viewType := viewTypeForValueType(fieldType); viewType != nil {
+					args.FieldViewName = it.QualifiedName(viewType)
+					writeTemplate("viewField")
+					continue
+				}
+				if viewType, makeViewFn := viewTypeForContainerType(fieldType); viewType != nil {
+					args.FieldViewName = it.QualifiedName(viewType)
+					args.MakeViewFnName = it.PackagePrefix(makeViewFn.Pkg()) + makeViewFn.Name()
+					writeTemplate("makeViewField")
+					continue
+				}
+				writeTemplate("unsupportedField")
 				continue
 			}
 			writeTemplate("valueField")
@@ -388,6 +405,9 @@ func appendNameSuffix(name, suffix string) string {
 }
 
 func viewTypeForValueType(typ types.Type) types.Type {
+	if ptr, ok := typ.(*types.Pointer); ok {
+		return viewTypeForValueType(ptr.Elem())
+	}
 	viewMethod := codegen.LookupMethod(typ, "View")
 	if viewMethod == nil {
 		return nil
@@ -399,12 +419,116 @@ func viewTypeForValueType(typ types.Type) types.Type {
 	return sig.Results().At(0).Type()
 }
 
+func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) {
+	// The container type should be an instantiated generic type,
+	// with its first type parameter specifying the element type.
+	containerType, ok := typ.(*types.Named)
+	if !ok || containerType.TypeArgs().Len() == 0 {
+		return nil, nil
+	}
+
+	// Look up the view type for the container type.
+	// It must include an additional type parameter specifying the element's view type.
+	// For example, Container[T] => ContainerView[T, V].
+	containerViewTypeName := containerType.Obj().Name() + "View"
+	containerViewTypeObj, ok := containerType.Obj().Pkg().Scope().Lookup(containerViewTypeName).(*types.TypeName)
+	if !ok {
+		return nil, nil
+	}
+	containerViewGenericType, ok := containerViewTypeObj.Type().(*types.Named)
+	if !ok || containerViewGenericType.TypeParams().Len() != containerType.TypeArgs().Len()+1 {
+		return nil, nil
+	}
+
+	// Create a list of type arguments for instantiating the container view type.
+	// Include all type arguments specified for the container type...
+	containerViewTypeArgs := make([]types.Type, containerViewGenericType.TypeParams().Len())
+	for i := range containerType.TypeArgs().Len() {
+		containerViewTypeArgs[i] = containerType.TypeArgs().At(i)
+	}
+	// ...and add the element view type.
+	// For that, we need to first determine the named elem type...
+	elemType, ok := baseType(containerType.TypeArgs().At(0)).(*types.Named)
+	if !ok {
+		return nil, nil
+	}
+	// ...then infer the view type from it.
+	var elemViewType *types.Named
+	elemTypeName := elemType.Obj().Name()
+	elemViewTypeBaseName := elemType.Obj().Name() + "View"
+	if elemViewTypeName, ok := elemType.Obj().Pkg().Scope().Lookup(elemViewTypeBaseName).(*types.TypeName); ok {
+		// The elem's view type is already defined in the same package as the elem type.
+		elemViewType = elemViewTypeName.Type().(*types.Named)
+	} else if slices.Contains(typeNames, elemTypeName) {
+		// The elem's view type has not been generated yet, but we can define
+		// and use a blank type with the expected view type name.
+		elemViewTypeName = types.NewTypeName(0, elemType.Obj().Pkg(), elemViewTypeBaseName, nil)
+		elemViewType = types.NewNamed(elemViewTypeName, types.NewStruct(nil, nil), nil)
+		if elemTypeParams := elemType.TypeParams(); elemTypeParams != nil {
+			elemViewType.SetTypeParams(collectTypeParams(elemTypeParams))
+		}
+	} else {
+		// The elem view type does not exist and won't be generated.
+		return nil, nil
+	}
+	// If elemType is an instantiated generic type, instantiate the elemViewType as well.
+	if elemTypeArgs := elemType.TypeArgs(); elemTypeArgs != nil {
+		elemViewType = must.Get(types.Instantiate(nil, elemViewType, collectTypes(elemTypeArgs), false)).(*types.Named)
+	}
+	// And finally set the elemViewType as the last type argument.
+	containerViewTypeArgs[len(containerViewTypeArgs)-1] = elemViewType
+
+	// Instantiate the container view type with the specified type arguments.
+	containerViewType := must.Get(types.Instantiate(nil, containerViewGenericType, containerViewTypeArgs, false))
+	// Look up a function to create a view of a container.
+	// It should be in the same package as the container type, named {ViewType}Of,
+	// and have a signature like {ViewType}Of(c *Container[T]) ContainerView[T, V].
+	makeContainerView, ok := containerType.Obj().Pkg().Scope().Lookup(containerViewTypeName + "Of").(*types.Func)
+	if !ok {
+		return nil, nil
+	}
+	return containerViewType.(*types.Named), makeContainerView
+}
+
+func baseType(typ types.Type) types.Type {
+	if ptr, ok := typ.(*types.Pointer); ok {
+		return ptr.Elem()
+	}
+	return typ
+}
+
+func collectTypes(list *types.TypeList) []types.Type {
+	// TODO(nickkhyl): use slices.Collect in Go 1.23?
+	if list.Len() == 0 {
+		return nil
+	}
+	res := make([]types.Type, list.Len())
+	for i := range res {
+		res[i] = list.At(i)
+	}
+	return res
+}
+
+func collectTypeParams(list *types.TypeParamList) []*types.TypeParam {
+	if list.Len() == 0 {
+		return nil
+	}
+	res := make([]*types.TypeParam, list.Len())
+	for i := range res {
+		p := list.At(i)
+		res[i] = types.NewTypeParam(p.Obj(), p.Constraint())
+	}
+	return res
+}
+
 var (
 	flagTypes     = flag.String("type", "", "comma-separated list of types; required")
 	flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
 	flagCloneFunc = flag.Bool("clonefunc", false, "add a top-level Clone func")
 
 	flagCloneOnlyTypes = flag.String("clone-only-type", "", "comma-separated list of types (a subset of --type) that should only generate a go:generate clone line and not actual views")
+
+	typeNames []string
 )
 
 func main() {
@@ -415,7 +539,7 @@ func main() {
 		flag.Usage()
 		os.Exit(2)
 	}
-	typeNames := strings.Split(*flagTypes, ",")
+	typeNames = strings.Split(*flagTypes, ",")
 
 	var flagArgs []string
 	flagArgs = append(flagArgs, fmt.Sprintf("-clonefunc=%v", *flagCloneFunc))

+ 51 - 0
types/views/views.go

@@ -11,6 +11,7 @@ import (
 	"errors"
 	"fmt"
 	"maps"
+	"reflect"
 	"slices"
 
 	"go4.org/mem"
@@ -111,6 +112,13 @@ type StructView[T any] interface {
 	AsStruct() T
 }
 
+// Cloner is any type that has a Clone function returning a deep-clone of the receiver.
+type Cloner[T any] interface {
+	// Clone returns a deep-clone of the receiver.
+	// It returns nil, when the receiver is nil.
+	Clone() T
+}
+
 // ViewCloner is any type that has had View and Clone funcs generated using
 // tailscale.com/cmd/viewer.
 type ViewCloner[T any, V StructView[T]] interface {
@@ -555,3 +563,46 @@ func (m MapFn[K, T, V]) Range(f MapRangeFn[K, V]) {
 		}
 	}
 }
+
+// ContainsPointers reports whether T contains any pointers,
+// either explicitly or implicitly.
+// It has special handling for some types that contain pointers
+// that we know are free from memory aliasing/mutation concerns.
+func ContainsPointers[T any]() bool {
+	return containsPointers(reflect.TypeFor[T]())
+}
+
+func containsPointers(typ reflect.Type) bool {
+	switch typ.Kind() {
+	case reflect.Pointer, reflect.UnsafePointer:
+		return true
+	case reflect.Chan, reflect.Map, reflect.Slice:
+		return true
+	case reflect.Array:
+		return containsPointers(typ.Elem())
+	case reflect.Interface, reflect.Func:
+		return true // err on the safe side.
+	case reflect.Struct:
+		if isWellKnownImmutableStruct(typ) {
+			return false
+		}
+		for i := range typ.NumField() {
+			if containsPointers(typ.Field(i).Type) {
+				return true
+			}
+		}
+	}
+	return false
+}
+
+func isWellKnownImmutableStruct(typ reflect.Type) bool {
+	switch typ.String() {
+	case "time.Time":
+		// time.Time contains a pointer that does not need copying
+		return true
+	case "netip.Addr", "netip.Prefix", "netip.AddrPort":
+		return true
+	default:
+		return false
+	}
+}

+ 223 - 0
types/views/views_test.go

@@ -10,6 +10,7 @@ import (
 	"reflect"
 	"strings"
 	"testing"
+	"unsafe"
 
 	qt "github.com/frankban/quicktest"
 )
@@ -22,6 +23,16 @@ type viewStruct struct {
 	StringsPtr *Slice[string]       `json:",omitempty"`
 }
 
+type noPtrStruct struct {
+	Int int
+	Str string
+}
+
+type withPtrStruct struct {
+	Int    int
+	StrPtr *string
+}
+
 func BenchmarkSliceIteration(b *testing.B) {
 	var data []viewStruct
 	for i := range 10000 {
@@ -189,3 +200,215 @@ func TestSliceMapKey(t *testing.T) {
 		}
 	}
 }
+
+func TestContainsPointers(t *testing.T) {
+	tests := []struct {
+		name     string
+		typ      reflect.Type
+		wantPtrs bool
+	}{
+		{
+			name:     "bool",
+			typ:      reflect.TypeFor[bool](),
+			wantPtrs: false,
+		},
+		{
+			name:     "int",
+			typ:      reflect.TypeFor[int](),
+			wantPtrs: false,
+		},
+		{
+			name:     "int8",
+			typ:      reflect.TypeFor[int8](),
+			wantPtrs: false,
+		},
+		{
+			name:     "int16",
+			typ:      reflect.TypeFor[int16](),
+			wantPtrs: false,
+		},
+		{
+			name:     "int32",
+			typ:      reflect.TypeFor[int32](),
+			wantPtrs: false,
+		},
+		{
+			name:     "int64",
+			typ:      reflect.TypeFor[int64](),
+			wantPtrs: false,
+		},
+		{
+			name:     "uint",
+			typ:      reflect.TypeFor[uint](),
+			wantPtrs: false,
+		},
+		{
+			name:     "uint8",
+			typ:      reflect.TypeFor[uint8](),
+			wantPtrs: false,
+		},
+		{
+			name:     "uint16",
+			typ:      reflect.TypeFor[uint16](),
+			wantPtrs: false,
+		},
+		{
+			name:     "uint32",
+			typ:      reflect.TypeFor[uint32](),
+			wantPtrs: false,
+		},
+		{
+			name:     "uint64",
+			typ:      reflect.TypeFor[uint64](),
+			wantPtrs: false,
+		},
+		{
+			name:     "uintptr",
+			typ:      reflect.TypeFor[uintptr](),
+			wantPtrs: false,
+		},
+		{
+			name:     "string",
+			typ:      reflect.TypeFor[string](),
+			wantPtrs: false,
+		},
+		{
+			name:     "float32",
+			typ:      reflect.TypeFor[float32](),
+			wantPtrs: false,
+		},
+		{
+			name:     "float64",
+			typ:      reflect.TypeFor[float64](),
+			wantPtrs: false,
+		},
+		{
+			name:     "complex64",
+			typ:      reflect.TypeFor[complex64](),
+			wantPtrs: false,
+		},
+		{
+			name:     "complex128",
+			typ:      reflect.TypeFor[complex128](),
+			wantPtrs: false,
+		},
+		{
+			name:     "netip-Addr",
+			typ:      reflect.TypeFor[netip.Addr](),
+			wantPtrs: false,
+		},
+		{
+			name:     "netip-Prefix",
+			typ:      reflect.TypeFor[netip.Prefix](),
+			wantPtrs: false,
+		},
+		{
+			name:     "netip-AddrPort",
+			typ:      reflect.TypeFor[netip.AddrPort](),
+			wantPtrs: false,
+		},
+		{
+			name:     "bool-ptr",
+			typ:      reflect.TypeFor[*bool](),
+			wantPtrs: true,
+		},
+		{
+			name:     "string-ptr",
+			typ:      reflect.TypeFor[*string](),
+			wantPtrs: true,
+		},
+		{
+			name:     "netip-Addr-ptr",
+			typ:      reflect.TypeFor[*netip.Addr](),
+			wantPtrs: true,
+		},
+		{
+			name:     "unsafe-ptr",
+			typ:      reflect.TypeFor[unsafe.Pointer](),
+			wantPtrs: true,
+		},
+		{
+			name:     "no-ptr-struct",
+			typ:      reflect.TypeFor[noPtrStruct](),
+			wantPtrs: false,
+		},
+		{
+			name:     "ptr-struct",
+			typ:      reflect.TypeFor[withPtrStruct](),
+			wantPtrs: true,
+		},
+		{
+			name:     "string-array",
+			typ:      reflect.TypeFor[[5]string](),
+			wantPtrs: false,
+		},
+		{
+			name:     "int-ptr-array",
+			typ:      reflect.TypeFor[[5]*int](),
+			wantPtrs: true,
+		},
+		{
+			name:     "no-ptr-struct-array",
+			typ:      reflect.TypeFor[[5]noPtrStruct](),
+			wantPtrs: false,
+		},
+		{
+			name:     "with-ptr-struct-array",
+			typ:      reflect.TypeFor[[5]withPtrStruct](),
+			wantPtrs: true,
+		},
+		{
+			name:     "string-slice",
+			typ:      reflect.TypeFor[[]string](),
+			wantPtrs: true,
+		},
+		{
+			name:     "int-ptr-slice",
+			typ:      reflect.TypeFor[[]int](),
+			wantPtrs: true,
+		},
+		{
+			name:     "no-ptr-struct-slice",
+			typ:      reflect.TypeFor[[]noPtrStruct](),
+			wantPtrs: true,
+		},
+		{
+			name:     "string-map",
+			typ:      reflect.TypeFor[map[string]string](),
+			wantPtrs: true,
+		},
+		{
+			name:     "int-map",
+			typ:      reflect.TypeFor[map[int]int](),
+			wantPtrs: true,
+		},
+		{
+			name:     "no-ptr-struct-map",
+			typ:      reflect.TypeFor[map[string]noPtrStruct](),
+			wantPtrs: true,
+		},
+		{
+			name:     "chan",
+			typ:      reflect.TypeFor[chan int](),
+			wantPtrs: true,
+		},
+		{
+			name:     "func",
+			typ:      reflect.TypeFor[func()](),
+			wantPtrs: true,
+		},
+		{
+			name:     "interface",
+			typ:      reflect.TypeFor[any](),
+			wantPtrs: true,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if gotPtrs := containsPointers(tt.typ); gotPtrs != tt.wantPtrs {
+				t.Errorf("got %v; want %v", gotPtrs, tt.wantPtrs)
+			}
+		})
+	}
+}

+ 8 - 0
util/codegen/codegen.go

@@ -111,6 +111,14 @@ func (it *ImportTracker) QualifiedName(t types.Type) string {
 	return types.TypeString(t, it.qualifier)
 }
 
+// PackagePrefix returns the prefix to be used when referencing named objects from pkg.
+func (it *ImportTracker) PackagePrefix(pkg *types.Package) string {
+	if s := it.qualifier(pkg); s != "" {
+		return s + "."
+	}
+	return ""
+}
+
 // Write prints all the tracked imports in a single import block to w.
 func (it *ImportTracker) Write(w io.Writer) {
 	fmt.Fprintf(w, "import (\n")