device_stack.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. //go:build with_gvisor
  2. package wireguard
  3. import (
  4. "context"
  5. "net"
  6. "net/netip"
  7. "os"
  8. E "github.com/sagernet/sing/common/exceptions"
  9. M "github.com/sagernet/sing/common/metadata"
  10. N "github.com/sagernet/sing/common/network"
  11. "github.com/sagernet/wireguard-go/tun"
  12. "gvisor.dev/gvisor/pkg/bufferv2"
  13. "gvisor.dev/gvisor/pkg/tcpip"
  14. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  15. "gvisor.dev/gvisor/pkg/tcpip/header"
  16. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  17. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  18. "gvisor.dev/gvisor/pkg/tcpip/stack"
  19. "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
  20. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  21. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  22. )
  23. var _ Device = (*StackDevice)(nil)
  24. const defaultNIC tcpip.NICID = 1
  25. type StackDevice struct {
  26. stack *stack.Stack
  27. mtu uint32
  28. events chan tun.Event
  29. outbound chan *stack.PacketBuffer
  30. done chan struct{}
  31. dispatcher stack.NetworkDispatcher
  32. addr4 tcpip.Address
  33. addr6 tcpip.Address
  34. }
  35. func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
  36. ipStack := stack.New(stack.Options{
  37. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  38. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
  39. HandleLocal: true,
  40. })
  41. tunDevice := &StackDevice{
  42. stack: ipStack,
  43. mtu: mtu,
  44. events: make(chan tun.Event, 1),
  45. outbound: make(chan *stack.PacketBuffer, 256),
  46. done: make(chan struct{}),
  47. }
  48. err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
  49. if err != nil {
  50. return nil, E.New(err.String())
  51. }
  52. for _, prefix := range localAddresses {
  53. addr := tcpip.Address(prefix.Addr().AsSlice())
  54. protoAddr := tcpip.ProtocolAddress{
  55. AddressWithPrefix: tcpip.AddressWithPrefix{
  56. Address: addr,
  57. PrefixLen: prefix.Bits(),
  58. },
  59. }
  60. if prefix.Addr().Is4() {
  61. tunDevice.addr4 = addr
  62. protoAddr.Protocol = ipv4.ProtocolNumber
  63. } else {
  64. tunDevice.addr6 = addr
  65. protoAddr.Protocol = ipv6.ProtocolNumber
  66. }
  67. err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
  68. if err != nil {
  69. return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
  70. }
  71. }
  72. sOpt := tcpip.TCPSACKEnabled(true)
  73. ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
  74. cOpt := tcpip.CongestionControlOption("cubic")
  75. ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
  76. ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
  77. ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
  78. return tunDevice, nil
  79. }
  80. func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
  81. return (*wireEndpoint)(w), nil
  82. }
  83. func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  84. addr := tcpip.FullAddress{
  85. NIC: defaultNIC,
  86. Port: destination.Port,
  87. Addr: tcpip.Address(destination.Addr.AsSlice()),
  88. }
  89. bind := tcpip.FullAddress{
  90. NIC: defaultNIC,
  91. }
  92. var networkProtocol tcpip.NetworkProtocolNumber
  93. if destination.IsIPv4() {
  94. networkProtocol = header.IPv4ProtocolNumber
  95. bind.Addr = w.addr4
  96. } else {
  97. networkProtocol = header.IPv6ProtocolNumber
  98. bind.Addr = w.addr6
  99. }
  100. switch N.NetworkName(network) {
  101. case N.NetworkTCP:
  102. tcpConn, err := gonet.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
  103. if err != nil {
  104. return nil, err
  105. }
  106. return tcpConn, nil
  107. case N.NetworkUDP:
  108. udpConn, err := gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
  109. if err != nil {
  110. return nil, err
  111. }
  112. return udpConn, nil
  113. default:
  114. return nil, E.Extend(N.ErrUnknownNetwork, network)
  115. }
  116. }
  117. func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  118. bind := tcpip.FullAddress{
  119. NIC: defaultNIC,
  120. }
  121. var networkProtocol tcpip.NetworkProtocolNumber
  122. if destination.IsIPv4() || w.addr6 == "" {
  123. networkProtocol = header.IPv4ProtocolNumber
  124. bind.Addr = w.addr4
  125. } else {
  126. networkProtocol = header.IPv6ProtocolNumber
  127. bind.Addr = w.addr6
  128. }
  129. udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
  130. if err != nil {
  131. return nil, err
  132. }
  133. return udpConn, nil
  134. }
  135. func (w *StackDevice) Start() error {
  136. w.events <- tun.EventUp
  137. return nil
  138. }
  139. func (w *StackDevice) File() *os.File {
  140. return nil
  141. }
  142. func (w *StackDevice) Read(p []byte, offset int) (n int, err error) {
  143. select {
  144. case packetBuffer, ok := <-w.outbound:
  145. if !ok {
  146. return 0, os.ErrClosed
  147. }
  148. defer packetBuffer.DecRef()
  149. p = p[offset:]
  150. for _, slice := range packetBuffer.AsSlices() {
  151. n += copy(p[n:], slice)
  152. }
  153. return
  154. case <-w.done:
  155. return 0, os.ErrClosed
  156. }
  157. }
  158. func (w *StackDevice) Write(p []byte, offset int) (n int, err error) {
  159. p = p[offset:]
  160. if len(p) == 0 {
  161. return
  162. }
  163. var networkProtocol tcpip.NetworkProtocolNumber
  164. switch header.IPVersion(p) {
  165. case header.IPv4Version:
  166. networkProtocol = header.IPv4ProtocolNumber
  167. case header.IPv6Version:
  168. networkProtocol = header.IPv6ProtocolNumber
  169. }
  170. packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
  171. Payload: bufferv2.MakeWithData(p),
  172. })
  173. defer packetBuffer.DecRef()
  174. w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
  175. n = len(p)
  176. return
  177. }
  178. func (w *StackDevice) Flush() error {
  179. return nil
  180. }
  181. func (w *StackDevice) MTU() (int, error) {
  182. return int(w.mtu), nil
  183. }
  184. func (w *StackDevice) Name() (string, error) {
  185. return "sing-box", nil
  186. }
  187. func (w *StackDevice) Events() chan tun.Event {
  188. return w.events
  189. }
  190. func (w *StackDevice) Close() error {
  191. select {
  192. case <-w.done:
  193. return os.ErrClosed
  194. default:
  195. }
  196. w.stack.Close()
  197. for _, endpoint := range w.stack.CleanupEndpoints() {
  198. endpoint.Abort()
  199. }
  200. w.stack.Wait()
  201. close(w.done)
  202. return nil
  203. }
  204. var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
  205. type wireEndpoint StackDevice
  206. func (ep *wireEndpoint) MTU() uint32 {
  207. return ep.mtu
  208. }
  209. func (ep *wireEndpoint) MaxHeaderLength() uint16 {
  210. return 0
  211. }
  212. func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
  213. return ""
  214. }
  215. func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
  216. return stack.CapabilityNone
  217. }
  218. func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
  219. ep.dispatcher = dispatcher
  220. }
  221. func (ep *wireEndpoint) IsAttached() bool {
  222. return ep.dispatcher != nil
  223. }
  224. func (ep *wireEndpoint) Wait() {
  225. }
  226. func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
  227. return header.ARPHardwareNone
  228. }
  229. func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
  230. }
  231. func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
  232. for _, packetBuffer := range list.AsSlice() {
  233. packetBuffer.IncRef()
  234. select {
  235. case <-ep.done:
  236. return 0, &tcpip.ErrClosedForSend{}
  237. case ep.outbound <- packetBuffer:
  238. }
  239. }
  240. return list.Len(), nil
  241. }