codegen_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package codegen
  4. import (
  5. "cmp"
  6. "go/types"
  7. "net/netip"
  8. "strings"
  9. "sync"
  10. "testing"
  11. "unsafe"
  12. "golang.org/x/exp/constraints"
  13. )
  14. type AnyParam[T any] struct {
  15. V T
  16. }
  17. type AnyParamPhantom[T any] struct {
  18. }
  19. type IntegerParam[T constraints.Integer] struct {
  20. V T
  21. }
  22. type FloatParam[T constraints.Float] struct {
  23. V T
  24. }
  25. type StringLikeParam[T ~string] struct {
  26. V T
  27. }
  28. type BasicType interface {
  29. ~bool | constraints.Integer | constraints.Float | constraints.Complex | ~string
  30. }
  31. type BasicTypeParam[T BasicType] struct {
  32. V T
  33. }
  34. type IntPtr *int
  35. type IntPtrParam[T IntPtr] struct {
  36. V T
  37. }
  38. type IntegerPtr interface {
  39. *int | *int32 | *int64
  40. }
  41. type IntegerPtrParam[T IntegerPtr] struct {
  42. V T
  43. }
  44. type IntegerParamPtr[T constraints.Integer] struct {
  45. V *T
  46. }
  47. type IntegerSliceParam[T constraints.Integer] struct {
  48. V []T
  49. }
  50. type IntegerMapParam[T constraints.Integer] struct {
  51. V []T
  52. }
  53. type UnsafePointerParam[T unsafe.Pointer] struct {
  54. V T
  55. }
  56. type ValueUnionParam[T netip.Prefix | BasicType] struct {
  57. V T
  58. }
  59. type ValueUnionParamPtr[T netip.Prefix | BasicType] struct {
  60. V *T
  61. }
  62. type PointerUnionParam[T netip.Prefix | BasicType | IntPtr] struct {
  63. V T
  64. }
  65. type Interface interface {
  66. Method()
  67. }
  68. type InterfaceParam[T Interface] struct {
  69. V T
  70. }
  71. func TestGenericContainsPointers(t *testing.T) {
  72. tests := []struct {
  73. typ string
  74. wantPointer bool
  75. }{
  76. {
  77. typ: "AnyParam",
  78. wantPointer: true,
  79. },
  80. {
  81. typ: "AnyParamPhantom",
  82. wantPointer: false, // has a pointer type parameter, but no pointer fields
  83. },
  84. {
  85. typ: "IntegerParam",
  86. wantPointer: false,
  87. },
  88. {
  89. typ: "FloatParam",
  90. wantPointer: false,
  91. },
  92. {
  93. typ: "StringLikeParam",
  94. wantPointer: false,
  95. },
  96. {
  97. typ: "BasicTypeParam",
  98. wantPointer: false,
  99. },
  100. {
  101. typ: "IntPtrParam",
  102. wantPointer: true,
  103. },
  104. {
  105. typ: "IntegerPtrParam",
  106. wantPointer: true,
  107. },
  108. {
  109. typ: "IntegerParamPtr",
  110. wantPointer: true,
  111. },
  112. {
  113. typ: "IntegerSliceParam",
  114. wantPointer: true,
  115. },
  116. {
  117. typ: "IntegerMapParam",
  118. wantPointer: true,
  119. },
  120. {
  121. typ: "UnsafePointerParam",
  122. wantPointer: true,
  123. },
  124. {
  125. typ: "InterfaceParam",
  126. wantPointer: true,
  127. },
  128. {
  129. typ: "ValueUnionParam",
  130. wantPointer: false,
  131. },
  132. {
  133. typ: "ValueUnionParamPtr",
  134. wantPointer: true,
  135. },
  136. {
  137. typ: "PointerUnionParam",
  138. wantPointer: true,
  139. },
  140. }
  141. for _, tt := range tests {
  142. t.Run(tt.typ, func(t *testing.T) {
  143. typ := lookupTestType(t, tt.typ)
  144. if isPointer := ContainsPointers(typ); isPointer != tt.wantPointer {
  145. t.Fatalf("ContainsPointers: got %v, want: %v", isPointer, tt.wantPointer)
  146. }
  147. })
  148. }
  149. }
  150. func TestAssertStructUnchanged(t *testing.T) {
  151. type args struct {
  152. t *types.Struct
  153. tname string
  154. params *types.TypeParamList
  155. ctx string
  156. it *ImportTracker
  157. }
  158. // package t1 with a struct T1 with two fields
  159. p1 := types.NewPackage("t1", "t1")
  160. t1 := types.NewNamed(types.NewTypeName(0, p1, "T1", nil), types.NewStruct([]*types.Var{
  161. types.NewField(0, nil, "P1", types.Typ[types.Int], false),
  162. types.NewField(0, nil, "P2", types.Typ[types.String], false),
  163. }, nil), nil)
  164. p1.Scope().Insert(t1.Obj())
  165. tests := []struct {
  166. name string
  167. args args
  168. want []byte
  169. }{
  170. {
  171. name: "t1-internally_defined",
  172. args: args{
  173. t: t1.Underlying().(*types.Struct),
  174. tname: "prefix_",
  175. params: nil,
  176. ctx: "",
  177. it: NewImportTracker(p1),
  178. },
  179. want: []byte("var _prefix_NeedsRegeneration = prefix_(struct {\n\tP1 int \n\tP2 string \n}{})"),
  180. },
  181. {
  182. name: "t2-with_named_field",
  183. args: args{
  184. t: types.NewStruct([]*types.Var{
  185. types.NewField(0, nil, "T1", t1, false),
  186. types.NewField(0, nil, "P1", types.Typ[types.Int], false),
  187. types.NewField(0, nil, "P2", types.Typ[types.String], false),
  188. }, nil),
  189. tname: "prefix_",
  190. params: nil,
  191. ctx: "",
  192. it: NewImportTracker(types.NewPackage("t2", "t2")),
  193. },
  194. // the struct should be regenerated with the named field
  195. want: []byte("var _prefix_NeedsRegeneration = prefix_(struct {\n\tT1 t1.T1 \n\tP1 int \n\tP2 string \n}{})"),
  196. },
  197. {
  198. name: "t3-with_embedded_field",
  199. args: args{
  200. t: types.NewStruct([]*types.Var{
  201. types.NewField(0, nil, "T1", t1, true),
  202. types.NewField(0, nil, "P1", types.Typ[types.Int], false),
  203. types.NewField(0, nil, "P2", types.Typ[types.String], false),
  204. }, nil),
  205. tname: "prefix_",
  206. params: nil,
  207. ctx: "",
  208. it: NewImportTracker(types.NewPackage("t3", "t3")),
  209. },
  210. // the struct should be regenerated with the embedded field
  211. want: []byte("var _prefix_NeedsRegeneration = prefix_(struct {\n\tt1.T1 \n\tP1 int \n\tP2 string \n}{})"),
  212. },
  213. }
  214. for _, tt := range tests {
  215. t.Run(tt.name, func(t *testing.T) {
  216. if got := AssertStructUnchanged(tt.args.t, tt.args.tname, tt.args.params, tt.args.ctx, tt.args.it); !strings.Contains(string(got), string(tt.want)) {
  217. t.Errorf("AssertStructUnchanged() = \n%s\nwant: \n%s", string(got), string(tt.want))
  218. }
  219. })
  220. }
  221. }
  222. type NamedType struct{}
  223. func (NamedType) Method() {}
  224. type NamedTypeAlias = NamedType
  225. type NamedInterface interface {
  226. Method()
  227. }
  228. type NamedInterfaceAlias = NamedInterface
  229. type GenericType[T NamedInterface] struct {
  230. TypeParamField T
  231. TypeParamPtrField *T
  232. }
  233. type GenericTypeWithAliasConstraint[T NamedInterfaceAlias] struct {
  234. TypeParamField T
  235. TypeParamPtrField *T
  236. }
  237. func TestLookupMethod(t *testing.T) {
  238. tests := []struct {
  239. name string
  240. typ types.Type
  241. methodName string
  242. wantHasMethod bool
  243. wantReceiver types.Type
  244. }{
  245. {
  246. name: "NamedType/HasMethod",
  247. typ: lookupTestType(t, "NamedType"),
  248. methodName: "Method",
  249. wantHasMethod: true,
  250. },
  251. {
  252. name: "NamedType/NoMethod",
  253. typ: lookupTestType(t, "NamedType"),
  254. methodName: "NoMethod",
  255. wantHasMethod: false,
  256. },
  257. {
  258. name: "NamedTypeAlias/HasMethod",
  259. typ: lookupTestType(t, "NamedTypeAlias"),
  260. methodName: "Method",
  261. wantHasMethod: true,
  262. wantReceiver: lookupTestType(t, "NamedType"),
  263. },
  264. {
  265. name: "NamedTypeAlias/NoMethod",
  266. typ: lookupTestType(t, "NamedTypeAlias"),
  267. methodName: "NoMethod",
  268. wantHasMethod: false,
  269. },
  270. {
  271. name: "PtrToNamedType/HasMethod",
  272. typ: types.NewPointer(lookupTestType(t, "NamedType")),
  273. methodName: "Method",
  274. wantHasMethod: true,
  275. wantReceiver: lookupTestType(t, "NamedType"),
  276. },
  277. {
  278. name: "PtrToNamedType/NoMethod",
  279. typ: types.NewPointer(lookupTestType(t, "NamedType")),
  280. methodName: "NoMethod",
  281. wantHasMethod: false,
  282. },
  283. {
  284. name: "PtrToNamedTypeAlias/HasMethod",
  285. typ: types.NewPointer(lookupTestType(t, "NamedTypeAlias")),
  286. methodName: "Method",
  287. wantHasMethod: true,
  288. wantReceiver: lookupTestType(t, "NamedType"),
  289. },
  290. {
  291. name: "PtrToNamedTypeAlias/NoMethod",
  292. typ: types.NewPointer(lookupTestType(t, "NamedTypeAlias")),
  293. methodName: "NoMethod",
  294. wantHasMethod: false,
  295. },
  296. {
  297. name: "NamedInterface/HasMethod",
  298. typ: lookupTestType(t, "NamedInterface"),
  299. methodName: "Method",
  300. wantHasMethod: true,
  301. },
  302. {
  303. name: "NamedInterface/NoMethod",
  304. typ: lookupTestType(t, "NamedInterface"),
  305. methodName: "NoMethod",
  306. wantHasMethod: false,
  307. },
  308. {
  309. name: "Interface/HasMethod",
  310. typ: types.NewInterfaceType([]*types.Func{types.NewFunc(0, nil, "Method", types.NewSignatureType(nil, nil, nil, nil, nil, false))}, nil),
  311. methodName: "Method",
  312. wantHasMethod: true,
  313. },
  314. {
  315. name: "Interface/NoMethod",
  316. typ: types.NewInterfaceType(nil, nil),
  317. methodName: "NoMethod",
  318. wantHasMethod: false,
  319. },
  320. {
  321. name: "TypeParam/HasMethod",
  322. typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(0).Type(),
  323. methodName: "Method",
  324. wantHasMethod: true,
  325. wantReceiver: lookupTestType(t, "NamedInterface"),
  326. },
  327. {
  328. name: "TypeParam/NoMethod",
  329. typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(0).Type(),
  330. methodName: "NoMethod",
  331. wantHasMethod: false,
  332. },
  333. {
  334. name: "TypeParamPtr/HasMethod",
  335. typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(1).Type(),
  336. methodName: "Method",
  337. wantHasMethod: true,
  338. wantReceiver: lookupTestType(t, "NamedInterface"),
  339. },
  340. {
  341. name: "TypeParamPtr/NoMethod",
  342. typ: lookupTestType(t, "GenericType").Underlying().(*types.Struct).Field(1).Type(),
  343. methodName: "NoMethod",
  344. wantHasMethod: false,
  345. },
  346. {
  347. name: "TypeParamWithAlias/HasMethod",
  348. typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(0).Type(),
  349. methodName: "Method",
  350. wantHasMethod: true,
  351. wantReceiver: lookupTestType(t, "NamedInterface"),
  352. },
  353. {
  354. name: "TypeParamWithAlias/NoMethod",
  355. typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(0).Type(),
  356. methodName: "NoMethod",
  357. wantHasMethod: false,
  358. },
  359. {
  360. name: "TypeParamWithAliasPtr/HasMethod",
  361. typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(1).Type(),
  362. methodName: "Method",
  363. wantHasMethod: true,
  364. wantReceiver: lookupTestType(t, "NamedInterface"),
  365. },
  366. {
  367. name: "TypeParamWithAliasPtr/NoMethod",
  368. typ: lookupTestType(t, "GenericTypeWithAliasConstraint").Underlying().(*types.Struct).Field(1).Type(),
  369. methodName: "NoMethod",
  370. wantHasMethod: false,
  371. },
  372. }
  373. for _, tt := range tests {
  374. t.Run(tt.name, func(t *testing.T) {
  375. gotMethod := LookupMethod(tt.typ, tt.methodName)
  376. if gotHasMethod := gotMethod != nil; gotHasMethod != tt.wantHasMethod {
  377. t.Fatalf("HasMethod: got %v; want %v", gotMethod, tt.wantHasMethod)
  378. }
  379. if gotMethod == nil {
  380. return
  381. }
  382. if gotMethod.Name() != tt.methodName {
  383. t.Errorf("Name: got %v; want %v", gotMethod.Name(), tt.methodName)
  384. }
  385. if gotRecv, wantRecv := gotMethod.Signature().Recv().Type(), cmp.Or(tt.wantReceiver, tt.typ); !types.Identical(gotRecv, wantRecv) {
  386. t.Errorf("Recv: got %v; want %v", gotRecv, wantRecv)
  387. }
  388. })
  389. }
  390. }
  391. var namedTestTypes = sync.OnceValues(func() (map[string]types.Type, error) {
  392. _, namedTypes, err := LoadTypes("test", ".")
  393. return namedTypes, err
  394. })
  395. func lookupTestType(t *testing.T, name string) types.Type {
  396. t.Helper()
  397. types, err := namedTestTypes()
  398. if err != nil {
  399. t.Fatal(err)
  400. }
  401. typ, ok := types[name]
  402. if !ok {
  403. t.Fatalf("type %q is not declared in the current package", name)
  404. }
  405. return typ
  406. }