compat_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. package srs
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/binary"
  6. "net/netip"
  7. "strings"
  8. "testing"
  9. "unsafe"
  10. M "github.com/sagernet/sing/common/metadata"
  11. "github.com/sagernet/sing/common/varbin"
  12. "github.com/stretchr/testify/require"
  13. "go4.org/netipx"
  14. )
  15. // Old implementations using varbin reflection-based serialization
  16. func oldWriteStringSlice(writer varbin.Writer, value []string) error {
  17. //nolint:staticcheck
  18. return varbin.Write(writer, binary.BigEndian, value)
  19. }
  20. func oldReadStringSlice(reader varbin.Reader) ([]string, error) {
  21. //nolint:staticcheck
  22. return varbin.ReadValue[[]string](reader, binary.BigEndian)
  23. }
  24. func oldWriteUint8Slice[E ~uint8](writer varbin.Writer, value []E) error {
  25. //nolint:staticcheck
  26. return varbin.Write(writer, binary.BigEndian, value)
  27. }
  28. func oldReadUint8Slice[E ~uint8](reader varbin.Reader) ([]E, error) {
  29. //nolint:staticcheck
  30. return varbin.ReadValue[[]E](reader, binary.BigEndian)
  31. }
  32. func oldWriteUint16Slice(writer varbin.Writer, value []uint16) error {
  33. //nolint:staticcheck
  34. return varbin.Write(writer, binary.BigEndian, value)
  35. }
  36. func oldReadUint16Slice(reader varbin.Reader) ([]uint16, error) {
  37. //nolint:staticcheck
  38. return varbin.ReadValue[[]uint16](reader, binary.BigEndian)
  39. }
  40. func oldWritePrefix(writer varbin.Writer, prefix netip.Prefix) error {
  41. //nolint:staticcheck
  42. err := varbin.Write(writer, binary.BigEndian, prefix.Addr().AsSlice())
  43. if err != nil {
  44. return err
  45. }
  46. return binary.Write(writer, binary.BigEndian, uint8(prefix.Bits()))
  47. }
  48. type oldIPRangeData struct {
  49. From []byte
  50. To []byte
  51. }
  52. // Note: The old writeIPSet had a bug where varbin.Write(writer, binary.BigEndian, data)
  53. // with a struct VALUE (not pointer) silently wrote nothing because field.CanSet() returned false.
  54. // This caused IP range data to be missing from the output.
  55. // The new implementation correctly writes all range data.
  56. //
  57. // The old readIPSet used varbin.Read with a pre-allocated slice, which worked because
  58. // slice elements are addressable and CanSet() returns true for them.
  59. //
  60. // For compatibility testing, we verify:
  61. // 1. New write produces correct output with range data
  62. // 2. New read can parse the new format correctly
  63. // 3. Round-trip works correctly
  64. func oldReadIPSet(reader varbin.Reader) (*netipx.IPSet, error) {
  65. version, err := reader.ReadByte()
  66. if err != nil {
  67. return nil, err
  68. }
  69. if version != 1 {
  70. return nil, err
  71. }
  72. var length uint64
  73. err = binary.Read(reader, binary.BigEndian, &length)
  74. if err != nil {
  75. return nil, err
  76. }
  77. ranges := make([]oldIPRangeData, length)
  78. //nolint:staticcheck
  79. err = varbin.Read(reader, binary.BigEndian, &ranges)
  80. if err != nil {
  81. return nil, err
  82. }
  83. mySet := &myIPSet{
  84. rr: make([]myIPRange, len(ranges)),
  85. }
  86. for i, rangeData := range ranges {
  87. mySet.rr[i].from = M.AddrFromIP(rangeData.From)
  88. mySet.rr[i].to = M.AddrFromIP(rangeData.To)
  89. }
  90. return (*netipx.IPSet)(unsafe.Pointer(mySet)), nil
  91. }
  92. // New write functions (without itemType prefix for testing)
  93. func newWriteStringSlice(writer varbin.Writer, value []string) error {
  94. _, err := varbin.WriteUvarint(writer, uint64(len(value)))
  95. if err != nil {
  96. return err
  97. }
  98. for _, s := range value {
  99. _, err = varbin.WriteUvarint(writer, uint64(len(s)))
  100. if err != nil {
  101. return err
  102. }
  103. _, err = writer.Write([]byte(s))
  104. if err != nil {
  105. return err
  106. }
  107. }
  108. return nil
  109. }
  110. func newWriteUint8Slice[E ~uint8](writer varbin.Writer, value []E) error {
  111. _, err := varbin.WriteUvarint(writer, uint64(len(value)))
  112. if err != nil {
  113. return err
  114. }
  115. _, err = writer.Write(*(*[]byte)(unsafe.Pointer(&value)))
  116. return err
  117. }
  118. func newWriteUint16Slice(writer varbin.Writer, value []uint16) error {
  119. _, err := varbin.WriteUvarint(writer, uint64(len(value)))
  120. if err != nil {
  121. return err
  122. }
  123. return binary.Write(writer, binary.BigEndian, value)
  124. }
  125. func newWritePrefix(writer varbin.Writer, prefix netip.Prefix) error {
  126. addrSlice := prefix.Addr().AsSlice()
  127. _, err := varbin.WriteUvarint(writer, uint64(len(addrSlice)))
  128. if err != nil {
  129. return err
  130. }
  131. _, err = writer.Write(addrSlice)
  132. if err != nil {
  133. return err
  134. }
  135. return writer.WriteByte(uint8(prefix.Bits()))
  136. }
  137. // Tests
  138. func TestStringSliceCompat(t *testing.T) {
  139. t.Parallel()
  140. cases := []struct {
  141. name string
  142. input []string
  143. }{
  144. {"nil", nil},
  145. {"empty", []string{}},
  146. {"single_empty", []string{""}},
  147. {"single", []string{"test"}},
  148. {"multi", []string{"a", "b", "c"}},
  149. {"with_empty", []string{"a", "", "c"}},
  150. {"utf8", []string{"测试", "テスト", "тест"}},
  151. {"long_string", []string{strings.Repeat("x", 128)}},
  152. {"many_elements", generateStrings(128)},
  153. {"many_elements_256", generateStrings(256)},
  154. {"127_byte_string", []string{strings.Repeat("x", 127)}},
  155. {"128_byte_string", []string{strings.Repeat("x", 128)}},
  156. {"mixed_lengths", []string{"a", strings.Repeat("b", 100), "", strings.Repeat("c", 200)}},
  157. }
  158. for _, tc := range cases {
  159. t.Run(tc.name, func(t *testing.T) {
  160. t.Parallel()
  161. // Old write
  162. var oldBuf bytes.Buffer
  163. err := oldWriteStringSlice(&oldBuf, tc.input)
  164. require.NoError(t, err)
  165. // New write
  166. var newBuf bytes.Buffer
  167. err = newWriteStringSlice(&newBuf, tc.input)
  168. require.NoError(t, err)
  169. // Bytes must match
  170. require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
  171. "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
  172. // New write -> old read
  173. readBack, err := oldReadStringSlice(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
  174. require.NoError(t, err)
  175. requireStringSliceEqual(t, tc.input, readBack)
  176. // Old write -> new read
  177. readBack2, err := readRuleItemString(bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
  178. require.NoError(t, err)
  179. requireStringSliceEqual(t, tc.input, readBack2)
  180. })
  181. }
  182. }
  183. func TestUint8SliceCompat(t *testing.T) {
  184. t.Parallel()
  185. cases := []struct {
  186. name string
  187. input []uint8
  188. }{
  189. {"nil", nil},
  190. {"empty", []uint8{}},
  191. {"single_zero", []uint8{0}},
  192. {"single_max", []uint8{255}},
  193. {"multi", []uint8{0, 1, 127, 128, 255}},
  194. {"boundary", []uint8{0x00, 0x7f, 0x80, 0xff}},
  195. {"sequential", generateUint8Slice(256)},
  196. {"127_elements", generateUint8Slice(127)},
  197. {"128_elements", generateUint8Slice(128)},
  198. }
  199. for _, tc := range cases {
  200. t.Run(tc.name, func(t *testing.T) {
  201. t.Parallel()
  202. // Old write
  203. var oldBuf bytes.Buffer
  204. err := oldWriteUint8Slice(&oldBuf, tc.input)
  205. require.NoError(t, err)
  206. // New write
  207. var newBuf bytes.Buffer
  208. err = newWriteUint8Slice(&newBuf, tc.input)
  209. require.NoError(t, err)
  210. // Bytes must match
  211. require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
  212. "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
  213. // New write -> old read
  214. readBack, err := oldReadUint8Slice[uint8](bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
  215. require.NoError(t, err)
  216. requireUint8SliceEqual(t, tc.input, readBack)
  217. // Old write -> new read
  218. readBack2, err := readRuleItemUint8[uint8](bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
  219. require.NoError(t, err)
  220. requireUint8SliceEqual(t, tc.input, readBack2)
  221. })
  222. }
  223. }
  224. func TestUint16SliceCompat(t *testing.T) {
  225. t.Parallel()
  226. cases := []struct {
  227. name string
  228. input []uint16
  229. }{
  230. {"nil", nil},
  231. {"empty", []uint16{}},
  232. {"single_zero", []uint16{0}},
  233. {"single_max", []uint16{65535}},
  234. {"multi", []uint16{0, 255, 256, 32767, 32768, 65535}},
  235. {"ports", []uint16{80, 443, 8080, 8443}},
  236. {"127_elements", generateUint16Slice(127)},
  237. {"128_elements", generateUint16Slice(128)},
  238. {"256_elements", generateUint16Slice(256)},
  239. }
  240. for _, tc := range cases {
  241. t.Run(tc.name, func(t *testing.T) {
  242. t.Parallel()
  243. // Old write
  244. var oldBuf bytes.Buffer
  245. err := oldWriteUint16Slice(&oldBuf, tc.input)
  246. require.NoError(t, err)
  247. // New write
  248. var newBuf bytes.Buffer
  249. err = newWriteUint16Slice(&newBuf, tc.input)
  250. require.NoError(t, err)
  251. // Bytes must match
  252. require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
  253. "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
  254. // New write -> old read
  255. readBack, err := oldReadUint16Slice(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
  256. require.NoError(t, err)
  257. requireUint16SliceEqual(t, tc.input, readBack)
  258. // Old write -> new read
  259. readBack2, err := readRuleItemUint16(bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
  260. require.NoError(t, err)
  261. requireUint16SliceEqual(t, tc.input, readBack2)
  262. })
  263. }
  264. }
  265. func TestPrefixCompat(t *testing.T) {
  266. t.Parallel()
  267. cases := []struct {
  268. name string
  269. input netip.Prefix
  270. }{
  271. {"ipv4_0", netip.MustParsePrefix("0.0.0.0/0")},
  272. {"ipv4_8", netip.MustParsePrefix("10.0.0.0/8")},
  273. {"ipv4_16", netip.MustParsePrefix("192.168.0.0/16")},
  274. {"ipv4_24", netip.MustParsePrefix("192.168.1.0/24")},
  275. {"ipv4_32", netip.MustParsePrefix("1.2.3.4/32")},
  276. {"ipv6_0", netip.MustParsePrefix("::/0")},
  277. {"ipv6_64", netip.MustParsePrefix("2001:db8::/64")},
  278. {"ipv6_128", netip.MustParsePrefix("::1/128")},
  279. {"ipv6_full", netip.MustParsePrefix("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128")},
  280. {"ipv4_private", netip.MustParsePrefix("172.16.0.0/12")},
  281. {"ipv6_link_local", netip.MustParsePrefix("fe80::/10")},
  282. }
  283. for _, tc := range cases {
  284. t.Run(tc.name, func(t *testing.T) {
  285. t.Parallel()
  286. // Old write
  287. var oldBuf bytes.Buffer
  288. err := oldWritePrefix(&oldBuf, tc.input)
  289. require.NoError(t, err)
  290. // New write
  291. var newBuf bytes.Buffer
  292. err = newWritePrefix(&newBuf, tc.input)
  293. require.NoError(t, err)
  294. // Bytes must match
  295. require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
  296. "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
  297. // New write -> new read (no old read for prefix)
  298. readBack, err := readPrefix(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
  299. require.NoError(t, err)
  300. require.Equal(t, tc.input, readBack)
  301. // Old write -> new read
  302. readBack2, err := readPrefix(bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
  303. require.NoError(t, err)
  304. require.Equal(t, tc.input, readBack2)
  305. })
  306. }
  307. }
  308. func TestIPSetCompat(t *testing.T) {
  309. t.Parallel()
  310. // Note: The old writeIPSet was buggy (varbin.Write with struct values wrote nothing).
  311. // This test verifies the new implementation writes correct data and round-trips correctly.
  312. cases := []struct {
  313. name string
  314. input *netipx.IPSet
  315. }{
  316. {"single_ipv4", buildIPSet("1.2.3.4")},
  317. {"ipv4_range", buildIPSet("192.168.0.0/16")},
  318. {"multi_ipv4", buildIPSet("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16")},
  319. {"single_ipv6", buildIPSet("::1")},
  320. {"ipv6_range", buildIPSet("2001:db8::/32")},
  321. {"mixed", buildIPSet("10.0.0.0/8", "::1", "2001:db8::/32")},
  322. {"large", buildLargeIPSet(100)},
  323. {"adjacent_ranges", buildIPSet("192.168.0.0/24", "192.168.1.0/24", "192.168.2.0/24")},
  324. }
  325. for _, tc := range cases {
  326. t.Run(tc.name, func(t *testing.T) {
  327. t.Parallel()
  328. // New write
  329. var newBuf bytes.Buffer
  330. err := writeIPSet(&newBuf, tc.input)
  331. require.NoError(t, err)
  332. // Verify format starts with version byte (1) + uint64 count
  333. require.True(t, len(newBuf.Bytes()) >= 9, "output too short")
  334. require.Equal(t, byte(1), newBuf.Bytes()[0], "version byte mismatch")
  335. // New write -> old read (varbin.Read with pre-allocated slice works correctly)
  336. readBack, err := oldReadIPSet(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
  337. require.NoError(t, err)
  338. requireIPSetEqual(t, tc.input, readBack)
  339. // New write -> new read
  340. readBack2, err := readIPSet(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
  341. require.NoError(t, err)
  342. requireIPSetEqual(t, tc.input, readBack2)
  343. })
  344. }
  345. }
  346. // Helper functions
  347. func generateStrings(count int) []string {
  348. result := make([]string, count)
  349. for i := range result {
  350. result[i] = strings.Repeat("x", i%50)
  351. }
  352. return result
  353. }
  354. func generateUint8Slice(count int) []uint8 {
  355. result := make([]uint8, count)
  356. for i := range result {
  357. result[i] = uint8(i % 256)
  358. }
  359. return result
  360. }
  361. func generateUint16Slice(count int) []uint16 {
  362. result := make([]uint16, count)
  363. for i := range result {
  364. result[i] = uint16(i * 257)
  365. }
  366. return result
  367. }
  368. func buildIPSet(cidrs ...string) *netipx.IPSet {
  369. var builder netipx.IPSetBuilder
  370. for _, cidr := range cidrs {
  371. prefix, err := netip.ParsePrefix(cidr)
  372. if err != nil {
  373. addr, err := netip.ParseAddr(cidr)
  374. if err != nil {
  375. panic(err)
  376. }
  377. builder.Add(addr)
  378. } else {
  379. builder.AddPrefix(prefix)
  380. }
  381. }
  382. set, _ := builder.IPSet()
  383. return set
  384. }
  385. func buildLargeIPSet(count int) *netipx.IPSet {
  386. var builder netipx.IPSetBuilder
  387. for i := 0; i < count; i++ {
  388. prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{10, byte(i / 256), byte(i % 256), 0}), 24)
  389. builder.AddPrefix(prefix)
  390. }
  391. set, _ := builder.IPSet()
  392. return set
  393. }
  394. func requireStringSliceEqual(t *testing.T, expected, actual []string) {
  395. t.Helper()
  396. if len(expected) == 0 && len(actual) == 0 {
  397. return
  398. }
  399. require.Equal(t, expected, actual)
  400. }
  401. func requireUint8SliceEqual(t *testing.T, expected, actual []uint8) {
  402. t.Helper()
  403. if len(expected) == 0 && len(actual) == 0 {
  404. return
  405. }
  406. require.Equal(t, expected, actual)
  407. }
  408. func requireUint16SliceEqual(t *testing.T, expected, actual []uint16) {
  409. t.Helper()
  410. if len(expected) == 0 && len(actual) == 0 {
  411. return
  412. }
  413. require.Equal(t, expected, actual)
  414. }
  415. func requireIPSetEqual(t *testing.T, expected, actual *netipx.IPSet) {
  416. t.Helper()
  417. expectedRanges := expected.Ranges()
  418. actualRanges := actual.Ranges()
  419. require.Equal(t, len(expectedRanges), len(actualRanges), "range count mismatch")
  420. for i := range expectedRanges {
  421. require.Equal(t, expectedRanges[i].From(), actualRanges[i].From(), "range[%d].from mismatch", i)
  422. require.Equal(t, expectedRanges[i].To(), actualRanges[i].To(), "range[%d].to mismatch", i)
  423. }
  424. }