| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- // Package codegen contains shared utilities for generating code.
- package codegen
- import (
- "bytes"
- "fmt"
- "go/ast"
- "go/format"
- "go/token"
- "go/types"
- "os"
- "golang.org/x/tools/go/packages"
- )
- // WriteFormatted writes code to path.
- // It runs gofmt on it before writing;
- // if gofmt fails, it writes code unchanged.
- // Errors can include I/O errors and gofmt errors.
- //
- // The advantage of always writing code to path,
- // even if gofmt fails, is that it makes debugging easier.
- // The code can be long, but you need it in order to debug.
- // It is nicer to work with it in a file than a terminal.
- // It is also easier to interpret gofmt errors
- // with an editor providing file and line numbers.
- func WriteFormatted(code []byte, path string) error {
- out, fmterr := format.Source(code)
- if fmterr != nil {
- out = code
- }
- ioerr := os.WriteFile(path, out, 0644)
- // Prefer I/O errors. They're usually easier to fix,
- // and until they're fixed you can't do much else.
- if ioerr != nil {
- return ioerr
- }
- if fmterr != nil {
- return fmt.Errorf("%s:%v", path, fmterr)
- }
- return nil
- }
- // NamedTypes returns all named types in pkg, keyed by their type name.
- func NamedTypes(pkg *packages.Package) map[string]*types.Named {
- nt := make(map[string]*types.Named)
- for _, file := range pkg.Syntax {
- for _, d := range file.Decls {
- decl, ok := d.(*ast.GenDecl)
- if !ok || decl.Tok != token.TYPE {
- continue
- }
- for _, s := range decl.Specs {
- spec, ok := s.(*ast.TypeSpec)
- if !ok {
- continue
- }
- typeNameObj := pkg.TypesInfo.Defs[spec.Name]
- typ, ok := typeNameObj.Type().(*types.Named)
- if !ok {
- continue
- }
- nt[spec.Name.Name] = typ
- }
- }
- }
- return nt
- }
- // AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
- // thisPkg is the package containing t.
- // tname is the named type corresponding to t.
- // ctx is a single-word context for this assertion, such as "Clone".
- // If non-nil, AssertStructUnchanged will add elements to imports
- // for each package path that the caller must import for the returned code to compile.
- func AssertStructUnchanged(t *types.Struct, thisPkg *types.Package, tname, ctx string, imports map[string]struct{}) []byte {
- buf := new(bytes.Buffer)
- w := func(format string, args ...interface{}) {
- fmt.Fprintf(buf, format+"\n", args...)
- }
- w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
- w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
- for i := 0; i < t.NumFields(); i++ {
- fname := t.Field(i).Name()
- ft := t.Field(i).Type()
- qname, imppath := importedName(ft, thisPkg)
- if imppath != "" && imports != nil {
- imports[imppath] = struct{}{}
- }
- w("\t%s %s", fname, qname)
- }
- w("}{})\n")
- return buf.Bytes()
- }
- func importedName(t types.Type, thisPkg *types.Package) (qualifiedName, importPkg string) {
- qual := func(pkg *types.Package) string {
- if thisPkg == pkg {
- return ""
- }
- importPkg = pkg.Path()
- return pkg.Name()
- }
- return types.TypeString(t, qual), importPkg
- }
- // ContainsPointers reports whether typ contains any pointers,
- // either explicitly or implicitly.
- // It has special handling for some types that contain pointers
- // that we know are free from memory aliasing/mutation concerns.
- func ContainsPointers(typ types.Type) bool {
- switch typ.String() {
- case "time.Time":
- // time.Time contains a pointer that does not need copying
- return false
- case "inet.af/netaddr.IP":
- return false
- }
- switch ft := typ.Underlying().(type) {
- case *types.Array:
- return ContainsPointers(ft.Elem())
- case *types.Chan:
- return true
- case *types.Interface:
- return true // a little too broad
- case *types.Map:
- return true
- case *types.Pointer:
- return true
- case *types.Slice:
- return true
- case *types.Struct:
- for i := 0; i < ft.NumFields(); i++ {
- if ContainsPointers(ft.Field(i).Type()) {
- return true
- }
- }
- }
- return false
- }
|