2
0

codegen.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // Package codegen contains shared utilities for generating code.
  4. package codegen
  5. import (
  6. "bytes"
  7. "flag"
  8. "fmt"
  9. "go/ast"
  10. "go/token"
  11. "go/types"
  12. "io"
  13. "os"
  14. "reflect"
  15. "strings"
  16. "golang.org/x/tools/go/packages"
  17. "golang.org/x/tools/imports"
  18. "tailscale.com/util/mak"
  19. )
  20. var flagCopyright = flag.Bool("copyright", true, "add Tailscale copyright to generated file headers")
  21. // LoadTypes returns all named types in pkgName, keyed by their type name.
  22. func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]*types.Named, error) {
  23. cfg := &packages.Config{
  24. Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
  25. Tests: false,
  26. }
  27. if buildTags != "" {
  28. cfg.BuildFlags = []string{"-tags=" + buildTags}
  29. }
  30. pkgs, err := packages.Load(cfg, pkgName)
  31. if err != nil {
  32. return nil, nil, err
  33. }
  34. if len(pkgs) != 1 {
  35. return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
  36. }
  37. pkg := pkgs[0]
  38. return pkg, namedTypes(pkg), nil
  39. }
  40. // HasNoClone reports whether the provided tag has `codegen:noclone`.
  41. func HasNoClone(structTag string) bool {
  42. val := reflect.StructTag(structTag).Get("codegen")
  43. for _, v := range strings.Split(val, ",") {
  44. if v == "noclone" {
  45. return true
  46. }
  47. }
  48. return false
  49. }
  50. const copyrightHeader = `// Copyright (c) Tailscale Inc & AUTHORS
  51. // SPDX-License-Identifier: BSD-3-Clause
  52. `
  53. const genAndPackageHeader = `// Code generated by %v; DO NOT EDIT.
  54. package %s
  55. `
  56. func NewImportTracker(thisPkg *types.Package) *ImportTracker {
  57. return &ImportTracker{
  58. thisPkg: thisPkg,
  59. }
  60. }
  61. // ImportTracker provides a mechanism to track and build import paths.
  62. type ImportTracker struct {
  63. thisPkg *types.Package
  64. packages map[string]bool
  65. }
  66. func (it *ImportTracker) Import(pkg string) {
  67. if pkg != "" && !it.packages[pkg] {
  68. mak.Set(&it.packages, pkg, true)
  69. }
  70. }
  71. func (it *ImportTracker) qualifier(pkg *types.Package) string {
  72. if it.thisPkg == pkg {
  73. return ""
  74. }
  75. it.Import(pkg.Path())
  76. // TODO(maisem): handle conflicts?
  77. return pkg.Name()
  78. }
  79. // QualifiedName returns the string representation of t in the package.
  80. func (it *ImportTracker) QualifiedName(t types.Type) string {
  81. return types.TypeString(t, it.qualifier)
  82. }
  83. // Write prints all the tracked imports in a single import block to w.
  84. func (it *ImportTracker) Write(w io.Writer) {
  85. fmt.Fprintf(w, "import (\n")
  86. for s := range it.packages {
  87. fmt.Fprintf(w, "\t%q\n", s)
  88. }
  89. fmt.Fprintf(w, ")\n\n")
  90. }
  91. func writeHeader(w io.Writer, tool, pkg string) {
  92. if *flagCopyright {
  93. fmt.Fprint(w, copyrightHeader)
  94. }
  95. fmt.Fprintf(w, genAndPackageHeader, tool, pkg)
  96. }
  97. // WritePackageFile adds a file with the provided imports and contents to package.
  98. // The tool param is used to identify the tool that generated package file.
  99. func WritePackageFile(tool string, pkg *packages.Package, path string, it *ImportTracker, contents *bytes.Buffer) error {
  100. buf := new(bytes.Buffer)
  101. writeHeader(buf, tool, pkg.Name)
  102. it.Write(buf)
  103. if _, err := buf.Write(contents.Bytes()); err != nil {
  104. return err
  105. }
  106. return writeFormatted(buf.Bytes(), path)
  107. }
  108. // writeFormatted writes code to path.
  109. // It runs gofmt on it before writing;
  110. // if gofmt fails, it writes code unchanged.
  111. // Errors can include I/O errors and gofmt errors.
  112. //
  113. // The advantage of always writing code to path,
  114. // even if gofmt fails, is that it makes debugging easier.
  115. // The code can be long, but you need it in order to debug.
  116. // It is nicer to work with it in a file than a terminal.
  117. // It is also easier to interpret gofmt errors
  118. // with an editor providing file and line numbers.
  119. func writeFormatted(code []byte, path string) error {
  120. out, fmterr := imports.Process(path, code, &imports.Options{
  121. Comments: true,
  122. TabIndent: true,
  123. TabWidth: 8,
  124. FormatOnly: true, // fancy gofmt only
  125. })
  126. if fmterr != nil {
  127. out = code
  128. }
  129. ioerr := os.WriteFile(path, out, 0644)
  130. // Prefer I/O errors. They're usually easier to fix,
  131. // and until they're fixed you can't do much else.
  132. if ioerr != nil {
  133. return ioerr
  134. }
  135. if fmterr != nil {
  136. return fmt.Errorf("%s:%v", path, fmterr)
  137. }
  138. return nil
  139. }
  140. // namedTypes returns all named types in pkg, keyed by their type name.
  141. func namedTypes(pkg *packages.Package) map[string]*types.Named {
  142. nt := make(map[string]*types.Named)
  143. for _, file := range pkg.Syntax {
  144. for _, d := range file.Decls {
  145. decl, ok := d.(*ast.GenDecl)
  146. if !ok || decl.Tok != token.TYPE {
  147. continue
  148. }
  149. for _, s := range decl.Specs {
  150. spec, ok := s.(*ast.TypeSpec)
  151. if !ok {
  152. continue
  153. }
  154. typeNameObj, ok := pkg.TypesInfo.Defs[spec.Name]
  155. if !ok {
  156. continue
  157. }
  158. typ, ok := typeNameObj.Type().(*types.Named)
  159. if !ok {
  160. continue
  161. }
  162. nt[spec.Name.Name] = typ
  163. }
  164. }
  165. }
  166. return nt
  167. }
  168. // AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
  169. // thisPkg is the package containing t.
  170. // tname is the named type corresponding to t.
  171. // ctx is a single-word context for this assertion, such as "Clone".
  172. // If non-nil, AssertStructUnchanged will add elements to imports
  173. // for each package path that the caller must import for the returned code to compile.
  174. func AssertStructUnchanged(t *types.Struct, tname, ctx string, it *ImportTracker) []byte {
  175. buf := new(bytes.Buffer)
  176. w := func(format string, args ...any) {
  177. fmt.Fprintf(buf, format+"\n", args...)
  178. }
  179. w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
  180. w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
  181. for i := 0; i < t.NumFields(); i++ {
  182. st := t.Field(i)
  183. fname := st.Name()
  184. ft := t.Field(i).Type()
  185. if IsInvalid(ft) {
  186. continue
  187. }
  188. qname := it.QualifiedName(ft)
  189. if st.Anonymous() {
  190. w("\t%s ", fname)
  191. } else {
  192. w("\t%s %s", fname, qname)
  193. }
  194. }
  195. w("}{})\n")
  196. return buf.Bytes()
  197. }
  198. // IsInvalid reports whether the provided type is invalid. It is used to allow
  199. // codegeneration to run even when the target files have build errors or are
  200. // missing views.
  201. func IsInvalid(t types.Type) bool {
  202. return t.String() == "invalid type"
  203. }
  204. // ContainsPointers reports whether typ contains any pointers,
  205. // either explicitly or implicitly.
  206. // It has special handling for some types that contain pointers
  207. // that we know are free from memory aliasing/mutation concerns.
  208. func ContainsPointers(typ types.Type) bool {
  209. switch typ.String() {
  210. case "time.Time":
  211. // time.Time contains a pointer that does not need copying
  212. return false
  213. case "inet.af/netip.Addr", "net/netip.Addr", "net/netip.Prefix", "net/netip.AddrPort":
  214. return false
  215. }
  216. switch ft := typ.Underlying().(type) {
  217. case *types.Array:
  218. return ContainsPointers(ft.Elem())
  219. case *types.Chan:
  220. return true
  221. case *types.Interface:
  222. return true // a little too broad
  223. case *types.Map:
  224. return true
  225. case *types.Pointer:
  226. return true
  227. case *types.Slice:
  228. return true
  229. case *types.Struct:
  230. for i := 0; i < ft.NumFields(); i++ {
  231. if ContainsPointers(ft.Field(i).Type()) {
  232. return true
  233. }
  234. }
  235. }
  236. return false
  237. }
  238. // IsViewType reports whether the provided typ is a View.
  239. func IsViewType(typ types.Type) bool {
  240. t, ok := typ.Underlying().(*types.Struct)
  241. if !ok {
  242. return false
  243. }
  244. if t.NumFields() != 1 {
  245. return false
  246. }
  247. return t.Field(0).Name() == "ж"
  248. }