filter_test.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. package windivert
  2. import (
  3. "encoding/binary"
  4. "net/netip"
  5. "testing"
  6. )
  7. func TestRejectFilter(t *testing.T) {
  8. t.Parallel()
  9. bin, flags, err := reject().encode()
  10. if err != nil {
  11. t.Fatal(err)
  12. }
  13. if len(bin) != filterInstBytes {
  14. t.Fatalf("reject filter len: got %d, want %d", len(bin), filterInstBytes)
  15. }
  16. if flags != 0 {
  17. t.Fatalf("reject filter flags: got %x, want 0", flags)
  18. }
  19. // word0: field=ZERO=0, test=EQ=0, success=REJECT=0x7FFF
  20. word0 := binary.LittleEndian.Uint32(bin[0:4])
  21. if word0 != uint32(resultReject)<<16 {
  22. t.Fatalf("reject word0 = %08x", word0)
  23. }
  24. // word1: failure=REJECT
  25. word1 := binary.LittleEndian.Uint32(bin[4:8])
  26. if word1 != uint32(resultReject) {
  27. t.Fatalf("reject word1 = %08x", word1)
  28. }
  29. }
  30. func TestOutboundTCPFilterIPv4(t *testing.T) {
  31. t.Parallel()
  32. src := netip.MustParseAddrPort("10.1.2.3:54321")
  33. dst := netip.MustParseAddrPort("1.2.3.4:443")
  34. f, err := OutboundTCP(src, dst)
  35. if err != nil {
  36. t.Fatal(err)
  37. }
  38. bin, flags, err := f.encode()
  39. if err != nil {
  40. t.Fatal(err)
  41. }
  42. if want := filterFlagOutbound | filterFlagIP; flags != want {
  43. t.Fatalf("flags: got %x, want %x", flags, want)
  44. }
  45. // 7 instructions: OUTBOUND, IP, TCP, IP_SRCADDR, IP_DSTADDR, TCP_SRCPORT, TCP_DSTPORT
  46. const wantInsts = 7
  47. if len(bin) != wantInsts*filterInstBytes {
  48. t.Fatalf("instruction count: got %d, want %d", len(bin)/filterInstBytes, wantInsts)
  49. }
  50. // Inst 0: OUTBOUND == 1, success=1, failure=REJECT
  51. checkInst(t, bin[0*filterInstBytes:], 0, fieldOutbound, testEQ, 1, resultReject, 1)
  52. // Inst 1: IP == 1, success=2
  53. checkInst(t, bin[1*filterInstBytes:], 1, fieldIP, testEQ, 2, resultReject, 1)
  54. // Inst 2: TCP == 1, success=3
  55. checkInst(t, bin[2*filterInstBytes:], 2, fieldTCP, testEQ, 3, resultReject, 1)
  56. // Inst 3: IP_SRCADDR == 10.1.2.3 (host-order uint32 = 0x0A010203, arg[1]=0x0000FFFF marker)
  57. checkInst(t, bin[3*filterInstBytes:], 3, fieldIPSrcAddr, testEQ, 4, resultReject, 0x0A010203)
  58. checkArg1(t, bin[3*filterInstBytes:], 3, 0x0000FFFF)
  59. // Inst 4: IP_DSTADDR == 1.2.3.4
  60. checkInst(t, bin[4*filterInstBytes:], 4, fieldIPDstAddr, testEQ, 5, resultReject, 0x01020304)
  61. checkArg1(t, bin[4*filterInstBytes:], 4, 0x0000FFFF)
  62. // Inst 5: TCP_SRCPORT == 54321
  63. checkInst(t, bin[5*filterInstBytes:], 5, fieldTCPSrcPort, testEQ, 6, resultReject, 54321)
  64. // Last inst 6: TCP_DSTPORT == 443, success=ACCEPT
  65. checkInst(t, bin[6*filterInstBytes:], 6, fieldTCPDstPort, testEQ, resultAccept, resultReject, 443)
  66. }
  67. func TestOutboundTCPFilterIPv6(t *testing.T) {
  68. t.Parallel()
  69. src := netip.MustParseAddrPort("[2001:db8::1]:54321")
  70. dst := netip.MustParseAddrPort("[2001:db8::2]:443")
  71. f, err := OutboundTCP(src, dst)
  72. if err != nil {
  73. t.Fatal(err)
  74. }
  75. bin, flags, err := f.encode()
  76. if err != nil {
  77. t.Fatal(err)
  78. }
  79. if want := filterFlagOutbound | filterFlagIPv6; flags != want {
  80. t.Fatalf("flags: got %x, want %x", flags, want)
  81. }
  82. // Inst 3: IPv6_SRCADDR. The driver stores the address in reversed
  83. // word order: arg[0]=low (bytes 12..15)=1, arg[3]=high (bytes 0..3)=0x20010db8.
  84. off := 3 * filterInstBytes
  85. a0 := binary.LittleEndian.Uint32(bin[off+8:])
  86. a1 := binary.LittleEndian.Uint32(bin[off+12:])
  87. a2 := binary.LittleEndian.Uint32(bin[off+16:])
  88. a3 := binary.LittleEndian.Uint32(bin[off+20:])
  89. if a0 != 1 || a1 != 0 || a2 != 0 || a3 != 0x20010db8 {
  90. t.Fatalf("ipv6 src arg=[%08x %08x %08x %08x], want [1 0 0 0x20010db8]", a0, a1, a2, a3)
  91. }
  92. }
  93. func TestOutboundTCPFilterMixedFamily(t *testing.T) {
  94. t.Parallel()
  95. src := netip.MustParseAddrPort("10.0.0.1:1234")
  96. dst := netip.MustParseAddrPort("[2001:db8::1]:443")
  97. if _, err := OutboundTCP(src, dst); err == nil {
  98. t.Fatal("expected error for mixed families")
  99. }
  100. }
  101. func checkArg1(t *testing.T, raw []byte, idx int, arg1 uint32) {
  102. t.Helper()
  103. got := binary.LittleEndian.Uint32(raw[12:16])
  104. if got != arg1 {
  105. t.Errorf("inst %d arg[1]: got %08x, want %08x", idx, got, arg1)
  106. }
  107. }
  108. func checkInst(t *testing.T, raw []byte, idx int, field uint16, test uint8, success, failure uint16, arg0 uint32) {
  109. t.Helper()
  110. word0 := binary.LittleEndian.Uint32(raw[0:4])
  111. word1 := binary.LittleEndian.Uint32(raw[4:8])
  112. a0 := binary.LittleEndian.Uint32(raw[8:12])
  113. gotField := uint16(word0 & 0x7FF)
  114. gotTest := uint8((word0 >> 11) & 0x1F)
  115. gotSuccess := uint16(word0 >> 16)
  116. gotFailure := uint16(word1 & 0xFFFF)
  117. if gotField != field {
  118. t.Errorf("inst %d field: got %d, want %d", idx, gotField, field)
  119. }
  120. if gotTest != test {
  121. t.Errorf("inst %d test: got %d, want %d", idx, gotTest, test)
  122. }
  123. if gotSuccess != success {
  124. t.Errorf("inst %d success: got %d, want %d", idx, gotSuccess, success)
  125. }
  126. if gotFailure != failure {
  127. t.Errorf("inst %d failure: got %d, want %d", idx, gotFailure, failure)
  128. }
  129. if a0 != arg0 {
  130. t.Errorf("inst %d arg[0]: got %08x, want %08x", idx, a0, arg0)
  131. }
  132. }