| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414 |
- // Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- // Package codegen contains shared utilities for generating code.
- package codegen
- import (
- "bytes"
- "flag"
- "fmt"
- "go/ast"
- "go/token"
- "go/types"
- "io"
- "os"
- "reflect"
- "strings"
- "golang.org/x/tools/go/packages"
- "golang.org/x/tools/imports"
- "tailscale.com/util/mak"
- )
- var flagCopyright = flag.Bool("copyright", true, "add Tailscale copyright to generated file headers")
- // LoadTypes returns all named types in pkgName, keyed by their type name.
- func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]types.Type, error) {
- cfg := &packages.Config{
- Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
- Tests: buildTags == "test",
- }
- if buildTags != "" && !cfg.Tests {
- cfg.BuildFlags = []string{"-tags=" + buildTags}
- }
- pkgs, err := packages.Load(cfg, pkgName)
- if err != nil {
- return nil, nil, err
- }
- if cfg.Tests {
- pkgs = testPackages(pkgs)
- }
- if len(pkgs) != 1 {
- return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
- }
- pkg := pkgs[0]
- return pkg, namedTypes(pkg), nil
- }
- func testPackages(pkgs []*packages.Package) []*packages.Package {
- var testPackages []*packages.Package
- for _, pkg := range pkgs {
- testPackageID := fmt.Sprintf("%[1]s [%[1]s.test]", pkg.PkgPath)
- if pkg.ID == testPackageID {
- testPackages = append(testPackages, pkg)
- }
- }
- return testPackages
- }
- // HasNoClone reports whether the provided tag has `codegen:noclone`.
- func HasNoClone(structTag string) bool {
- val := reflect.StructTag(structTag).Get("codegen")
- for _, v := range strings.Split(val, ",") {
- if v == "noclone" {
- return true
- }
- }
- return false
- }
- const copyrightHeader = `// Copyright (c) Tailscale Inc & AUTHORS
- // SPDX-License-Identifier: BSD-3-Clause
- `
- const genAndPackageHeader = `// Code generated by %v; DO NOT EDIT.
- package %s
- `
- func NewImportTracker(thisPkg *types.Package) *ImportTracker {
- return &ImportTracker{
- thisPkg: thisPkg,
- }
- }
- type namePkgPath struct {
- name string // optional import name
- pkgPath string
- }
- // ImportTracker provides a mechanism to track and build import paths.
- type ImportTracker struct {
- thisPkg *types.Package
- packages map[namePkgPath]bool
- }
- // Import imports pkgPath under an optional import name.
- func (it *ImportTracker) Import(name, pkgPath string) {
- if pkgPath != "" && !it.packages[namePkgPath{name, pkgPath}] {
- mak.Set(&it.packages, namePkgPath{name, pkgPath}, true)
- }
- }
- // Has reports whether the specified package path has been imported
- // under the particular import name.
- func (it *ImportTracker) Has(name, pkgPath string) bool {
- return it.packages[namePkgPath{name, pkgPath}]
- }
- func (it *ImportTracker) qualifier(pkg *types.Package) string {
- if it.thisPkg == pkg {
- return ""
- }
- it.Import("", pkg.Path())
- // TODO(maisem): handle conflicts?
- return pkg.Name()
- }
- // QualifiedName returns the string representation of t in the package.
- func (it *ImportTracker) QualifiedName(t types.Type) string {
- return types.TypeString(t, it.qualifier)
- }
- // PackagePrefix returns the prefix to be used when referencing named objects from pkg.
- func (it *ImportTracker) PackagePrefix(pkg *types.Package) string {
- if s := it.qualifier(pkg); s != "" {
- return s + "."
- }
- return ""
- }
- // Write prints all the tracked imports in a single import block to w.
- func (it *ImportTracker) Write(w io.Writer) {
- fmt.Fprintf(w, "import (\n")
- for s := range it.packages {
- if s.name == "" {
- fmt.Fprintf(w, "\t%q\n", s.pkgPath)
- } else {
- fmt.Fprintf(w, "\t%s %q\n", s.name, s.pkgPath)
- }
- }
- fmt.Fprintf(w, ")\n\n")
- }
- func writeHeader(w io.Writer, tool, pkg string) {
- if *flagCopyright {
- fmt.Fprint(w, copyrightHeader)
- }
- fmt.Fprintf(w, genAndPackageHeader, tool, pkg)
- }
- // WritePackageFile adds a file with the provided imports and contents to package.
- // The tool param is used to identify the tool that generated package file.
- func WritePackageFile(tool string, pkg *packages.Package, path string, it *ImportTracker, contents *bytes.Buffer) error {
- buf := new(bytes.Buffer)
- writeHeader(buf, tool, pkg.Name)
- it.Write(buf)
- if _, err := buf.Write(contents.Bytes()); err != nil {
- return err
- }
- return writeFormatted(buf.Bytes(), path)
- }
- // writeFormatted writes code to path.
- // It runs gofmt on it before writing;
- // if gofmt fails, it writes code unchanged.
- // Errors can include I/O errors and gofmt errors.
- //
- // The advantage of always writing code to path,
- // even if gofmt fails, is that it makes debugging easier.
- // The code can be long, but you need it in order to debug.
- // It is nicer to work with it in a file than a terminal.
- // It is also easier to interpret gofmt errors
- // with an editor providing file and line numbers.
- func writeFormatted(code []byte, path string) error {
- out, fmterr := imports.Process(path, code, &imports.Options{
- Comments: true,
- TabIndent: true,
- TabWidth: 8,
- FormatOnly: true, // fancy gofmt only
- })
- if fmterr != nil {
- out = code
- }
- ioerr := os.WriteFile(path, out, 0644)
- // Prefer I/O errors. They're usually easier to fix,
- // and until they're fixed you can't do much else.
- if ioerr != nil {
- return ioerr
- }
- if fmterr != nil {
- return fmt.Errorf("%s:%v", path, fmterr)
- }
- return nil
- }
- // namedTypes returns all named types in pkg, keyed by their type name.
- func namedTypes(pkg *packages.Package) map[string]types.Type {
- nt := make(map[string]types.Type)
- for _, file := range pkg.Syntax {
- for _, d := range file.Decls {
- decl, ok := d.(*ast.GenDecl)
- if !ok || decl.Tok != token.TYPE {
- continue
- }
- for _, s := range decl.Specs {
- spec, ok := s.(*ast.TypeSpec)
- if !ok {
- continue
- }
- typeNameObj, ok := pkg.TypesInfo.Defs[spec.Name]
- if !ok {
- continue
- }
- switch typ := typeNameObj.Type(); typ.(type) {
- case *types.Alias, *types.Named:
- nt[spec.Name.Name] = typ
- }
- }
- }
- }
- return nt
- }
- // AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
- // thisPkg is the package containing t.
- // tname is the named type corresponding to t.
- // ctx is a single-word context for this assertion, such as "Clone".
- // If non-nil, AssertStructUnchanged will add elements to imports
- // for each package path that the caller must import for the returned code to compile.
- func AssertStructUnchanged(t *types.Struct, tname string, params *types.TypeParamList, ctx string, it *ImportTracker) []byte {
- buf := new(bytes.Buffer)
- w := func(format string, args ...any) {
- fmt.Fprintf(buf, format+"\n", args...)
- }
- w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
- hasTypeParams := params != nil && params.Len() > 0
- if hasTypeParams {
- constraints, identifiers := FormatTypeParams(params, it)
- w("func _%s%sNeedsRegeneration%s (%s%s) {", tname, ctx, constraints, tname, identifiers)
- w("_%s%sNeedsRegeneration(struct {", tname, ctx)
- } else {
- w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
- }
- for i := range t.NumFields() {
- st := t.Field(i)
- fname := st.Name()
- ft := t.Field(i).Type()
- if IsInvalid(ft) {
- continue
- }
- qname := it.QualifiedName(ft)
- var tag string
- if hasTypeParams {
- tag = t.Tag(i)
- if tag != "" {
- tag = "`" + tag + "`"
- }
- }
- if st.Anonymous() {
- w("\t%s %s", qname, tag)
- } else {
- w("\t%s %s %s", fname, qname, tag)
- }
- }
- if hasTypeParams {
- w("}{})\n}")
- } else {
- w("}{})")
- }
- return buf.Bytes()
- }
- // IsInvalid reports whether the provided type is invalid. It is used to allow
- // codegeneration to run even when the target files have build errors or are
- // missing views.
- func IsInvalid(t types.Type) bool {
- return t.String() == "invalid type"
- }
- // ContainsPointers reports whether typ contains any pointers,
- // either explicitly or implicitly.
- // It has special handling for some types that contain pointers
- // that we know are free from memory aliasing/mutation concerns.
- func ContainsPointers(typ types.Type) bool {
- s := typ.String()
- switch s {
- case "time.Time":
- // time.Time contains a pointer that does not need cloning.
- return false
- case "inet.af/netip.Addr":
- return false
- }
- if strings.HasPrefix(s, "unique.Handle[") {
- // unique.Handle contains a pointer that does not need cloning.
- return false
- }
- switch ft := typ.Underlying().(type) {
- case *types.Array:
- return ContainsPointers(ft.Elem())
- case *types.Basic:
- if ft.Kind() == types.UnsafePointer {
- return true
- }
- case *types.Chan:
- return true
- case *types.Interface:
- if ft.Empty() || ft.IsMethodSet() {
- return true
- }
- for i := 0; i < ft.NumEmbeddeds(); i++ {
- if ContainsPointers(ft.EmbeddedType(i)) {
- return true
- }
- }
- case *types.Map:
- return true
- case *types.Pointer:
- return true
- case *types.Slice:
- return true
- case *types.Struct:
- for i := range ft.NumFields() {
- if ContainsPointers(ft.Field(i).Type()) {
- return true
- }
- }
- case *types.Union:
- for i := range ft.Len() {
- if ContainsPointers(ft.Term(i).Type()) {
- return true
- }
- }
- }
- return false
- }
- // IsViewType reports whether the provided typ is a View.
- func IsViewType(typ types.Type) bool {
- t, ok := typ.Underlying().(*types.Struct)
- if !ok {
- return false
- }
- if t.NumFields() != 1 {
- return false
- }
- return t.Field(0).Name() == "ж"
- }
- // FormatTypeParams formats the specified params and returns two strings:
- // - constraints are comma-separated type parameters and their constraints in square brackets (e.g. [T any, V constraints.Integer])
- // - names are comma-separated type parameter names in square brackets (e.g. [T, V])
- //
- // If params is nil or empty, both return values are empty strings.
- func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constraints, names string) {
- if params == nil || params.Len() == 0 {
- return "", ""
- }
- var constraintList, nameList []string
- for i := range params.Len() {
- param := params.At(i)
- name := param.Obj().Name()
- constraint := it.QualifiedName(param.Constraint())
- nameList = append(nameList, name)
- constraintList = append(constraintList, name+" "+constraint)
- }
- constraints = "[" + strings.Join(constraintList, ", ") + "]"
- names = "[" + strings.Join(nameList, ", ") + "]"
- return constraints, names
- }
- // LookupMethod returns the method with the specified name in t, or nil if the method does not exist.
- func LookupMethod(t types.Type, name string) *types.Func {
- switch t := t.(type) {
- case *types.Alias:
- return LookupMethod(t.Rhs(), name)
- case *types.TypeParam:
- return LookupMethod(t.Constraint(), name)
- case *types.Pointer:
- return LookupMethod(t.Elem(), name)
- case *types.Named:
- switch u := t.Underlying().(type) {
- case *types.Interface:
- return LookupMethod(u, name)
- default:
- for i := 0; i < t.NumMethods(); i++ {
- if method := t.Method(i); method.Name() == name {
- return method
- }
- }
- }
- case *types.Interface:
- for i := 0; i < t.NumMethods(); i++ {
- if method := t.Method(i); method.Name() == name {
- return method
- }
- }
- }
- return nil
- }
- // NamedTypeOf is like t.(*types.Named), but also works with type aliases.
- func NamedTypeOf(t types.Type) (named *types.Named, ok bool) {
- if a, ok := t.(*types.Alias); ok {
- return NamedTypeOf(types.Unalias(a))
- }
- named, ok = t.(*types.Named)
- return
- }
|