netstack_test.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. // Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package netstack
  5. import (
  6. "runtime"
  7. "testing"
  8. "inet.af/netaddr"
  9. "tailscale.com/net/packet"
  10. "tailscale.com/net/tsdial"
  11. "tailscale.com/net/tstun"
  12. "tailscale.com/wgengine"
  13. "tailscale.com/wgengine/filter"
  14. )
  15. // TestInjectInboundLeak tests that injectInbound doesn't leak memory.
  16. // See https://github.com/tailscale/tailscale/issues/3762
  17. func TestInjectInboundLeak(t *testing.T) {
  18. tunDev := tstun.NewFake()
  19. dialer := new(tsdial.Dialer)
  20. logf := func(format string, args ...interface{}) {
  21. if !t.Failed() {
  22. t.Logf(format, args...)
  23. }
  24. }
  25. eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
  26. Tun: tunDev,
  27. Dialer: dialer,
  28. })
  29. if err != nil {
  30. t.Fatal(err)
  31. }
  32. defer eng.Close()
  33. ig, ok := eng.(wgengine.InternalsGetter)
  34. if !ok {
  35. t.Fatal("not an InternalsGetter")
  36. }
  37. tunWrap, magicSock, ok := ig.GetInternals()
  38. if !ok {
  39. t.Fatal("failed to get internals")
  40. }
  41. ns, err := Create(logf, tunWrap, eng, magicSock, dialer)
  42. if err != nil {
  43. t.Fatal(err)
  44. }
  45. defer ns.Close()
  46. ns.ProcessLocalIPs = true
  47. if err := ns.Start(); err != nil {
  48. t.Fatalf("Start: %v", err)
  49. }
  50. ns.atomicIsLocalIPFunc.Store(func(netaddr.IP) bool { return true })
  51. pkt := &packet.Parsed{}
  52. const N = 10_000
  53. ms0 := getMemStats()
  54. for i := 0; i < N; i++ {
  55. outcome := ns.injectInbound(pkt, tunWrap)
  56. if outcome != filter.DropSilently {
  57. t.Fatalf("got outcome %v; want DropSilently", outcome)
  58. }
  59. }
  60. ms1 := getMemStats()
  61. if grew := int64(ms1.HeapObjects) - int64(ms0.HeapObjects); grew >= N {
  62. t.Fatalf("grew by %v (which is too much and >= the %v packets we sent)", grew, N)
  63. }
  64. }
  65. func getMemStats() (ms runtime.MemStats) {
  66. runtime.GC()
  67. runtime.ReadMemStats(&ms)
  68. return
  69. }