codegen_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package codegen
  4. import (
  5. "cmp"
  6. "go/types"
  7. "net/netip"
  8. "strings"
  9. "sync"
  10. "testing"
  11. "time"
  12. "unique"
  13. "unsafe"
  14. "golang.org/x/exp/constraints"
  15. )
  16. type AnyParam[T any] struct {
  17. V T
  18. }
  19. type AnyParamPhantom[T any] struct {
  20. }
  21. type IntegerParam[T constraints.Integer] struct {
  22. V T
  23. }
  24. type FloatParam[T constraints.Float] struct {
  25. V T
  26. }
  27. type StringLikeParam[T ~string] struct {
  28. V T
  29. }
  30. type BasicType interface {
  31. ~bool | constraints.Integer | constraints.Float | constraints.Complex | ~string
  32. }
  33. type BasicTypeParam[T BasicType] struct {
  34. V T
  35. }
  36. type IntPtr *int
  37. type IntPtrParam[T IntPtr] struct {
  38. V T
  39. }
  40. type IntegerPtr interface {
  41. *int | *int32 | *int64
  42. }
  43. type IntegerPtrParam[T IntegerPtr] struct {
  44. V T
  45. }
  46. type IntegerParamPtr[T constraints.Integer] struct {
  47. V *T
  48. }
  49. type IntegerSliceParam[T constraints.Integer] struct {
  50. V []T
  51. }
  52. type IntegerMapParam[T constraints.Integer] struct {
  53. V []T
  54. }
  55. type UnsafePointerParam[T unsafe.Pointer] struct {
  56. V T
  57. }
  58. type ValueUnionParam[T netip.Prefix | BasicType] struct {
  59. V T
  60. }
  61. type ValueUnionParamPtr[T netip.Prefix | BasicType] struct {
  62. V *T
  63. }
  64. type PointerUnionParam[T netip.Prefix | BasicType | IntPtr] struct {
  65. V T
  66. }
  67. type StructWithUniqueHandle struct{ _ unique.Handle[[32]byte] }
  68. type StructWithTime struct{ _ time.Time }
  69. type StructWithNetipTypes struct {
  70. _ netip.Addr
  71. _ netip.AddrPort
  72. _ netip.Prefix
  73. }
  74. type Interface interface {
  75. Method()
  76. }
  77. type InterfaceParam[T Interface] struct {
  78. V T
  79. }
  80. func TestGenericContainsPointers(t *testing.T) {
  81. tests := []struct {
  82. typ string
  83. wantPointer bool
  84. }{
  85. {
  86. typ: "AnyParam",
  87. wantPointer: true,
  88. },
  89. {
  90. typ: "AnyParamPhantom",
  91. wantPointer: false, // has a pointer type parameter, but no pointer fields
  92. },
  93. {
  94. typ: "IntegerParam",
  95. wantPointer: false,
  96. },
  97. {
  98. typ: "FloatParam",
  99. wantPointer: false,
  100. },
  101. {
  102. typ: "StringLikeParam",
  103. wantPointer: false,
  104. },
  105. {
  106. typ: "BasicTypeParam",
  107. wantPointer: false,
  108. },
  109. {
  110. typ: "IntPtrParam",
  111. wantPointer: true,
  112. },
  113. {
  114. typ: "IntegerPtrParam",
  115. wantPointer: true,
  116. },
  117. {
  118. typ: "IntegerParamPtr",
  119. wantPointer: true,
  120. },
  121. {
  122. typ: "IntegerSliceParam",
  123. wantPointer: true,
  124. },
  125. {
  126. typ: "IntegerMapParam",
  127. wantPointer: true,
  128. },
  129. {
  130. typ: "UnsafePointerParam",
  131. wantPointer: true,
  132. },
  133. {
  134. typ: "InterfaceParam",
  135. wantPointer: true,
  136. },
  137. {
  138. typ: "ValueUnionParam",
  139. wantPointer: false,
  140. },
  141. {
  142. typ: "ValueUnionParamPtr",
  143. wantPointer: true,
  144. },
  145. {
  146. typ: "PointerUnionParam",
  147. wantPointer: true,
  148. },
  149. {
  150. typ: "StructWithUniqueHandle",
  151. wantPointer: false,
  152. },
  153. {
  154. typ: "StructWithTime",
  155. wantPointer: false,
  156. },
  157. {
  158. typ: "StructWithNetipTypes",
  159. wantPointer: false,
  160. },
  161. }
  162. for _, tt := range tests {
  163. t.Run(tt.typ, func(t *testing.T) {
  164. typ := lookupTestType(t, tt.typ)
  165. if isPointer := ContainsPointers(typ); isPointer != tt.wantPointer {
  166. t.Fatalf("ContainsPointers: got %v, want: %v", isPointer, tt.wantPointer)
  167. }
  168. })
  169. }
  170. }
  171. func TestAssertStructUnchanged(t *testing.T) {
  172. type args struct {
  173. t *types.Struct
  174. tname string
  175. params *types.TypeParamList
  176. ctx string
  177. it *ImportTracker
  178. }
  179. // package t1 with a struct T1 with two fields
  180. p1 := types.NewPackage("t1", "t1")
  181. t1 := types.NewNamed(types.NewTypeName(0, p1, "T1", nil), types.NewStruct([]*types.Var{
  182. types.NewField(0, nil, "P1", types.Typ[types.Int], false),
  183. types.NewField(0, nil, "P2", types.Typ[types.String], false),
  184. }, nil), nil)
  185. p1.Scope().Insert(t1.Obj())
  186. tests := []struct {
  187. name string
  188. args args
  189. want []byte
  190. }{
  191. {
  192. name: "t1-internally_defined",
  193. args: args{
  194. t: t1.Underlying().(*types.Struct),
  195. tname: "prefix_",
  196. params: nil,
  197. ctx: "",
  198. it: NewImportTracker(p1),
  199. },
  200. want: []byte("var _prefix_NeedsRegeneration = prefix_(struct {\n\tP1 int \n\tP2 string \n}{})"),
  201. },
  202. {
  203. name: "t2-with_named_field",
  204. args: args{
  205. t: types.NewStruct([]*types.Var{
  206. types.NewField(0, nil, "T1", t1, false),
  207. types.NewField(0, nil, "P1", types.Typ[types.Int], false),
  208. types.NewField(0, nil, "P2", types.Typ[types.String], false),
  209. }, nil),
  210. tname: "prefix_",
  211. params: nil,
  212. ctx: "",
  213. it: NewImportTracker(types.NewPackage("t2", "t2")),
  214. },
  215. // the struct should be regenerated with the named field
  216. want: []byte("var _prefix_NeedsRegeneration = prefix_(struct {\n\tT1 t1.T1 \n\tP1 int \n\tP2 string \n}{})"),
  217. },
  218. {
  219. name: "t3-with_embedded_field",
  220. args: args{
  221. t: types.NewStruct([]*types.Var{
  222. types.NewField(0, nil, "T1", t1, true),
  223. types.NewField(0, nil, "P1", types.Typ[types.Int], false),
  224. types.NewField(0, nil, "P2", types.Typ[types.String], false),
  225. }, nil),
  226. tname: "prefix_",
  227. params: nil,
  228. ctx: "",
  229. it: NewImportTracker(types.NewPackage("t3", "t3")),
  230. },
  231. // the struct should be regenerated with the embedded field
  232. want: []byte("var _prefix_NeedsRegeneration = prefix_(struct {\n\tt1.T1 \n\tP1 int \n\tP2 string \n}{})"),
  233. },
  234. }
  235. for _, tt := range tests {
  236. t.Run(tt.name, func(t *testing.T) {
  237. if got := AssertStructUnchanged(tt.args.t, tt.args.tname, tt.args.params, tt.args.ctx, tt.args.it); !strings.Contains(string(got), string(tt.want)) {
  238. t.Errorf("AssertStructUnchanged() = \n%s\nwant: \n%s", string(got), string(tt.want))
  239. }
  240. })
  241. }
  242. }
  243. type NamedType struct{}
  244. func (NamedType) Method() {}
  245. type NamedTypeAlias = NamedType
  246. type NamedInterface interface {
  247. Method()
  248. }
  249. type NamedInterfaceAlias = NamedInterface
  250. type GenericType[T NamedInterface] struct {
  251. TypeParamField T
  252. TypeParamPtrField *T
  253. }
  254. type GenericTypeWithAliasConstraint[T NamedInterfaceAlias] struct {
  255. TypeParamField T
  256. TypeParamPtrField *T
  257. }
  258. func TestLookupMethod(t *testing.T) {
  259. tests := []struct {
  260. name string
  261. typ types.Type
  262. methodName string
  263. wantHasMethod bool
  264. wantReceiver types.Type
  265. }{
  266. {
  267. name: "NamedType/HasMethod",
  268. typ: lookupTestType(t, "NamedType"),
  269. methodName: "Method",
  270. wantHasMethod: true,
  271. },
  272. {
  273. name: "NamedType/NoMethod",
  274. typ: lookupTestType(t, "NamedType"),
  275. methodName: "NoMethod",
  276. wantHasMethod: false,
  277. },
  278. {
  279. name: "NamedTypeAlias/HasMethod",
  280. typ: lookupTestType(t, "NamedTypeAlias"),
  281. methodName: "Method",
  282. wantHasMethod: true,
  283. wantReceiver: lookupTestType(t, "NamedType"),
  284. },
  285. {
  286. name: "NamedTypeAlias/NoMethod",
  287. typ: lookupTestType(t, "NamedTypeAlias"),
  288. methodName: "NoMethod",
  289. wantHasMethod: false,
  290. },
  291. {
  292. name: "PtrToNamedType/HasMethod",
  293. typ: types.NewPointer(lookupTestType(t, "NamedType")),
  294. methodName: "Method",
  295. wantHasMethod: true,
  296. wantReceiver: lookupTestType(t, "NamedType"),
  297. },
  298. {
  299. name: "PtrToNamedType/NoMethod",
  300. typ: types.NewPointer(lookupTestType(t, "NamedType")),
  301. methodName: "NoMethod",
  302. wantHasMethod: false,
  303. },
  304. {
  305. name: "PtrToNamedTypeAlias/HasMethod",
  306. typ: types.NewPointer(lookupTestType(t, "NamedTypeAlias")),
  307. methodName: "Method",
  308. wantHasMethod: true,
  309. wantReceiver: lookupTestType(t, "NamedType"),
  310. },
  311. {
  312. name: "PtrToNamedTypeAlias/NoMethod",
  313. typ: types.NewPointer(lookupTestType(t, "NamedTypeAlias")),
  314. methodName: "NoMethod",
  315. wantHasMethod: false,
  316. },
  317. {
  318. name: "NamedInterface/HasMethod",
  319. typ: lookupTestType(t, "NamedInterface"),
  320. methodName: "Method",
  321. wantHasMethod: true,
  322. },
  323. {
  324. name: "NamedInterface/NoMethod",
  325. typ: lookupTestType(t, "NamedInterface"),
  326. methodName: "NoMethod",
  327. wantHasMethod: false,
  328. },
  329. {
  330. name: "Interface/HasMethod",
  331. typ: types.NewInterfaceType([]*types.Func{types.NewFunc(0, nil, "Method", types.NewSignatureType(nil, nil, nil, nil, nil, false))}, nil),
  332. methodName: "Method",
  333. wantHasMethod: true,
  334. },
  335. {
  336. name: "Interface/NoMethod",
  337. typ: types.NewInterfaceType(nil, nil),
  338. methodName: "NoMethod",
  339. wantHasMethod: false,
  340. },
  341. {
  342. name: "TypeParam/HasMethod",
  343. typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(0).Type(),
  344. methodName: "Method",
  345. wantHasMethod: true,
  346. wantReceiver: lookupTestType(t, "NamedInterface"),
  347. },
  348. {
  349. name: "TypeParam/NoMethod",
  350. typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(0).Type(),
  351. methodName: "NoMethod",
  352. wantHasMethod: false,
  353. },
  354. {
  355. name: "TypeParamPtr/HasMethod",
  356. typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(1).Type(),
  357. methodName: "Method",
  358. wantHasMethod: true,
  359. wantReceiver: lookupTestType(t, "NamedInterface"),
  360. },
  361. {
  362. name: "TypeParamPtr/NoMethod",
  363. typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(1).Type(),
  364. methodName: "NoMethod",
  365. wantHasMethod: false,
  366. },
  367. {
  368. name: "TypeParamWithAlias/HasMethod",
  369. typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(0).Type(),
  370. methodName: "Method",
  371. wantHasMethod: true,
  372. wantReceiver: lookupTestType(t, "NamedInterface"),
  373. },
  374. {
  375. name: "TypeParamWithAlias/NoMethod",
  376. typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(0).Type(),
  377. methodName: "NoMethod",
  378. wantHasMethod: false,
  379. },
  380. {
  381. name: "TypeParamWithAliasPtr/HasMethod",
  382. typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(1).Type(),
  383. methodName: "Method",
  384. wantHasMethod: true,
  385. wantReceiver: lookupTestType(t, "NamedInterface"),
  386. },
  387. {
  388. name: "TypeParamWithAliasPtr/NoMethod",
  389. typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(1).Type(),
  390. methodName: "NoMethod",
  391. wantHasMethod: false,
  392. },
  393. }
  394. for _, tt := range tests {
  395. t.Run(tt.name, func(t *testing.T) {
  396. gotMethod := LookupMethod(tt.typ, tt.methodName)
  397. if gotHasMethod := gotMethod != nil; gotHasMethod != tt.wantHasMethod {
  398. t.Fatalf("HasMethod: got %v; want %v", gotMethod, tt.wantHasMethod)
  399. }
  400. if gotMethod == nil {
  401. return
  402. }
  403. if gotMethod.Name() != tt.methodName {
  404. t.Errorf("Name: got %v; want %v", gotMethod.Name(), tt.methodName)
  405. }
  406. if gotRecv, wantRecv := gotMethod.Signature().Recv().Type(), cmp.Or(tt.wantReceiver, tt.typ); !types.Identical(gotRecv, wantRecv) {
  407. t.Errorf("Recv: got %v; want %v", gotRecv, wantRecv)
  408. }
  409. })
  410. }
  411. }
  412. var namedTestTypes = sync.OnceValues(func() (map[string]types.Type, error) {
  413. _, namedTypes, err := LoadTypes("test", ".")
  414. return namedTypes, err
  415. })
  416. func lookupTestType(t *testing.T, name string) types.Type {
  417. t.Helper()
  418. types, err := namedTestTypes()
  419. if err != nil {
  420. t.Fatal(err)
  421. }
  422. typ, ok := types[name]
  423. if !ok {
  424. t.Fatalf("type %q is not declared in the current package", name)
  425. }
  426. return typ
  427. }