device_stack.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. //go:build with_gvisor
  2. package wireguard
  3. import (
  4. "context"
  5. "net"
  6. "net/netip"
  7. "os"
  8. E "github.com/sagernet/sing/common/exceptions"
  9. M "github.com/sagernet/sing/common/metadata"
  10. N "github.com/sagernet/sing/common/network"
  11. "golang.zx2c4.com/wireguard/tun"
  12. "gvisor.dev/gvisor/pkg/bufferv2"
  13. "gvisor.dev/gvisor/pkg/tcpip"
  14. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  15. "gvisor.dev/gvisor/pkg/tcpip/header"
  16. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  17. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  18. "gvisor.dev/gvisor/pkg/tcpip/stack"
  19. "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  20. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  21. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  22. )
  23. var _ Device = (*StackDevice)(nil)
  24. const defaultNIC tcpip.NICID = 1
  25. type StackDevice struct {
  26. stack *stack.Stack
  27. mtu uint32
  28. events chan tun.Event
  29. outbound chan *stack.PacketBuffer
  30. dispatcher stack.NetworkDispatcher
  31. done chan struct{}
  32. addr4 tcpip.Address
  33. addr6 tcpip.Address
  34. }
  35. func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
  36. ipStack := stack.New(stack.Options{
  37. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  38. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
  39. HandleLocal: true,
  40. })
  41. tunDevice := &StackDevice{
  42. stack: ipStack,
  43. mtu: mtu,
  44. events: make(chan tun.Event, 1),
  45. outbound: make(chan *stack.PacketBuffer, 256),
  46. done: make(chan struct{}),
  47. }
  48. err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
  49. if err != nil {
  50. return nil, E.New(err.String())
  51. }
  52. for _, prefix := range localAddresses {
  53. addr := tcpip.Address(prefix.Addr().AsSlice())
  54. protoAddr := tcpip.ProtocolAddress{
  55. AddressWithPrefix: tcpip.AddressWithPrefix{
  56. Address: addr,
  57. PrefixLen: prefix.Bits(),
  58. },
  59. }
  60. if prefix.Addr().Is4() {
  61. tunDevice.addr4 = addr
  62. protoAddr.Protocol = ipv4.ProtocolNumber
  63. } else {
  64. tunDevice.addr6 = addr
  65. protoAddr.Protocol = ipv6.ProtocolNumber
  66. }
  67. err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
  68. if err != nil {
  69. return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
  70. }
  71. }
  72. sOpt := tcpip.TCPSACKEnabled(true)
  73. ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
  74. cOpt := tcpip.CongestionControlOption("cubic")
  75. ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
  76. ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
  77. ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
  78. return tunDevice, nil
  79. }
  80. func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
  81. return (*wireEndpoint)(w), nil
  82. }
  83. func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  84. addr := tcpip.FullAddress{
  85. NIC: defaultNIC,
  86. Port: destination.Port,
  87. Addr: tcpip.Address(destination.Addr.AsSlice()),
  88. }
  89. bind := tcpip.FullAddress{
  90. NIC: defaultNIC,
  91. }
  92. var networkProtocol tcpip.NetworkProtocolNumber
  93. if destination.IsIPv4() {
  94. networkProtocol = header.IPv4ProtocolNumber
  95. bind.Addr = w.addr4
  96. } else {
  97. networkProtocol = header.IPv6ProtocolNumber
  98. bind.Addr = w.addr6
  99. }
  100. switch N.NetworkName(network) {
  101. case N.NetworkTCP:
  102. return gonet.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
  103. case N.NetworkUDP:
  104. return gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
  105. default:
  106. return nil, E.Extend(N.ErrUnknownNetwork, network)
  107. }
  108. }
  109. func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  110. bind := tcpip.FullAddress{
  111. NIC: defaultNIC,
  112. }
  113. var networkProtocol tcpip.NetworkProtocolNumber
  114. if destination.IsIPv4() || w.addr6 == "" {
  115. networkProtocol = header.IPv4ProtocolNumber
  116. bind.Addr = w.addr4
  117. } else {
  118. networkProtocol = header.IPv6ProtocolNumber
  119. bind.Addr = w.addr6
  120. }
  121. return gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
  122. }
  123. func (w *StackDevice) Start() error {
  124. w.events <- tun.EventUp
  125. return nil
  126. }
  127. func (w *StackDevice) File() *os.File {
  128. return nil
  129. }
  130. func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
  131. packetBuffer, ok := <-w.outbound
  132. if !ok {
  133. return 0, os.ErrClosed
  134. }
  135. defer packetBuffer.DecRef()
  136. p = p[offset:]
  137. for _, slice := range packetBuffer.AsSlices() {
  138. n += copy(p[n:], slice)
  139. }
  140. return
  141. }
  142. func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
  143. p = p[offset:]
  144. if len(p) == 0 {
  145. return
  146. }
  147. var networkProtocol tcpip.NetworkProtocolNumber
  148. switch header.IPVersion(p) {
  149. case header.IPv4Version:
  150. networkProtocol = header.IPv4ProtocolNumber
  151. case header.IPv6Version:
  152. networkProtocol = header.IPv6ProtocolNumber
  153. }
  154. packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
  155. Payload: bufferv2.MakeWithData(p),
  156. })
  157. defer packetBuffer.DecRef()
  158. w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
  159. n = len(p)
  160. return
  161. }
  162. func (w *StackDevice) Flush() error {
  163. return nil
  164. }
  165. func (w *StackDevice) MTU() (int, error) {
  166. return int(w.mtu), nil
  167. }
  168. func (w *StackDevice) Name() (string, error) {
  169. return "sing-box", nil
  170. }
  171. func (w *StackDevice) Events() chan tun.Event {
  172. return w.events
  173. }
  174. func (w *StackDevice) Close() error {
  175. select {
  176. case <-w.done:
  177. return os.ErrClosed
  178. default:
  179. }
  180. close(w.done)
  181. w.stack.Close()
  182. for _, endpoint := range w.stack.CleanupEndpoints() {
  183. endpoint.Abort()
  184. }
  185. w.stack.Wait()
  186. close(w.outbound)
  187. return nil
  188. }
  189. var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
  190. type wireEndpoint StackDevice
  191. func (ep *wireEndpoint) MTU() uint32 {
  192. return ep.mtu
  193. }
  194. func (ep *wireEndpoint) MaxHeaderLength() uint16 {
  195. return 0
  196. }
  197. func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
  198. return ""
  199. }
  200. func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
  201. return stack.CapabilityNone
  202. }
  203. func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
  204. ep.dispatcher = dispatcher
  205. }
  206. func (ep *wireEndpoint) IsAttached() bool {
  207. return ep.dispatcher != nil
  208. }
  209. func (ep *wireEndpoint) Wait() {
  210. }
  211. func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
  212. return header.ARPHardwareNone
  213. }
  214. func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
  215. }
  216. func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
  217. for _, packetBuffer := range list.AsSlice() {
  218. packetBuffer.IncRef()
  219. ep.outbound <- packetBuffer
  220. }
  221. return list.Len(), nil
  222. }