cloner.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. // Cloner is a tool to automate the creation of a Clone method.
  4. //
  5. // The result of the Clone method aliases no memory that can be edited
  6. // with the original.
  7. //
  8. // This tool makes lots of implicit assumptions about the types you feed it.
  9. // In particular, it can only write relatively "shallow" Clone methods.
  10. // That is, if a type contains another named struct type, cloner assumes that
  11. // named type will also have a Clone method.
  12. package main
  13. import (
  14. "bytes"
  15. "flag"
  16. "fmt"
  17. "go/types"
  18. "log"
  19. "os"
  20. "strings"
  21. "tailscale.com/util/codegen"
  22. )
  23. var (
  24. flagTypes = flag.String("type", "", "comma-separated list of types; required")
  25. flagBuildTags = flag.String("tags", "", "compiler build tags to apply")
  26. flagCloneFunc = flag.Bool("clonefunc", false, "add a top-level Clone func")
  27. )
  28. func main() {
  29. log.SetFlags(0)
  30. log.SetPrefix("cloner: ")
  31. flag.Parse()
  32. if len(*flagTypes) == 0 {
  33. flag.Usage()
  34. os.Exit(2)
  35. }
  36. typeNames := strings.Split(*flagTypes, ",")
  37. pkg, namedTypes, err := codegen.LoadTypes(*flagBuildTags, ".")
  38. if err != nil {
  39. log.Fatal(err)
  40. }
  41. it := codegen.NewImportTracker(pkg.Types)
  42. buf := new(bytes.Buffer)
  43. for _, typeName := range typeNames {
  44. typ, ok := namedTypes[typeName].(*types.Named)
  45. if !ok {
  46. log.Fatalf("could not find type %s", typeName)
  47. }
  48. gen(buf, it, typ)
  49. }
  50. w := func(format string, args ...any) {
  51. fmt.Fprintf(buf, format+"\n", args...)
  52. }
  53. if *flagCloneFunc {
  54. w("// Clone duplicates src into dst and reports whether it succeeded.")
  55. w("// To succeed, <src, dst> must be of types <*T, *T> or <*T, **T>,")
  56. w("// where T is one of %s.", *flagTypes)
  57. w("func Clone(dst, src any) bool {")
  58. w(" switch src := src.(type) {")
  59. for _, typeName := range typeNames {
  60. w(" case *%s:", typeName)
  61. w(" switch dst := dst.(type) {")
  62. w(" case *%s:", typeName)
  63. w(" *dst = *src.Clone()")
  64. w(" return true")
  65. w(" case **%s:", typeName)
  66. w(" *dst = src.Clone()")
  67. w(" return true")
  68. w(" }")
  69. }
  70. w(" }")
  71. w(" return false")
  72. w("}")
  73. }
  74. cloneOutput := pkg.Name + "_clone"
  75. if *flagBuildTags == "test" {
  76. cloneOutput += "_test"
  77. }
  78. cloneOutput += ".go"
  79. if err := codegen.WritePackageFile("tailscale.com/cmd/cloner", pkg, cloneOutput, it, buf); err != nil {
  80. log.Fatal(err)
  81. }
  82. }
  83. func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) {
  84. t, ok := typ.Underlying().(*types.Struct)
  85. if !ok {
  86. return
  87. }
  88. name := typ.Obj().Name()
  89. typeParams := typ.Origin().TypeParams()
  90. _, typeParamNames := codegen.FormatTypeParams(typeParams, it)
  91. nameWithParams := name + typeParamNames
  92. fmt.Fprintf(buf, "// Clone makes a deep copy of %s.\n", name)
  93. fmt.Fprintf(buf, "// The result aliases no memory with the original.\n")
  94. fmt.Fprintf(buf, "func (src *%s) Clone() *%s {\n", nameWithParams, nameWithParams)
  95. writef := func(format string, args ...any) {
  96. fmt.Fprintf(buf, "\t"+format+"\n", args...)
  97. }
  98. writef("if src == nil {")
  99. writef("\treturn nil")
  100. writef("}")
  101. writef("dst := new(%s)", nameWithParams)
  102. writef("*dst = *src")
  103. for i := range t.NumFields() {
  104. fname := t.Field(i).Name()
  105. ft := t.Field(i).Type()
  106. if !codegen.ContainsPointers(ft) || codegen.HasNoClone(t.Tag(i)) {
  107. continue
  108. }
  109. if named, _ := codegen.NamedTypeOf(ft); named != nil {
  110. if codegen.IsViewType(ft) {
  111. writef("dst.%s = src.%s", fname, fname)
  112. continue
  113. }
  114. if !hasBasicUnderlying(ft) {
  115. // don't dereference if the underlying type is an interface
  116. if _, isInterface := ft.Underlying().(*types.Interface); isInterface {
  117. writef("if src.%s != nil { dst.%s = src.%s.Clone() }", fname, fname, fname)
  118. } else {
  119. writef("dst.%s = *src.%s.Clone()", fname, fname)
  120. }
  121. continue
  122. }
  123. }
  124. switch ft := ft.Underlying().(type) {
  125. case *types.Slice:
  126. if codegen.ContainsPointers(ft.Elem()) {
  127. n := it.QualifiedName(ft.Elem())
  128. writef("if src.%s != nil {", fname)
  129. writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname)
  130. writef("for i := range dst.%s {", fname)
  131. if ptr, isPtr := ft.Elem().(*types.Pointer); isPtr {
  132. writef("if src.%s[i] == nil { dst.%s[i] = nil } else {", fname, fname)
  133. if codegen.ContainsPointers(ptr.Elem()) {
  134. if _, isIface := ptr.Elem().Underlying().(*types.Interface); isIface {
  135. it.Import("", "tailscale.com/types/ptr")
  136. writef("\tdst.%s[i] = ptr.To((*src.%s[i]).Clone())", fname, fname)
  137. } else {
  138. writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname)
  139. }
  140. } else {
  141. it.Import("", "tailscale.com/types/ptr")
  142. writef("\tdst.%s[i] = ptr.To(*src.%s[i])", fname, fname)
  143. }
  144. writef("}")
  145. } else if ft.Elem().String() == "encoding/json.RawMessage" {
  146. writef("\tdst.%s[i] = append(src.%s[i][:0:0], src.%s[i]...)", fname, fname, fname)
  147. } else if _, isIface := ft.Elem().Underlying().(*types.Interface); isIface {
  148. writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname)
  149. } else {
  150. writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname)
  151. }
  152. writef("}")
  153. writef("}")
  154. } else {
  155. writef("dst.%s = append(src.%s[:0:0], src.%s...)", fname, fname, fname)
  156. }
  157. case *types.Pointer:
  158. base := ft.Elem()
  159. hasPtrs := codegen.ContainsPointers(base)
  160. if named, _ := codegen.NamedTypeOf(base); named != nil && hasPtrs {
  161. writef("dst.%s = src.%s.Clone()", fname, fname)
  162. continue
  163. }
  164. it.Import("", "tailscale.com/types/ptr")
  165. writef("if dst.%s != nil {", fname)
  166. if _, isIface := base.Underlying().(*types.Interface); isIface && hasPtrs {
  167. writef("\tdst.%s = ptr.To((*src.%s).Clone())", fname, fname)
  168. } else if !hasPtrs {
  169. writef("\tdst.%s = ptr.To(*src.%s)", fname, fname)
  170. } else {
  171. writef("\t" + `panic("TODO pointers in pointers")`)
  172. }
  173. writef("}")
  174. case *types.Map:
  175. elem := ft.Elem()
  176. if sliceType, isSlice := elem.(*types.Slice); isSlice {
  177. n := it.QualifiedName(sliceType.Elem())
  178. writef("if dst.%s != nil {", fname)
  179. writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem))
  180. writef("\tfor k := range src.%s {", fname)
  181. // use zero-length slice instead of nil to ensure
  182. // the key is always copied.
  183. writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname)
  184. writef("\t}")
  185. writef("}")
  186. } else if codegen.IsViewType(elem) || !codegen.ContainsPointers(elem) {
  187. // If the map values are view types (which are
  188. // immutable and don't need cloning) or don't
  189. // themselves contain pointers, we can just
  190. // clone the map itself.
  191. it.Import("", "maps")
  192. writef("\tdst.%s = maps.Clone(src.%s)", fname, fname)
  193. } else {
  194. // Otherwise we need to clone each element of
  195. // the map using our recursive helper.
  196. writef("if dst.%s != nil {", fname)
  197. writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem))
  198. writef("\tfor k, v := range src.%s {", fname)
  199. // Use a recursive helper here; this handles
  200. // arbitrarily nested maps in addition to
  201. // simpler types.
  202. writeMapValueClone(mapValueCloneParams{
  203. Buf: buf,
  204. It: it,
  205. Elem: elem,
  206. SrcExpr: "v",
  207. DstExpr: fmt.Sprintf("dst.%s[k]", fname),
  208. BaseIndent: "\t",
  209. Depth: 1,
  210. })
  211. writef("\t}")
  212. writef("}")
  213. }
  214. case *types.Interface:
  215. // If ft is an interface with a "Clone() ft" method, it can be used to clone the field.
  216. // This includes scenarios where ft is a constrained type parameter.
  217. if cloneResultType := methodResultType(ft, "Clone"); cloneResultType.Underlying() == ft {
  218. writef("dst.%s = src.%s.Clone()", fname, fname)
  219. continue
  220. }
  221. writef(`panic("%s (%v) does not have a compatible Clone method")`, fname, ft)
  222. default:
  223. writef(`panic("TODO: %s (%T)")`, fname, ft)
  224. }
  225. }
  226. writef("return dst")
  227. fmt.Fprintf(buf, "}\n\n")
  228. buf.Write(codegen.AssertStructUnchanged(t, name, typeParams, "Clone", it))
  229. }
  230. // hasBasicUnderlying reports true when typ.Underlying() is a slice or a map.
  231. func hasBasicUnderlying(typ types.Type) bool {
  232. switch typ.Underlying().(type) {
  233. case *types.Slice, *types.Map:
  234. return true
  235. default:
  236. return false
  237. }
  238. }
  239. func methodResultType(typ types.Type, method string) types.Type {
  240. viewMethod := codegen.LookupMethod(typ, method)
  241. if viewMethod == nil {
  242. return nil
  243. }
  244. sig, ok := viewMethod.Type().(*types.Signature)
  245. if !ok || sig.Results().Len() != 1 {
  246. return nil
  247. }
  248. return sig.Results().At(0).Type()
  249. }
  250. type mapValueCloneParams struct {
  251. // Buf is the buffer to write generated code to
  252. Buf *bytes.Buffer
  253. // It is the import tracker for managing imports.
  254. It *codegen.ImportTracker
  255. // Elem is the type of the map value to clone
  256. Elem types.Type
  257. // SrcExpr is the expression for the source value (e.g., "v", "v2", "v3")
  258. SrcExpr string
  259. // DstExpr is the expression for the destination (e.g., "dst.Field[k]", "dst.Field[k][k2]")
  260. DstExpr string
  261. // BaseIndent is the "base" indentation string for the generated code
  262. // (i.e. 1 or more tabs). Additional indentation will be added based on
  263. // the Depth parameter.
  264. BaseIndent string
  265. // Depth is the current nesting depth (1 for first level, 2 for second, etc.)
  266. Depth int
  267. }
  268. // writeMapValueClone generates code to clone a map value recursively.
  269. // It handles arbitrary nesting of maps, pointers, and interfaces.
  270. func writeMapValueClone(params mapValueCloneParams) {
  271. indent := params.BaseIndent + strings.Repeat("\t", params.Depth)
  272. writef := func(format string, args ...any) {
  273. fmt.Fprintf(params.Buf, indent+format+"\n", args...)
  274. }
  275. switch elem := params.Elem.Underlying().(type) {
  276. case *types.Pointer:
  277. writef("if %s == nil { %s = nil } else {", params.SrcExpr, params.DstExpr)
  278. if base := elem.Elem().Underlying(); codegen.ContainsPointers(base) {
  279. if _, isIface := base.(*types.Interface); isIface {
  280. params.It.Import("", "tailscale.com/types/ptr")
  281. writef("\t%s = ptr.To((*%s).Clone())", params.DstExpr, params.SrcExpr)
  282. } else {
  283. writef("\t%s = %s.Clone()", params.DstExpr, params.SrcExpr)
  284. }
  285. } else {
  286. params.It.Import("", "tailscale.com/types/ptr")
  287. writef("\t%s = ptr.To(*%s)", params.DstExpr, params.SrcExpr)
  288. }
  289. writef("}")
  290. case *types.Map:
  291. // Recursively handle nested maps
  292. innerElem := elem.Elem()
  293. if codegen.IsViewType(innerElem) || !codegen.ContainsPointers(innerElem) {
  294. // Inner map values don't need deep cloning
  295. params.It.Import("", "maps")
  296. writef("%s = maps.Clone(%s)", params.DstExpr, params.SrcExpr)
  297. } else {
  298. // Inner map values need cloning
  299. keyType := params.It.QualifiedName(elem.Key())
  300. valueType := params.It.QualifiedName(innerElem)
  301. // Generate unique variable names for nested loops based on depth
  302. keyVar := fmt.Sprintf("k%d", params.Depth+1)
  303. valVar := fmt.Sprintf("v%d", params.Depth+1)
  304. writef("if %s == nil {", params.SrcExpr)
  305. writef("\t%s = nil", params.DstExpr)
  306. writef("\tcontinue")
  307. writef("}")
  308. writef("%s = map[%s]%s{}", params.DstExpr, keyType, valueType)
  309. writef("for %s, %s := range %s {", keyVar, valVar, params.SrcExpr)
  310. // Recursively generate cloning code for the nested map value
  311. nestedDstExpr := fmt.Sprintf("%s[%s]", params.DstExpr, keyVar)
  312. writeMapValueClone(mapValueCloneParams{
  313. Buf: params.Buf,
  314. It: params.It,
  315. Elem: innerElem,
  316. SrcExpr: valVar,
  317. DstExpr: nestedDstExpr,
  318. BaseIndent: params.BaseIndent,
  319. Depth: params.Depth + 1,
  320. })
  321. writef("}")
  322. }
  323. case *types.Interface:
  324. if cloneResultType := methodResultType(elem, "Clone"); cloneResultType != nil {
  325. if _, isPtr := cloneResultType.(*types.Pointer); isPtr {
  326. writef("%s = *(%s.Clone())", params.DstExpr, params.SrcExpr)
  327. } else {
  328. writef("%s = %s.Clone()", params.DstExpr, params.SrcExpr)
  329. }
  330. } else {
  331. writef(`panic("map value (%%v) does not have a Clone method")`, elem)
  332. }
  333. default:
  334. writef("%s = *(%s.Clone())", params.DstExpr, params.SrcExpr)
  335. }
  336. }