| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494 |
- package srs
- import (
- "bufio"
- "bytes"
- "encoding/binary"
- "net/netip"
- "strings"
- "testing"
- "unsafe"
- M "github.com/sagernet/sing/common/metadata"
- "github.com/sagernet/sing/common/varbin"
- "github.com/stretchr/testify/require"
- "go4.org/netipx"
- )
- // Old implementations using varbin reflection-based serialization
- func oldWriteStringSlice(writer varbin.Writer, value []string) error {
- //nolint:staticcheck
- return varbin.Write(writer, binary.BigEndian, value)
- }
- func oldReadStringSlice(reader varbin.Reader) ([]string, error) {
- //nolint:staticcheck
- return varbin.ReadValue[[]string](reader, binary.BigEndian)
- }
- func oldWriteUint8Slice[E ~uint8](writer varbin.Writer, value []E) error {
- //nolint:staticcheck
- return varbin.Write(writer, binary.BigEndian, value)
- }
- func oldReadUint8Slice[E ~uint8](reader varbin.Reader) ([]E, error) {
- //nolint:staticcheck
- return varbin.ReadValue[[]E](reader, binary.BigEndian)
- }
- func oldWriteUint16Slice(writer varbin.Writer, value []uint16) error {
- //nolint:staticcheck
- return varbin.Write(writer, binary.BigEndian, value)
- }
- func oldReadUint16Slice(reader varbin.Reader) ([]uint16, error) {
- //nolint:staticcheck
- return varbin.ReadValue[[]uint16](reader, binary.BigEndian)
- }
- func oldWritePrefix(writer varbin.Writer, prefix netip.Prefix) error {
- //nolint:staticcheck
- err := varbin.Write(writer, binary.BigEndian, prefix.Addr().AsSlice())
- if err != nil {
- return err
- }
- return binary.Write(writer, binary.BigEndian, uint8(prefix.Bits()))
- }
- type oldIPRangeData struct {
- From []byte
- To []byte
- }
- // Note: The old writeIPSet had a bug where varbin.Write(writer, binary.BigEndian, data)
- // with a struct VALUE (not pointer) silently wrote nothing because field.CanSet() returned false.
- // This caused IP range data to be missing from the output.
- // The new implementation correctly writes all range data.
- //
- // The old readIPSet used varbin.Read with a pre-allocated slice, which worked because
- // slice elements are addressable and CanSet() returns true for them.
- //
- // For compatibility testing, we verify:
- // 1. New write produces correct output with range data
- // 2. New read can parse the new format correctly
- // 3. Round-trip works correctly
- func oldReadIPSet(reader varbin.Reader) (*netipx.IPSet, error) {
- version, err := reader.ReadByte()
- if err != nil {
- return nil, err
- }
- if version != 1 {
- return nil, err
- }
- var length uint64
- err = binary.Read(reader, binary.BigEndian, &length)
- if err != nil {
- return nil, err
- }
- ranges := make([]oldIPRangeData, length)
- //nolint:staticcheck
- err = varbin.Read(reader, binary.BigEndian, &ranges)
- if err != nil {
- return nil, err
- }
- mySet := &myIPSet{
- rr: make([]myIPRange, len(ranges)),
- }
- for i, rangeData := range ranges {
- mySet.rr[i].from = M.AddrFromIP(rangeData.From)
- mySet.rr[i].to = M.AddrFromIP(rangeData.To)
- }
- return (*netipx.IPSet)(unsafe.Pointer(mySet)), nil
- }
- // New write functions (without itemType prefix for testing)
- func newWriteStringSlice(writer varbin.Writer, value []string) error {
- _, err := varbin.WriteUvarint(writer, uint64(len(value)))
- if err != nil {
- return err
- }
- for _, s := range value {
- _, err = varbin.WriteUvarint(writer, uint64(len(s)))
- if err != nil {
- return err
- }
- _, err = writer.Write([]byte(s))
- if err != nil {
- return err
- }
- }
- return nil
- }
- func newWriteUint8Slice[E ~uint8](writer varbin.Writer, value []E) error {
- _, err := varbin.WriteUvarint(writer, uint64(len(value)))
- if err != nil {
- return err
- }
- _, err = writer.Write(*(*[]byte)(unsafe.Pointer(&value)))
- return err
- }
- func newWriteUint16Slice(writer varbin.Writer, value []uint16) error {
- _, err := varbin.WriteUvarint(writer, uint64(len(value)))
- if err != nil {
- return err
- }
- return binary.Write(writer, binary.BigEndian, value)
- }
- func newWritePrefix(writer varbin.Writer, prefix netip.Prefix) error {
- addrSlice := prefix.Addr().AsSlice()
- _, err := varbin.WriteUvarint(writer, uint64(len(addrSlice)))
- if err != nil {
- return err
- }
- _, err = writer.Write(addrSlice)
- if err != nil {
- return err
- }
- return writer.WriteByte(uint8(prefix.Bits()))
- }
- // Tests
- func TestStringSliceCompat(t *testing.T) {
- t.Parallel()
- cases := []struct {
- name string
- input []string
- }{
- {"nil", nil},
- {"empty", []string{}},
- {"single_empty", []string{""}},
- {"single", []string{"test"}},
- {"multi", []string{"a", "b", "c"}},
- {"with_empty", []string{"a", "", "c"}},
- {"utf8", []string{"测试", "テスト", "тест"}},
- {"long_string", []string{strings.Repeat("x", 128)}},
- {"many_elements", generateStrings(128)},
- {"many_elements_256", generateStrings(256)},
- {"127_byte_string", []string{strings.Repeat("x", 127)}},
- {"128_byte_string", []string{strings.Repeat("x", 128)}},
- {"mixed_lengths", []string{"a", strings.Repeat("b", 100), "", strings.Repeat("c", 200)}},
- }
- for _, tc := range cases {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
- // Old write
- var oldBuf bytes.Buffer
- err := oldWriteStringSlice(&oldBuf, tc.input)
- require.NoError(t, err)
- // New write
- var newBuf bytes.Buffer
- err = newWriteStringSlice(&newBuf, tc.input)
- require.NoError(t, err)
- // Bytes must match
- require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
- "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
- // New write -> old read
- readBack, err := oldReadStringSlice(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
- require.NoError(t, err)
- requireStringSliceEqual(t, tc.input, readBack)
- // Old write -> new read
- readBack2, err := readRuleItemString(bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
- require.NoError(t, err)
- requireStringSliceEqual(t, tc.input, readBack2)
- })
- }
- }
- func TestUint8SliceCompat(t *testing.T) {
- t.Parallel()
- cases := []struct {
- name string
- input []uint8
- }{
- {"nil", nil},
- {"empty", []uint8{}},
- {"single_zero", []uint8{0}},
- {"single_max", []uint8{255}},
- {"multi", []uint8{0, 1, 127, 128, 255}},
- {"boundary", []uint8{0x00, 0x7f, 0x80, 0xff}},
- {"sequential", generateUint8Slice(256)},
- {"127_elements", generateUint8Slice(127)},
- {"128_elements", generateUint8Slice(128)},
- }
- for _, tc := range cases {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
- // Old write
- var oldBuf bytes.Buffer
- err := oldWriteUint8Slice(&oldBuf, tc.input)
- require.NoError(t, err)
- // New write
- var newBuf bytes.Buffer
- err = newWriteUint8Slice(&newBuf, tc.input)
- require.NoError(t, err)
- // Bytes must match
- require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
- "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
- // New write -> old read
- readBack, err := oldReadUint8Slice[uint8](bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
- require.NoError(t, err)
- requireUint8SliceEqual(t, tc.input, readBack)
- // Old write -> new read
- readBack2, err := readRuleItemUint8[uint8](bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
- require.NoError(t, err)
- requireUint8SliceEqual(t, tc.input, readBack2)
- })
- }
- }
- func TestUint16SliceCompat(t *testing.T) {
- t.Parallel()
- cases := []struct {
- name string
- input []uint16
- }{
- {"nil", nil},
- {"empty", []uint16{}},
- {"single_zero", []uint16{0}},
- {"single_max", []uint16{65535}},
- {"multi", []uint16{0, 255, 256, 32767, 32768, 65535}},
- {"ports", []uint16{80, 443, 8080, 8443}},
- {"127_elements", generateUint16Slice(127)},
- {"128_elements", generateUint16Slice(128)},
- {"256_elements", generateUint16Slice(256)},
- }
- for _, tc := range cases {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
- // Old write
- var oldBuf bytes.Buffer
- err := oldWriteUint16Slice(&oldBuf, tc.input)
- require.NoError(t, err)
- // New write
- var newBuf bytes.Buffer
- err = newWriteUint16Slice(&newBuf, tc.input)
- require.NoError(t, err)
- // Bytes must match
- require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
- "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
- // New write -> old read
- readBack, err := oldReadUint16Slice(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
- require.NoError(t, err)
- requireUint16SliceEqual(t, tc.input, readBack)
- // Old write -> new read
- readBack2, err := readRuleItemUint16(bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
- require.NoError(t, err)
- requireUint16SliceEqual(t, tc.input, readBack2)
- })
- }
- }
- func TestPrefixCompat(t *testing.T) {
- t.Parallel()
- cases := []struct {
- name string
- input netip.Prefix
- }{
- {"ipv4_0", netip.MustParsePrefix("0.0.0.0/0")},
- {"ipv4_8", netip.MustParsePrefix("10.0.0.0/8")},
- {"ipv4_16", netip.MustParsePrefix("192.168.0.0/16")},
- {"ipv4_24", netip.MustParsePrefix("192.168.1.0/24")},
- {"ipv4_32", netip.MustParsePrefix("1.2.3.4/32")},
- {"ipv6_0", netip.MustParsePrefix("::/0")},
- {"ipv6_64", netip.MustParsePrefix("2001:db8::/64")},
- {"ipv6_128", netip.MustParsePrefix("::1/128")},
- {"ipv6_full", netip.MustParsePrefix("2001:0db8:85a3:0000:0000:8a2e:0370:7334/128")},
- {"ipv4_private", netip.MustParsePrefix("172.16.0.0/12")},
- {"ipv6_link_local", netip.MustParsePrefix("fe80::/10")},
- }
- for _, tc := range cases {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
- // Old write
- var oldBuf bytes.Buffer
- err := oldWritePrefix(&oldBuf, tc.input)
- require.NoError(t, err)
- // New write
- var newBuf bytes.Buffer
- err = newWritePrefix(&newBuf, tc.input)
- require.NoError(t, err)
- // Bytes must match
- require.Equal(t, oldBuf.Bytes(), newBuf.Bytes(),
- "mismatch for %q\nold: %x\nnew: %x", tc.name, oldBuf.Bytes(), newBuf.Bytes())
- // New write -> new read (no old read for prefix)
- readBack, err := readPrefix(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
- require.NoError(t, err)
- require.Equal(t, tc.input, readBack)
- // Old write -> new read
- readBack2, err := readPrefix(bufio.NewReader(bytes.NewReader(oldBuf.Bytes())))
- require.NoError(t, err)
- require.Equal(t, tc.input, readBack2)
- })
- }
- }
- func TestIPSetCompat(t *testing.T) {
- t.Parallel()
- // Note: The old writeIPSet was buggy (varbin.Write with struct values wrote nothing).
- // This test verifies the new implementation writes correct data and round-trips correctly.
- cases := []struct {
- name string
- input *netipx.IPSet
- }{
- {"single_ipv4", buildIPSet("1.2.3.4")},
- {"ipv4_range", buildIPSet("192.168.0.0/16")},
- {"multi_ipv4", buildIPSet("10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16")},
- {"single_ipv6", buildIPSet("::1")},
- {"ipv6_range", buildIPSet("2001:db8::/32")},
- {"mixed", buildIPSet("10.0.0.0/8", "::1", "2001:db8::/32")},
- {"large", buildLargeIPSet(100)},
- {"adjacent_ranges", buildIPSet("192.168.0.0/24", "192.168.1.0/24", "192.168.2.0/24")},
- }
- for _, tc := range cases {
- t.Run(tc.name, func(t *testing.T) {
- t.Parallel()
- // New write
- var newBuf bytes.Buffer
- err := writeIPSet(&newBuf, tc.input)
- require.NoError(t, err)
- // Verify format starts with version byte (1) + uint64 count
- require.True(t, len(newBuf.Bytes()) >= 9, "output too short")
- require.Equal(t, byte(1), newBuf.Bytes()[0], "version byte mismatch")
- // New write -> old read (varbin.Read with pre-allocated slice works correctly)
- readBack, err := oldReadIPSet(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
- require.NoError(t, err)
- requireIPSetEqual(t, tc.input, readBack)
- // New write -> new read
- readBack2, err := readIPSet(bufio.NewReader(bytes.NewReader(newBuf.Bytes())))
- require.NoError(t, err)
- requireIPSetEqual(t, tc.input, readBack2)
- })
- }
- }
- // Helper functions
- func generateStrings(count int) []string {
- result := make([]string, count)
- for i := range result {
- result[i] = strings.Repeat("x", i%50)
- }
- return result
- }
- func generateUint8Slice(count int) []uint8 {
- result := make([]uint8, count)
- for i := range result {
- result[i] = uint8(i % 256)
- }
- return result
- }
- func generateUint16Slice(count int) []uint16 {
- result := make([]uint16, count)
- for i := range result {
- result[i] = uint16(i * 257)
- }
- return result
- }
- func buildIPSet(cidrs ...string) *netipx.IPSet {
- var builder netipx.IPSetBuilder
- for _, cidr := range cidrs {
- prefix, err := netip.ParsePrefix(cidr)
- if err != nil {
- addr, err := netip.ParseAddr(cidr)
- if err != nil {
- panic(err)
- }
- builder.Add(addr)
- } else {
- builder.AddPrefix(prefix)
- }
- }
- set, _ := builder.IPSet()
- return set
- }
- func buildLargeIPSet(count int) *netipx.IPSet {
- var builder netipx.IPSetBuilder
- for i := 0; i < count; i++ {
- prefix := netip.PrefixFrom(netip.AddrFrom4([4]byte{10, byte(i / 256), byte(i % 256), 0}), 24)
- builder.AddPrefix(prefix)
- }
- set, _ := builder.IPSet()
- return set
- }
- func requireStringSliceEqual(t *testing.T, expected, actual []string) {
- t.Helper()
- if len(expected) == 0 && len(actual) == 0 {
- return
- }
- require.Equal(t, expected, actual)
- }
- func requireUint8SliceEqual(t *testing.T, expected, actual []uint8) {
- t.Helper()
- if len(expected) == 0 && len(actual) == 0 {
- return
- }
- require.Equal(t, expected, actual)
- }
- func requireUint16SliceEqual(t *testing.T, expected, actual []uint16) {
- t.Helper()
- if len(expected) == 0 && len(actual) == 0 {
- return
- }
- require.Equal(t, expected, actual)
- }
- func requireIPSetEqual(t *testing.T, expected, actual *netipx.IPSet) {
- t.Helper()
- expectedRanges := expected.Ranges()
- actualRanges := actual.Ranges()
- require.Equal(t, len(expectedRanges), len(actualRanges), "range count mismatch")
- for i := range expectedRanges {
- require.Equal(t, expectedRanges[i].From(), actualRanges[i].From(), "range[%d].from mismatch", i)
- require.Equal(t, expectedRanges[i].To(), actualRanges[i].To(), "range[%d].to mismatch", i)
- }
- }
|