device_system_stack.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. //go:build with_gvisor
  2. package wireguard
  3. import (
  4. "net/netip"
  5. "github.com/sagernet/gvisor/pkg/buffer"
  6. "github.com/sagernet/gvisor/pkg/tcpip"
  7. "github.com/sagernet/gvisor/pkg/tcpip/header"
  8. "github.com/sagernet/gvisor/pkg/tcpip/stack"
  9. "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
  10. "github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
  11. "github.com/sagernet/sing-tun"
  12. "github.com/sagernet/sing/common"
  13. "github.com/sagernet/wireguard-go/device"
  14. )
  15. var _ Device = (*systemStackDevice)(nil)
  16. type systemStackDevice struct {
  17. *systemDevice
  18. stack *stack.Stack
  19. endpoint *deviceEndpoint
  20. writeBufs [][]byte
  21. }
  22. func newSystemStackDevice(options DeviceOptions) (*systemStackDevice, error) {
  23. system, err := newSystemDevice(options)
  24. if err != nil {
  25. return nil, err
  26. }
  27. endpoint := &deviceEndpoint{
  28. mtu: options.MTU,
  29. done: make(chan struct{}),
  30. }
  31. ipStack, err := tun.NewGVisorStack(endpoint)
  32. if err != nil {
  33. return nil, err
  34. }
  35. ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
  36. ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
  37. return &systemStackDevice{
  38. systemDevice: system,
  39. stack: ipStack,
  40. endpoint: endpoint,
  41. }, nil
  42. }
  43. func (w *systemStackDevice) SetDevice(device *device.Device) {
  44. w.endpoint.device = device
  45. }
  46. func (w *systemStackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
  47. if w.batchDevice != nil {
  48. w.writeBufs = w.writeBufs[:0]
  49. for _, packet := range bufs {
  50. if !w.writeStack(packet[offset:]) {
  51. w.writeBufs = append(w.writeBufs, packet)
  52. }
  53. }
  54. if len(w.writeBufs) > 0 {
  55. return w.batchDevice.BatchWrite(bufs, offset)
  56. }
  57. } else {
  58. for _, packet := range bufs {
  59. if !w.writeStack(packet[offset:]) {
  60. if tun.PacketOffset > 0 {
  61. common.ClearArray(packet[offset-tun.PacketOffset : offset])
  62. tun.PacketFillHeader(packet[offset-tun.PacketOffset:], tun.PacketIPVersion(packet[offset:]))
  63. }
  64. _, err = w.device.Write(packet[offset-tun.PacketOffset:])
  65. }
  66. if err != nil {
  67. return
  68. }
  69. }
  70. }
  71. // WireGuard will not read count
  72. return
  73. }
  74. func (w *systemStackDevice) Close() error {
  75. close(w.endpoint.done)
  76. w.stack.Close()
  77. for _, endpoint := range w.stack.CleanupEndpoints() {
  78. endpoint.Abort()
  79. }
  80. w.stack.Wait()
  81. return w.systemDevice.Close()
  82. }
  83. func (w *systemStackDevice) writeStack(packet []byte) bool {
  84. var (
  85. networkProtocol tcpip.NetworkProtocolNumber
  86. destination netip.Addr
  87. )
  88. switch header.IPVersion(packet) {
  89. case header.IPv4Version:
  90. networkProtocol = header.IPv4ProtocolNumber
  91. destination = netip.AddrFrom4(header.IPv4(packet).DestinationAddress().As4())
  92. case header.IPv6Version:
  93. networkProtocol = header.IPv6ProtocolNumber
  94. destination = netip.AddrFrom16(header.IPv6(packet).DestinationAddress().As16())
  95. }
  96. for _, prefix := range w.options.Address {
  97. if prefix.Contains(destination) {
  98. return false
  99. }
  100. }
  101. packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
  102. Payload: buffer.MakeWithData(packet),
  103. })
  104. w.endpoint.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
  105. packetBuffer.DecRef()
  106. return true
  107. }
  108. type deviceEndpoint struct {
  109. mtu uint32
  110. done chan struct{}
  111. device *device.Device
  112. dispatcher stack.NetworkDispatcher
  113. }
  114. func (ep *deviceEndpoint) MTU() uint32 {
  115. return ep.mtu
  116. }
  117. func (ep *deviceEndpoint) SetMTU(mtu uint32) {
  118. }
  119. func (ep *deviceEndpoint) MaxHeaderLength() uint16 {
  120. return 0
  121. }
  122. func (ep *deviceEndpoint) LinkAddress() tcpip.LinkAddress {
  123. return ""
  124. }
  125. func (ep *deviceEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
  126. }
  127. func (ep *deviceEndpoint) Capabilities() stack.LinkEndpointCapabilities {
  128. return stack.CapabilityRXChecksumOffload
  129. }
  130. func (ep *deviceEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
  131. ep.dispatcher = dispatcher
  132. }
  133. func (ep *deviceEndpoint) IsAttached() bool {
  134. return ep.dispatcher != nil
  135. }
  136. func (ep *deviceEndpoint) Wait() {
  137. }
  138. func (ep *deviceEndpoint) ARPHardwareType() header.ARPHardwareType {
  139. return header.ARPHardwareNone
  140. }
  141. func (ep *deviceEndpoint) AddHeader(buffer *stack.PacketBuffer) {
  142. }
  143. func (ep *deviceEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
  144. return true
  145. }
  146. func (ep *deviceEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
  147. for _, packetBuffer := range list.AsSlice() {
  148. destination := packetBuffer.Network().DestinationAddress()
  149. ep.device.InputPacket(destination.AsSlice(), packetBuffer.AsSlices())
  150. }
  151. return list.Len(), nil
  152. }
  153. func (ep *deviceEndpoint) Close() {
  154. }
  155. func (ep *deviceEndpoint) SetOnCloseAction(f func()) {
  156. }