codegen.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package codegen contains shared utilities for generating code.
  5. package codegen
  6. import (
  7. "bytes"
  8. "fmt"
  9. "go/ast"
  10. "go/format"
  11. "go/token"
  12. "go/types"
  13. "os"
  14. "golang.org/x/tools/go/packages"
  15. )
  16. // WriteFormatted writes code to path.
  17. // It runs gofmt on it before writing;
  18. // if gofmt fails, it writes code unchanged.
  19. // Errors can include I/O errors and gofmt errors.
  20. //
  21. // The advantage of always writing code to path,
  22. // even if gofmt fails, is that it makes debugging easier.
  23. // The code can be long, but you need it in order to debug.
  24. // It is nicer to work with it in a file than a terminal.
  25. // It is also easier to interpret gofmt errors
  26. // with an editor providing file and line numbers.
  27. func WriteFormatted(code []byte, path string) error {
  28. out, fmterr := format.Source(code)
  29. if fmterr != nil {
  30. out = code
  31. }
  32. ioerr := os.WriteFile(path, out, 0644)
  33. // Prefer I/O errors. They're usually easier to fix,
  34. // and until they're fixed you can't do much else.
  35. if ioerr != nil {
  36. return ioerr
  37. }
  38. if fmterr != nil {
  39. return fmt.Errorf("%s:%v", path, fmterr)
  40. }
  41. return nil
  42. }
  43. // NamedTypes returns all named types in pkg, keyed by their type name.
  44. func NamedTypes(pkg *packages.Package) map[string]*types.Named {
  45. nt := make(map[string]*types.Named)
  46. for _, file := range pkg.Syntax {
  47. for _, d := range file.Decls {
  48. decl, ok := d.(*ast.GenDecl)
  49. if !ok || decl.Tok != token.TYPE {
  50. continue
  51. }
  52. for _, s := range decl.Specs {
  53. spec, ok := s.(*ast.TypeSpec)
  54. if !ok {
  55. continue
  56. }
  57. typeNameObj := pkg.TypesInfo.Defs[spec.Name]
  58. typ, ok := typeNameObj.Type().(*types.Named)
  59. if !ok {
  60. continue
  61. }
  62. nt[spec.Name.Name] = typ
  63. }
  64. }
  65. }
  66. return nt
  67. }
  68. // AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
  69. // thisPkg is the package containing t.
  70. // tname is the named type corresponding to t.
  71. // ctx is a single-word context for this assertion, such as "Clone".
  72. // If non-nil, AssertStructUnchanged will add elements to imports
  73. // for each package path that the caller must import for the returned code to compile.
  74. func AssertStructUnchanged(t *types.Struct, thisPkg *types.Package, tname, ctx string, imports map[string]struct{}) []byte {
  75. buf := new(bytes.Buffer)
  76. w := func(format string, args ...interface{}) {
  77. fmt.Fprintf(buf, format+"\n", args...)
  78. }
  79. w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
  80. w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
  81. for i := 0; i < t.NumFields(); i++ {
  82. fname := t.Field(i).Name()
  83. ft := t.Field(i).Type()
  84. qname, imppath := importedName(ft, thisPkg)
  85. if imppath != "" && imports != nil {
  86. imports[imppath] = struct{}{}
  87. }
  88. w("\t%s %s", fname, qname)
  89. }
  90. w("}{})\n")
  91. return buf.Bytes()
  92. }
  93. func importedName(t types.Type, thisPkg *types.Package) (qualifiedName, importPkg string) {
  94. qual := func(pkg *types.Package) string {
  95. if thisPkg == pkg {
  96. return ""
  97. }
  98. importPkg = pkg.Path()
  99. return pkg.Name()
  100. }
  101. return types.TypeString(t, qual), importPkg
  102. }
  103. // ContainsPointers reports whether typ contains any pointers,
  104. // either explicitly or implicitly.
  105. // It has special handling for some types that contain pointers
  106. // that we know are free from memory aliasing/mutation concerns.
  107. func ContainsPointers(typ types.Type) bool {
  108. switch typ.String() {
  109. case "time.Time":
  110. // time.Time contains a pointer that does not need copying
  111. return false
  112. case "inet.af/netaddr.IP":
  113. return false
  114. }
  115. switch ft := typ.Underlying().(type) {
  116. case *types.Array:
  117. return ContainsPointers(ft.Elem())
  118. case *types.Chan:
  119. return true
  120. case *types.Interface:
  121. return true // a little too broad
  122. case *types.Map:
  123. return true
  124. case *types.Pointer:
  125. return true
  126. case *types.Slice:
  127. return true
  128. case *types.Struct:
  129. for i := 0; i < ft.NumFields(); i++ {
  130. if ContainsPointers(ft.Field(i).Type()) {
  131. return true
  132. }
  133. }
  134. }
  135. return false
  136. }