tun_device_unix.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. //go:build with_gvisor && !windows
  2. package tailscale
  3. import (
  4. "encoding/hex"
  5. "errors"
  6. "io"
  7. "os"
  8. "sync"
  9. "sync/atomic"
  10. singTun "github.com/sagernet/sing-tun"
  11. "github.com/sagernet/sing/common"
  12. "github.com/sagernet/sing/common/logger"
  13. wgTun "github.com/sagernet/wireguard-go/tun"
  14. )
  15. type tunDeviceAdapter struct {
  16. tun singTun.Tun
  17. linuxTUN singTun.LinuxTUN
  18. events chan wgTun.Event
  19. mtu int
  20. logger logger.ContextLogger
  21. debugTun bool
  22. readCount atomic.Uint32
  23. writeCount atomic.Uint32
  24. closeOnce sync.Once
  25. }
  26. func newTunDeviceAdapter(tun singTun.Tun, mtu int, logger logger.ContextLogger) (wgTun.Device, error) {
  27. if tun == nil {
  28. return nil, os.ErrInvalid
  29. }
  30. if mtu == 0 {
  31. mtu = 1500
  32. }
  33. adapter := &tunDeviceAdapter{
  34. tun: tun,
  35. events: make(chan wgTun.Event, 1),
  36. mtu: mtu,
  37. logger: logger,
  38. debugTun: os.Getenv("SINGBOX_TS_TUN_DEBUG") != "",
  39. }
  40. if linuxTUN, ok := tun.(singTun.LinuxTUN); ok {
  41. adapter.linuxTUN = linuxTUN
  42. }
  43. adapter.events <- wgTun.EventUp
  44. return adapter, nil
  45. }
  46. func (a *tunDeviceAdapter) File() *os.File {
  47. return nil
  48. }
  49. func (a *tunDeviceAdapter) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
  50. if a.linuxTUN != nil {
  51. n, err := a.linuxTUN.BatchRead(bufs, offset-singTun.PacketOffset, sizes)
  52. if err == nil {
  53. for i := 0; i < n; i++ {
  54. a.debugPacket("read", bufs[i][offset:offset+sizes[i]])
  55. }
  56. }
  57. return n, err
  58. }
  59. if offset < singTun.PacketOffset {
  60. return 0, io.ErrShortBuffer
  61. }
  62. readBuf := bufs[0][offset-singTun.PacketOffset:]
  63. n, err := a.tun.Read(readBuf)
  64. if err == nil {
  65. if n < singTun.PacketOffset {
  66. return 0, io.ErrUnexpectedEOF
  67. }
  68. sizes[0] = n - singTun.PacketOffset
  69. a.debugPacket("read", readBuf[singTun.PacketOffset:n])
  70. return 1, nil
  71. }
  72. if errors.Is(err, singTun.ErrTooManySegments) {
  73. err = wgTun.ErrTooManySegments
  74. }
  75. return 0, err
  76. }
  77. func (a *tunDeviceAdapter) Write(bufs [][]byte, offset int) (count int, err error) {
  78. if a.linuxTUN != nil {
  79. for i := range bufs {
  80. a.debugPacket("write", bufs[i][offset:])
  81. }
  82. return a.linuxTUN.BatchWrite(bufs, offset)
  83. }
  84. for _, packet := range bufs {
  85. a.debugPacket("write", packet[offset:])
  86. if singTun.PacketOffset > 0 {
  87. common.ClearArray(packet[offset-singTun.PacketOffset : offset])
  88. singTun.PacketFillHeader(packet[offset-singTun.PacketOffset:], singTun.PacketIPVersion(packet[offset:]))
  89. }
  90. _, err = a.tun.Write(packet[offset-singTun.PacketOffset:])
  91. if err != nil {
  92. return 0, err
  93. }
  94. }
  95. // WireGuard will not read count.
  96. return 0, nil
  97. }
  98. func (a *tunDeviceAdapter) MTU() (int, error) {
  99. return a.mtu, nil
  100. }
  101. func (a *tunDeviceAdapter) Name() (string, error) {
  102. return a.tun.Name()
  103. }
  104. func (a *tunDeviceAdapter) Events() <-chan wgTun.Event {
  105. return a.events
  106. }
  107. func (a *tunDeviceAdapter) Close() error {
  108. var err error
  109. a.closeOnce.Do(func() {
  110. close(a.events)
  111. err = a.tun.Close()
  112. })
  113. return err
  114. }
  115. func (a *tunDeviceAdapter) BatchSize() int {
  116. if a.linuxTUN != nil {
  117. return a.linuxTUN.BatchSize()
  118. }
  119. return 1
  120. }
  121. func (a *tunDeviceAdapter) debugPacket(direction string, packet []byte) {
  122. if !a.debugTun || a.logger == nil {
  123. return
  124. }
  125. var counter *atomic.Uint32
  126. switch direction {
  127. case "read":
  128. counter = &a.readCount
  129. case "write":
  130. counter = &a.writeCount
  131. default:
  132. return
  133. }
  134. if counter.Add(1) > 8 {
  135. return
  136. }
  137. sample := packet
  138. if len(sample) > 64 {
  139. sample = sample[:64]
  140. }
  141. a.logger.Trace("tailscale tun ", direction, " len=", len(packet), " head=", hex.EncodeToString(sample))
  142. }