format.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package main
  4. import (
  5. "bytes"
  6. "go/ast"
  7. "go/format"
  8. "go/parser"
  9. "go/token"
  10. "go/types"
  11. "path"
  12. "slices"
  13. "strconv"
  14. "strings"
  15. "tailscale.com/util/must"
  16. )
  17. // mustFormatFile formats a Go source file and adjust "json" imports.
  18. // It panics if there are any parsing errors.
  19. //
  20. // - "encoding/json" is imported under the name "jsonv1" or "jsonv1std"
  21. // - "encoding/json/v2" is rewritten to import "github.com/go-json-experiment/json" instead
  22. // - "encoding/json/jsontext" is rewritten to import "github.com/go-json-experiment/json/jsontext" instead
  23. // - "github.com/go-json-experiment/json" is imported under the name "jsonv2"
  24. // - "github.com/go-json-experiment/json/v1" is imported under the name "jsonv1"
  25. //
  26. // If no changes to the file is made, it returns input.
  27. func mustFormatFile(in []byte) (out []byte) {
  28. fset := token.NewFileSet()
  29. f := must.Get(parser.ParseFile(fset, "", in, parser.ParseComments))
  30. // Check for the existence of "json" imports.
  31. jsonImports := make(map[string][]*ast.ImportSpec)
  32. for _, imp := range f.Imports {
  33. switch pkgPath := must.Get(strconv.Unquote(imp.Path.Value)); pkgPath {
  34. case
  35. "encoding/json",
  36. "encoding/json/v2",
  37. "encoding/json/jsontext",
  38. "github.com/go-json-experiment/json",
  39. "github.com/go-json-experiment/json/v1",
  40. "github.com/go-json-experiment/json/jsontext":
  41. jsonImports[pkgPath] = append(jsonImports[pkgPath], imp)
  42. }
  43. }
  44. if len(jsonImports) == 0 {
  45. return in
  46. }
  47. // Best-effort local type-check of the file
  48. // to resolve local declarations to detect shadowed variables.
  49. typeInfo := &types.Info{Uses: make(map[*ast.Ident]types.Object)}
  50. (&types.Config{
  51. Error: func(err error) {},
  52. }).Check("", fset, []*ast.File{f}, typeInfo)
  53. // Rewrite imports to instead use "github.com/go-json-experiment/json".
  54. // This ensures that code continues to build even if
  55. // goexperiment.jsonv2 is *not* specified.
  56. // As of https://github.com/go-json-experiment/json/pull/186,
  57. // imports to "github.com/go-json-experiment/json" are identical
  58. // to the standard library if built with goexperiment.jsonv2.
  59. for fromPath, toPath := range map[string]string{
  60. "encoding/json/v2": "github.com/go-json-experiment/json",
  61. "encoding/json/jsontext": "github.com/go-json-experiment/json/jsontext",
  62. } {
  63. for _, imp := range jsonImports[fromPath] {
  64. imp.Path.Value = strconv.Quote(toPath)
  65. jsonImports[toPath] = append(jsonImports[toPath], imp)
  66. }
  67. delete(jsonImports, fromPath)
  68. }
  69. // While in a transitory state, where both v1 and v2 json imports
  70. // may exist in our codebase, always explicitly import with
  71. // either jsonv1 or jsonv2 in the package name to avoid ambiguities
  72. // when looking at a particular Marshal or Unmarshal call site.
  73. renames := make(map[string]string) // mapping of old names to new names
  74. deletes := make(map[*ast.ImportSpec]bool) // set of imports to delete
  75. for pkgPath, imps := range jsonImports {
  76. var newName string
  77. switch pkgPath {
  78. case "encoding/json":
  79. newName = "jsonv1"
  80. // If "github.com/go-json-experiment/json/v1" is also imported,
  81. // then use jsonv1std for "encoding/json" to avoid a conflict.
  82. if len(jsonImports["github.com/go-json-experiment/json/v1"]) > 0 {
  83. newName += "std"
  84. }
  85. case "github.com/go-json-experiment/json":
  86. newName = "jsonv2"
  87. case "github.com/go-json-experiment/json/v1":
  88. newName = "jsonv1"
  89. }
  90. // Rename the import if different than expected.
  91. if oldName := importName(imps[0]); oldName != newName && newName != "" {
  92. renames[oldName] = newName
  93. pos := imps[0].Pos() // preserve original positioning
  94. imps[0].Name = ast.NewIdent(newName)
  95. imps[0].Name.NamePos = pos
  96. }
  97. // For all redundant imports, use the first imported name.
  98. for _, imp := range imps[1:] {
  99. renames[importName(imp)] = importName(imps[0])
  100. deletes[imp] = true
  101. }
  102. }
  103. if len(deletes) > 0 {
  104. f.Imports = slices.DeleteFunc(f.Imports, func(imp *ast.ImportSpec) bool {
  105. return deletes[imp]
  106. })
  107. for _, decl := range f.Decls {
  108. if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT {
  109. genDecl.Specs = slices.DeleteFunc(genDecl.Specs, func(spec ast.Spec) bool {
  110. return deletes[spec.(*ast.ImportSpec)]
  111. })
  112. }
  113. }
  114. }
  115. if len(renames) > 0 {
  116. ast.Walk(astVisitor(func(n ast.Node) bool {
  117. if sel, ok := n.(*ast.SelectorExpr); ok {
  118. if id, ok := sel.X.(*ast.Ident); ok {
  119. // Just because the selector looks like "json.Marshal"
  120. // does not mean that it is referencing the "json" package.
  121. // There could be a local "json" declaration that shadows
  122. // the package import. Check partial type information
  123. // to see if there was a local declaration.
  124. if obj, ok := typeInfo.Uses[id]; ok {
  125. if _, ok := obj.(*types.PkgName); !ok {
  126. return true
  127. }
  128. }
  129. if newName, ok := renames[id.String()]; ok {
  130. id.Name = newName
  131. }
  132. }
  133. }
  134. return true
  135. }), f)
  136. }
  137. bb := new(bytes.Buffer)
  138. must.Do(format.Node(bb, fset, f))
  139. return must.Get(format.Source(bb.Bytes()))
  140. }
  141. // importName is the local package name used for an import.
  142. // If no explicit local name is used, then it uses string parsing
  143. // to derive the package name from the path, relying on the convention
  144. // that the package name is the base name of the package path.
  145. func importName(imp *ast.ImportSpec) string {
  146. if imp.Name != nil {
  147. return imp.Name.String()
  148. }
  149. pkgPath, _ := strconv.Unquote(imp.Path.Value)
  150. pkgPath = strings.TrimRight(pkgPath, "/v0123456789") // exclude version directories
  151. return path.Base(pkgPath)
  152. }
  153. // astVisitor is a function that implements [ast.Visitor].
  154. type astVisitor func(ast.Node) bool
  155. func (f astVisitor) Visit(node ast.Node) ast.Visitor {
  156. if !f(node) {
  157. return nil
  158. }
  159. return f
  160. }