codegen.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // Package codegen contains shared utilities for generating code.
  4. package codegen
  5. import (
  6. "bytes"
  7. "flag"
  8. "fmt"
  9. "go/ast"
  10. "go/token"
  11. "go/types"
  12. "io"
  13. "os"
  14. "reflect"
  15. "strings"
  16. "golang.org/x/tools/go/packages"
  17. "golang.org/x/tools/imports"
  18. "tailscale.com/util/mak"
  19. )
  20. var flagCopyright = flag.Bool("copyright", true, "add Tailscale copyright to generated file headers")
  21. // LoadTypes returns all named types in pkgName, keyed by their type name.
  22. func LoadTypes(buildTags string, pkgName string) (*packages.Package, map[string]types.Type, error) {
  23. cfg := &packages.Config{
  24. Mode: packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedName,
  25. Tests: buildTags == "test",
  26. }
  27. if buildTags != "" && !cfg.Tests {
  28. cfg.BuildFlags = []string{"-tags=" + buildTags}
  29. }
  30. pkgs, err := packages.Load(cfg, pkgName)
  31. if err != nil {
  32. return nil, nil, err
  33. }
  34. if cfg.Tests {
  35. pkgs = testPackages(pkgs)
  36. }
  37. if len(pkgs) != 1 {
  38. return nil, nil, fmt.Errorf("wrong number of packages: %d", len(pkgs))
  39. }
  40. pkg := pkgs[0]
  41. return pkg, namedTypes(pkg), nil
  42. }
  43. func testPackages(pkgs []*packages.Package) []*packages.Package {
  44. var testPackages []*packages.Package
  45. for _, pkg := range pkgs {
  46. testPackageID := fmt.Sprintf("%[1]s [%[1]s.test]", pkg.PkgPath)
  47. if pkg.ID == testPackageID {
  48. testPackages = append(testPackages, pkg)
  49. }
  50. }
  51. return testPackages
  52. }
  53. // HasNoClone reports whether the provided tag has `codegen:noclone`.
  54. func HasNoClone(structTag string) bool {
  55. val := reflect.StructTag(structTag).Get("codegen")
  56. for _, v := range strings.Split(val, ",") {
  57. if v == "noclone" {
  58. return true
  59. }
  60. }
  61. return false
  62. }
  63. const copyrightHeader = `// Copyright (c) Tailscale Inc & AUTHORS
  64. // SPDX-License-Identifier: BSD-3-Clause
  65. `
  66. const genAndPackageHeader = `// Code generated by %v; DO NOT EDIT.
  67. package %s
  68. `
  69. func NewImportTracker(thisPkg *types.Package) *ImportTracker {
  70. return &ImportTracker{
  71. thisPkg: thisPkg,
  72. }
  73. }
  74. type namePkgPath struct {
  75. name string // optional import name
  76. pkgPath string
  77. }
  78. // ImportTracker provides a mechanism to track and build import paths.
  79. type ImportTracker struct {
  80. thisPkg *types.Package
  81. packages map[namePkgPath]bool
  82. }
  83. // Import imports pkgPath under an optional import name.
  84. func (it *ImportTracker) Import(name, pkgPath string) {
  85. if pkgPath != "" && !it.packages[namePkgPath{name, pkgPath}] {
  86. mak.Set(&it.packages, namePkgPath{name, pkgPath}, true)
  87. }
  88. }
  89. // Has reports whether the specified package path has been imported
  90. // under the particular import name.
  91. func (it *ImportTracker) Has(name, pkgPath string) bool {
  92. return it.packages[namePkgPath{name, pkgPath}]
  93. }
  94. func (it *ImportTracker) qualifier(pkg *types.Package) string {
  95. if it.thisPkg == pkg {
  96. return ""
  97. }
  98. it.Import("", pkg.Path())
  99. // TODO(maisem): handle conflicts?
  100. return pkg.Name()
  101. }
  102. // QualifiedName returns the string representation of t in the package.
  103. func (it *ImportTracker) QualifiedName(t types.Type) string {
  104. return types.TypeString(t, it.qualifier)
  105. }
  106. // PackagePrefix returns the prefix to be used when referencing named objects from pkg.
  107. func (it *ImportTracker) PackagePrefix(pkg *types.Package) string {
  108. if s := it.qualifier(pkg); s != "" {
  109. return s + "."
  110. }
  111. return ""
  112. }
  113. // Write prints all the tracked imports in a single import block to w.
  114. func (it *ImportTracker) Write(w io.Writer) {
  115. fmt.Fprintf(w, "import (\n")
  116. for s := range it.packages {
  117. if s.name == "" {
  118. fmt.Fprintf(w, "\t%q\n", s.pkgPath)
  119. } else {
  120. fmt.Fprintf(w, "\t%s %q\n", s.name, s.pkgPath)
  121. }
  122. }
  123. fmt.Fprintf(w, ")\n\n")
  124. }
  125. func writeHeader(w io.Writer, tool, pkg string) {
  126. if *flagCopyright {
  127. fmt.Fprint(w, copyrightHeader)
  128. }
  129. fmt.Fprintf(w, genAndPackageHeader, tool, pkg)
  130. }
  131. // WritePackageFile adds a file with the provided imports and contents to package.
  132. // The tool param is used to identify the tool that generated package file.
  133. func WritePackageFile(tool string, pkg *packages.Package, path string, it *ImportTracker, contents *bytes.Buffer) error {
  134. buf := new(bytes.Buffer)
  135. writeHeader(buf, tool, pkg.Name)
  136. it.Write(buf)
  137. if _, err := buf.Write(contents.Bytes()); err != nil {
  138. return err
  139. }
  140. return writeFormatted(buf.Bytes(), path)
  141. }
  142. // writeFormatted writes code to path.
  143. // It runs gofmt on it before writing;
  144. // if gofmt fails, it writes code unchanged.
  145. // Errors can include I/O errors and gofmt errors.
  146. //
  147. // The advantage of always writing code to path,
  148. // even if gofmt fails, is that it makes debugging easier.
  149. // The code can be long, but you need it in order to debug.
  150. // It is nicer to work with it in a file than a terminal.
  151. // It is also easier to interpret gofmt errors
  152. // with an editor providing file and line numbers.
  153. func writeFormatted(code []byte, path string) error {
  154. out, fmterr := imports.Process(path, code, &imports.Options{
  155. Comments: true,
  156. TabIndent: true,
  157. TabWidth: 8,
  158. FormatOnly: true, // fancy gofmt only
  159. })
  160. if fmterr != nil {
  161. out = code
  162. }
  163. ioerr := os.WriteFile(path, out, 0644)
  164. // Prefer I/O errors. They're usually easier to fix,
  165. // and until they're fixed you can't do much else.
  166. if ioerr != nil {
  167. return ioerr
  168. }
  169. if fmterr != nil {
  170. return fmt.Errorf("%s:%v", path, fmterr)
  171. }
  172. return nil
  173. }
  174. // namedTypes returns all named types in pkg, keyed by their type name.
  175. func namedTypes(pkg *packages.Package) map[string]types.Type {
  176. nt := make(map[string]types.Type)
  177. for _, file := range pkg.Syntax {
  178. for _, d := range file.Decls {
  179. decl, ok := d.(*ast.GenDecl)
  180. if !ok || decl.Tok != token.TYPE {
  181. continue
  182. }
  183. for _, s := range decl.Specs {
  184. spec, ok := s.(*ast.TypeSpec)
  185. if !ok {
  186. continue
  187. }
  188. typeNameObj, ok := pkg.TypesInfo.Defs[spec.Name]
  189. if !ok {
  190. continue
  191. }
  192. switch typ := typeNameObj.Type(); typ.(type) {
  193. case *types.Alias, *types.Named:
  194. nt[spec.Name.Name] = typ
  195. }
  196. }
  197. }
  198. }
  199. return nt
  200. }
  201. // AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
  202. // thisPkg is the package containing t.
  203. // tname is the named type corresponding to t.
  204. // ctx is a single-word context for this assertion, such as "Clone".
  205. // If non-nil, AssertStructUnchanged will add elements to imports
  206. // for each package path that the caller must import for the returned code to compile.
  207. func AssertStructUnchanged(t *types.Struct, tname string, params *types.TypeParamList, ctx string, it *ImportTracker) []byte {
  208. buf := new(bytes.Buffer)
  209. w := func(format string, args ...any) {
  210. fmt.Fprintf(buf, format+"\n", args...)
  211. }
  212. w("// A compilation failure here means this code must be regenerated, with the command at the top of this file.")
  213. hasTypeParams := params != nil && params.Len() > 0
  214. if hasTypeParams {
  215. constraints, identifiers := FormatTypeParams(params, it)
  216. w("func _%s%sNeedsRegeneration%s (%s%s) {", tname, ctx, constraints, tname, identifiers)
  217. w("_%s%sNeedsRegeneration(struct {", tname, ctx)
  218. } else {
  219. w("var _%s%sNeedsRegeneration = %s(struct {", tname, ctx, tname)
  220. }
  221. for i := range t.NumFields() {
  222. st := t.Field(i)
  223. fname := st.Name()
  224. ft := t.Field(i).Type()
  225. if IsInvalid(ft) {
  226. continue
  227. }
  228. qname := it.QualifiedName(ft)
  229. var tag string
  230. if hasTypeParams {
  231. tag = t.Tag(i)
  232. if tag != "" {
  233. tag = "`" + tag + "`"
  234. }
  235. }
  236. if st.Anonymous() {
  237. w("\t%s %s", qname, tag)
  238. } else {
  239. w("\t%s %s %s", fname, qname, tag)
  240. }
  241. }
  242. if hasTypeParams {
  243. w("}{})\n}")
  244. } else {
  245. w("}{})")
  246. }
  247. return buf.Bytes()
  248. }
  249. // IsInvalid reports whether the provided type is invalid. It is used to allow
  250. // codegeneration to run even when the target files have build errors or are
  251. // missing views.
  252. func IsInvalid(t types.Type) bool {
  253. return t.String() == "invalid type"
  254. }
  255. // ContainsPointers reports whether typ contains any pointers,
  256. // either explicitly or implicitly.
  257. // It has special handling for some types that contain pointers
  258. // that we know are free from memory aliasing/mutation concerns.
  259. func ContainsPointers(typ types.Type) bool {
  260. s := typ.String()
  261. switch s {
  262. case "time.Time":
  263. // time.Time contains a pointer that does not need cloning.
  264. return false
  265. case "inet.af/netip.Addr":
  266. return false
  267. }
  268. if strings.HasPrefix(s, "unique.Handle[") {
  269. // unique.Handle contains a pointer that does not need cloning.
  270. return false
  271. }
  272. switch ft := typ.Underlying().(type) {
  273. case *types.Array:
  274. return ContainsPointers(ft.Elem())
  275. case *types.Basic:
  276. if ft.Kind() == types.UnsafePointer {
  277. return true
  278. }
  279. case *types.Chan:
  280. return true
  281. case *types.Interface:
  282. if ft.Empty() || ft.IsMethodSet() {
  283. return true
  284. }
  285. for i := 0; i < ft.NumEmbeddeds(); i++ {
  286. if ContainsPointers(ft.EmbeddedType(i)) {
  287. return true
  288. }
  289. }
  290. case *types.Map:
  291. return true
  292. case *types.Pointer:
  293. return true
  294. case *types.Slice:
  295. return true
  296. case *types.Struct:
  297. for i := range ft.NumFields() {
  298. if ContainsPointers(ft.Field(i).Type()) {
  299. return true
  300. }
  301. }
  302. case *types.Union:
  303. for i := range ft.Len() {
  304. if ContainsPointers(ft.Term(i).Type()) {
  305. return true
  306. }
  307. }
  308. }
  309. return false
  310. }
  311. // IsViewType reports whether the provided typ is a View.
  312. func IsViewType(typ types.Type) bool {
  313. t, ok := typ.Underlying().(*types.Struct)
  314. if !ok {
  315. return false
  316. }
  317. if t.NumFields() != 1 {
  318. return false
  319. }
  320. return t.Field(0).Name() == "ж"
  321. }
  322. // FormatTypeParams formats the specified params and returns two strings:
  323. // - constraints are comma-separated type parameters and their constraints in square brackets (e.g. [T any, V constraints.Integer])
  324. // - names are comma-separated type parameter names in square brackets (e.g. [T, V])
  325. //
  326. // If params is nil or empty, both return values are empty strings.
  327. func FormatTypeParams(params *types.TypeParamList, it *ImportTracker) (constraints, names string) {
  328. if params == nil || params.Len() == 0 {
  329. return "", ""
  330. }
  331. var constraintList, nameList []string
  332. for i := range params.Len() {
  333. param := params.At(i)
  334. name := param.Obj().Name()
  335. constraint := it.QualifiedName(param.Constraint())
  336. nameList = append(nameList, name)
  337. constraintList = append(constraintList, name+" "+constraint)
  338. }
  339. constraints = "[" + strings.Join(constraintList, ", ") + "]"
  340. names = "[" + strings.Join(nameList, ", ") + "]"
  341. return constraints, names
  342. }
  343. // LookupMethod returns the method with the specified name in t, or nil if the method does not exist.
  344. func LookupMethod(t types.Type, name string) *types.Func {
  345. switch t := t.(type) {
  346. case *types.Alias:
  347. return LookupMethod(t.Rhs(), name)
  348. case *types.TypeParam:
  349. return LookupMethod(t.Constraint(), name)
  350. case *types.Pointer:
  351. return LookupMethod(t.Elem(), name)
  352. case *types.Named:
  353. switch u := t.Underlying().(type) {
  354. case *types.Interface:
  355. return LookupMethod(u, name)
  356. default:
  357. for i := 0; i < t.NumMethods(); i++ {
  358. if method := t.Method(i); method.Name() == name {
  359. return method
  360. }
  361. }
  362. }
  363. case *types.Interface:
  364. for i := 0; i < t.NumMethods(); i++ {
  365. if method := t.Method(i); method.Name() == name {
  366. return method
  367. }
  368. }
  369. }
  370. return nil
  371. }
  372. // NamedTypeOf is like t.(*types.Named), but also works with type aliases.
  373. func NamedTypeOf(t types.Type) (named *types.Named, ok bool) {
  374. if a, ok := t.(*types.Alias); ok {
  375. return NamedTypeOf(types.Unalias(a))
  376. }
  377. named, ok = t.(*types.Named)
  378. return
  379. }