device_stack.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. //go:build with_gvisor
  2. package wireguard
  3. import (
  4. "context"
  5. "net"
  6. "net/netip"
  7. "os"
  8. E "github.com/sagernet/sing/common/exceptions"
  9. M "github.com/sagernet/sing/common/metadata"
  10. N "github.com/sagernet/sing/common/network"
  11. "github.com/sagernet/wireguard-go/tun"
  12. "gvisor.dev/gvisor/pkg/bufferv2"
  13. "gvisor.dev/gvisor/pkg/tcpip"
  14. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  15. "gvisor.dev/gvisor/pkg/tcpip/header"
  16. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  17. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  18. "gvisor.dev/gvisor/pkg/tcpip/stack"
  19. "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  20. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  21. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  22. )
  23. var _ Device = (*StackDevice)(nil)
  24. const defaultNIC tcpip.NICID = 1
  25. type StackDevice struct {
  26. stack *stack.Stack
  27. mtu uint32
  28. events chan tun.Event
  29. outbound chan *stack.PacketBuffer
  30. done chan struct{}
  31. dispatcher stack.NetworkDispatcher
  32. addr4 tcpip.Address
  33. addr6 tcpip.Address
  34. }
  35. func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
  36. ipStack := stack.New(stack.Options{
  37. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  38. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
  39. HandleLocal: true,
  40. })
  41. tunDevice := &StackDevice{
  42. stack: ipStack,
  43. mtu: mtu,
  44. events: make(chan tun.Event, 1),
  45. outbound: make(chan *stack.PacketBuffer, 256),
  46. done: make(chan struct{}),
  47. }
  48. err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
  49. if err != nil {
  50. return nil, E.New(err.String())
  51. }
  52. for _, prefix := range localAddresses {
  53. addr := tcpip.Address(prefix.Addr().AsSlice())
  54. protoAddr := tcpip.ProtocolAddress{
  55. AddressWithPrefix: tcpip.AddressWithPrefix{
  56. Address: addr,
  57. PrefixLen: prefix.Bits(),
  58. },
  59. }
  60. if prefix.Addr().Is4() {
  61. tunDevice.addr4 = addr
  62. protoAddr.Protocol = ipv4.ProtocolNumber
  63. } else {
  64. tunDevice.addr6 = addr
  65. protoAddr.Protocol = ipv6.ProtocolNumber
  66. }
  67. err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
  68. if err != nil {
  69. return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
  70. }
  71. }
  72. sOpt := tcpip.TCPSACKEnabled(true)
  73. ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
  74. cOpt := tcpip.CongestionControlOption("cubic")
  75. ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
  76. ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
  77. ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
  78. return tunDevice, nil
  79. }
  80. func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
  81. return (*wireEndpoint)(w), nil
  82. }
  83. func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  84. addr := tcpip.FullAddress{
  85. NIC: defaultNIC,
  86. Port: destination.Port,
  87. Addr: tcpip.Address(destination.Addr.AsSlice()),
  88. }
  89. bind := tcpip.FullAddress{
  90. NIC: defaultNIC,
  91. }
  92. var networkProtocol tcpip.NetworkProtocolNumber
  93. if destination.IsIPv4() {
  94. networkProtocol = header.IPv4ProtocolNumber
  95. bind.Addr = w.addr4
  96. } else {
  97. networkProtocol = header.IPv6ProtocolNumber
  98. bind.Addr = w.addr6
  99. }
  100. switch N.NetworkName(network) {
  101. case N.NetworkTCP:
  102. return gonet.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
  103. case N.NetworkUDP:
  104. return gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
  105. default:
  106. return nil, E.Extend(N.ErrUnknownNetwork, network)
  107. }
  108. }
  109. func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  110. bind := tcpip.FullAddress{
  111. NIC: defaultNIC,
  112. }
  113. var networkProtocol tcpip.NetworkProtocolNumber
  114. if destination.IsIPv4() || w.addr6 == "" {
  115. networkProtocol = header.IPv4ProtocolNumber
  116. bind.Addr = w.addr4
  117. } else {
  118. networkProtocol = header.IPv6ProtocolNumber
  119. bind.Addr = w.addr6
  120. }
  121. return gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
  122. }
  123. func (w *StackDevice) Start() error {
  124. w.events <- tun.EventUp
  125. return nil
  126. }
  127. func (w *StackDevice) File() *os.File {
  128. return nil
  129. }
  130. func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
  131. select {
  132. case packetBuffer, ok := <-w.outbound:
  133. if !ok {
  134. return 0, os.ErrClosed
  135. }
  136. defer packetBuffer.DecRef()
  137. p = p[offset:]
  138. for _, slice := range packetBuffer.AsSlices() {
  139. n += copy(p[n:], slice)
  140. }
  141. return
  142. case <-w.done:
  143. return 0, os.ErrClosed
  144. }
  145. }
  146. func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
  147. p = p[offset:]
  148. if len(p) == 0 {
  149. return
  150. }
  151. var networkProtocol tcpip.NetworkProtocolNumber
  152. switch header.IPVersion(p) {
  153. case header.IPv4Version:
  154. networkProtocol = header.IPv4ProtocolNumber
  155. case header.IPv6Version:
  156. networkProtocol = header.IPv6ProtocolNumber
  157. }
  158. packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
  159. Payload: bufferv2.MakeWithData(p),
  160. })
  161. defer packetBuffer.DecRef()
  162. w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
  163. n = len(p)
  164. return
  165. }
  166. func (w *StackDevice) Flush() error {
  167. return nil
  168. }
  169. func (w *StackDevice) MTU() (int, error) {
  170. return int(w.mtu), nil
  171. }
  172. func (w *StackDevice) Name() (string, error) {
  173. return "sing-box", nil
  174. }
  175. func (w *StackDevice) Events() chan tun.Event {
  176. return w.events
  177. }
  178. func (w *StackDevice) Close() error {
  179. select {
  180. case <-w.done:
  181. return os.ErrClosed
  182. default:
  183. }
  184. w.stack.Close()
  185. for _, endpoint := range w.stack.CleanupEndpoints() {
  186. endpoint.Abort()
  187. }
  188. w.stack.Wait()
  189. close(w.done)
  190. return nil
  191. }
  192. var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
  193. type wireEndpoint StackDevice
  194. func (ep *wireEndpoint) MTU() uint32 {
  195. return ep.mtu
  196. }
  197. func (ep *wireEndpoint) MaxHeaderLength() uint16 {
  198. return 0
  199. }
  200. func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
  201. return ""
  202. }
  203. func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
  204. return stack.CapabilityNone
  205. }
  206. func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
  207. ep.dispatcher = dispatcher
  208. }
  209. func (ep *wireEndpoint) IsAttached() bool {
  210. return ep.dispatcher != nil
  211. }
  212. func (ep *wireEndpoint) Wait() {
  213. }
  214. func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
  215. return header.ARPHardwareNone
  216. }
  217. func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
  218. }
  219. func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
  220. for _, packetBuffer := range list.AsSlice() {
  221. packetBuffer.IncRef()
  222. select {
  223. case <-ep.done:
  224. return 0, &tcpip.ErrClosedForSend{}
  225. case ep.outbound <- packetBuffer:
  226. }
  227. }
  228. return list.Len(), nil
  229. }