conn_linux_test.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. // Copyright (c) Tailscale Inc & AUTHORS
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. package batching
  4. import (
  5. "encoding/binary"
  6. "net"
  7. "testing"
  8. "github.com/tailscale/wireguard-go/conn"
  9. "golang.org/x/net/ipv6"
  10. "tailscale.com/net/packet"
  11. )
  12. func setGSOSize(control *[]byte, gsoSize uint16) {
  13. *control = (*control)[:cap(*control)]
  14. binary.LittleEndian.PutUint16(*control, gsoSize)
  15. }
  16. func getGSOSize(control []byte) (int, error) {
  17. if len(control) < 2 {
  18. return 0, nil
  19. }
  20. return int(binary.LittleEndian.Uint16(control)), nil
  21. }
  22. func Test_linuxBatchingConn_splitCoalescedMessages(t *testing.T) {
  23. c := &linuxBatchingConn{
  24. setGSOSizeInControl: setGSOSize,
  25. getGSOSizeFromControl: getGSOSize,
  26. }
  27. newMsg := func(n, gso int) ipv6.Message {
  28. msg := ipv6.Message{
  29. Buffers: [][]byte{make([]byte, 1024)},
  30. N: n,
  31. OOB: make([]byte, 2),
  32. }
  33. binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
  34. if gso > 0 {
  35. msg.NN = 2
  36. }
  37. return msg
  38. }
  39. cases := []struct {
  40. name string
  41. msgs []ipv6.Message
  42. firstMsgAt int
  43. wantNumEval int
  44. wantMsgLens []int
  45. wantErr bool
  46. }{
  47. {
  48. name: "second last split last empty",
  49. msgs: []ipv6.Message{
  50. newMsg(0, 0),
  51. newMsg(0, 0),
  52. newMsg(3, 1),
  53. newMsg(0, 0),
  54. },
  55. firstMsgAt: 2,
  56. wantNumEval: 3,
  57. wantMsgLens: []int{1, 1, 1, 0},
  58. wantErr: false,
  59. },
  60. {
  61. name: "second last no split last empty",
  62. msgs: []ipv6.Message{
  63. newMsg(0, 0),
  64. newMsg(0, 0),
  65. newMsg(1, 0),
  66. newMsg(0, 0),
  67. },
  68. firstMsgAt: 2,
  69. wantNumEval: 1,
  70. wantMsgLens: []int{1, 0, 0, 0},
  71. wantErr: false,
  72. },
  73. {
  74. name: "second last no split last no split",
  75. msgs: []ipv6.Message{
  76. newMsg(0, 0),
  77. newMsg(0, 0),
  78. newMsg(1, 0),
  79. newMsg(1, 0),
  80. },
  81. firstMsgAt: 2,
  82. wantNumEval: 2,
  83. wantMsgLens: []int{1, 1, 0, 0},
  84. wantErr: false,
  85. },
  86. {
  87. name: "second last no split last split",
  88. msgs: []ipv6.Message{
  89. newMsg(0, 0),
  90. newMsg(0, 0),
  91. newMsg(1, 0),
  92. newMsg(3, 1),
  93. },
  94. firstMsgAt: 2,
  95. wantNumEval: 4,
  96. wantMsgLens: []int{1, 1, 1, 1},
  97. wantErr: false,
  98. },
  99. {
  100. name: "second last split last split",
  101. msgs: []ipv6.Message{
  102. newMsg(0, 0),
  103. newMsg(0, 0),
  104. newMsg(2, 1),
  105. newMsg(2, 1),
  106. },
  107. firstMsgAt: 2,
  108. wantNumEval: 4,
  109. wantMsgLens: []int{1, 1, 1, 1},
  110. wantErr: false,
  111. },
  112. {
  113. name: "second last no split last split overflow",
  114. msgs: []ipv6.Message{
  115. newMsg(0, 0),
  116. newMsg(0, 0),
  117. newMsg(1, 0),
  118. newMsg(4, 1),
  119. },
  120. firstMsgAt: 2,
  121. wantNumEval: 4,
  122. wantMsgLens: []int{1, 1, 1, 1},
  123. wantErr: true,
  124. },
  125. }
  126. for _, tt := range cases {
  127. t.Run(tt.name, func(t *testing.T) {
  128. got, err := c.splitCoalescedMessages(tt.msgs, 2)
  129. if err != nil && !tt.wantErr {
  130. t.Fatalf("err: %v", err)
  131. }
  132. if got != tt.wantNumEval {
  133. t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
  134. }
  135. for i, msg := range tt.msgs {
  136. if msg.N != tt.wantMsgLens[i] {
  137. t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
  138. }
  139. }
  140. })
  141. }
  142. }
  143. func Test_linuxBatchingConn_coalesceMessages(t *testing.T) {
  144. c := &linuxBatchingConn{
  145. setGSOSizeInControl: setGSOSize,
  146. getGSOSizeFromControl: getGSOSize,
  147. }
  148. withGeneveSpace := func(len, cap int) []byte {
  149. return make([]byte, len+packet.GeneveFixedHeaderLength, cap+packet.GeneveFixedHeaderLength)
  150. }
  151. geneve := packet.GeneveHeader{
  152. Protocol: packet.GeneveProtocolWireGuard,
  153. }
  154. geneve.VNI.Set(1)
  155. cases := []struct {
  156. name string
  157. buffs [][]byte
  158. geneve packet.GeneveHeader
  159. wantLens []int
  160. wantGSO []int
  161. }{
  162. {
  163. name: "one message no coalesce",
  164. buffs: [][]byte{
  165. withGeneveSpace(1, 1),
  166. },
  167. wantLens: []int{1},
  168. wantGSO: []int{0},
  169. },
  170. {
  171. name: "one message no coalesce vni.isSet",
  172. buffs: [][]byte{
  173. withGeneveSpace(1, 1),
  174. },
  175. geneve: geneve,
  176. wantLens: []int{1 + packet.GeneveFixedHeaderLength},
  177. wantGSO: []int{0},
  178. },
  179. {
  180. name: "two messages equal len coalesce",
  181. buffs: [][]byte{
  182. withGeneveSpace(1, 2),
  183. withGeneveSpace(1, 1),
  184. },
  185. wantLens: []int{2},
  186. wantGSO: []int{1},
  187. },
  188. {
  189. name: "two messages equal len coalesce vni.isSet",
  190. buffs: [][]byte{
  191. withGeneveSpace(1, 2+packet.GeneveFixedHeaderLength),
  192. withGeneveSpace(1, 1),
  193. },
  194. geneve: geneve,
  195. wantLens: []int{2 + (2 * packet.GeneveFixedHeaderLength)},
  196. wantGSO: []int{1 + packet.GeneveFixedHeaderLength},
  197. },
  198. {
  199. name: "two messages unequal len coalesce",
  200. buffs: [][]byte{
  201. withGeneveSpace(2, 3),
  202. withGeneveSpace(1, 1),
  203. },
  204. wantLens: []int{3},
  205. wantGSO: []int{2},
  206. },
  207. {
  208. name: "two messages unequal len coalesce vni.isSet",
  209. buffs: [][]byte{
  210. withGeneveSpace(2, 3+packet.GeneveFixedHeaderLength),
  211. withGeneveSpace(1, 1),
  212. },
  213. geneve: geneve,
  214. wantLens: []int{3 + (2 * packet.GeneveFixedHeaderLength)},
  215. wantGSO: []int{2 + packet.GeneveFixedHeaderLength},
  216. },
  217. {
  218. name: "three messages second unequal len coalesce",
  219. buffs: [][]byte{
  220. withGeneveSpace(2, 3),
  221. withGeneveSpace(1, 1),
  222. withGeneveSpace(2, 2),
  223. },
  224. wantLens: []int{3, 2},
  225. wantGSO: []int{2, 0},
  226. },
  227. {
  228. name: "three messages second unequal len coalesce vni.isSet",
  229. buffs: [][]byte{
  230. withGeneveSpace(2, 3+(2*packet.GeneveFixedHeaderLength)),
  231. withGeneveSpace(1, 1),
  232. withGeneveSpace(2, 2),
  233. },
  234. geneve: geneve,
  235. wantLens: []int{3 + (2 * packet.GeneveFixedHeaderLength), 2 + packet.GeneveFixedHeaderLength},
  236. wantGSO: []int{2 + packet.GeneveFixedHeaderLength, 0},
  237. },
  238. {
  239. name: "three messages limited cap coalesce",
  240. buffs: [][]byte{
  241. withGeneveSpace(2, 4),
  242. withGeneveSpace(2, 2),
  243. withGeneveSpace(2, 2),
  244. },
  245. wantLens: []int{4, 2},
  246. wantGSO: []int{2, 0},
  247. },
  248. {
  249. name: "three messages limited cap coalesce vni.isSet",
  250. buffs: [][]byte{
  251. withGeneveSpace(2, 4+packet.GeneveFixedHeaderLength),
  252. withGeneveSpace(2, 2),
  253. withGeneveSpace(2, 2),
  254. },
  255. geneve: geneve,
  256. wantLens: []int{4 + (2 * packet.GeneveFixedHeaderLength), 2 + packet.GeneveFixedHeaderLength},
  257. wantGSO: []int{2 + packet.GeneveFixedHeaderLength, 0},
  258. },
  259. }
  260. for _, tt := range cases {
  261. t.Run(tt.name, func(t *testing.T) {
  262. addr := &net.UDPAddr{
  263. IP: net.ParseIP("127.0.0.1"),
  264. Port: 1,
  265. }
  266. msgs := make([]ipv6.Message, len(tt.buffs))
  267. for i := range msgs {
  268. msgs[i].Buffers = make([][]byte, 1)
  269. msgs[i].OOB = make([]byte, 0, 2)
  270. }
  271. got := c.coalesceMessages(addr, tt.geneve, tt.buffs, msgs, packet.GeneveFixedHeaderLength)
  272. if got != len(tt.wantLens) {
  273. t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
  274. }
  275. for i := range got {
  276. if msgs[i].Addr != addr {
  277. t.Errorf("msgs[%d].Addr != passed addr", i)
  278. }
  279. gotLen := len(msgs[i].Buffers[0])
  280. if gotLen != tt.wantLens[i] {
  281. t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
  282. }
  283. gotGSO, err := getGSOSize(msgs[i].OOB)
  284. if err != nil {
  285. t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
  286. }
  287. if gotGSO != tt.wantGSO[i] {
  288. t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
  289. }
  290. }
  291. })
  292. }
  293. }
  294. func TestMinReadBatchMsgsLen(t *testing.T) {
  295. // So long as magicsock uses [Conn], and [wireguard-go/conn.Bind] API is
  296. // shaped for wireguard-go to control packet memory, these values should be
  297. // aligned.
  298. if IdealBatchSize != conn.IdealBatchSize {
  299. t.Fatalf("IdealBatchSize: %d != conn.IdealBatchSize(): %d", IdealBatchSize, conn.IdealBatchSize)
  300. }
  301. }