device_stack.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. //go:build with_gvisor
  2. package wireguard
  3. import (
  4. "context"
  5. "net"
  6. "net/netip"
  7. "os"
  8. "github.com/sagernet/gvisor/pkg/buffer"
  9. "github.com/sagernet/gvisor/pkg/tcpip"
  10. "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
  11. "github.com/sagernet/gvisor/pkg/tcpip/header"
  12. "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
  13. "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
  14. "github.com/sagernet/gvisor/pkg/tcpip/stack"
  15. "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
  16. "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
  17. "github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
  18. "github.com/sagernet/sing-tun"
  19. "github.com/sagernet/sing/common/buf"
  20. E "github.com/sagernet/sing/common/exceptions"
  21. M "github.com/sagernet/sing/common/metadata"
  22. N "github.com/sagernet/sing/common/network"
  23. wgTun "github.com/sagernet/wireguard-go/tun"
  24. )
  25. var _ Device = (*StackDevice)(nil)
  26. const defaultNIC tcpip.NICID = 1
  27. type StackDevice struct {
  28. stack *stack.Stack
  29. mtu uint32
  30. events chan wgTun.Event
  31. outbound chan *stack.PacketBuffer
  32. packetOutbound chan *buf.Buffer
  33. done chan struct{}
  34. dispatcher stack.NetworkDispatcher
  35. addr4 tcpip.Address
  36. addr6 tcpip.Address
  37. }
  38. func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) {
  39. ipStack := stack.New(stack.Options{
  40. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  41. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
  42. HandleLocal: true,
  43. })
  44. tunDevice := &StackDevice{
  45. stack: ipStack,
  46. mtu: mtu,
  47. events: make(chan wgTun.Event, 1),
  48. outbound: make(chan *stack.PacketBuffer, 256),
  49. packetOutbound: make(chan *buf.Buffer, 256),
  50. done: make(chan struct{}),
  51. }
  52. err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice))
  53. if err != nil {
  54. return nil, E.New(err.String())
  55. }
  56. for _, prefix := range localAddresses {
  57. addr := tun.AddressFromAddr(prefix.Addr())
  58. protoAddr := tcpip.ProtocolAddress{
  59. AddressWithPrefix: tcpip.AddressWithPrefix{
  60. Address: addr,
  61. PrefixLen: prefix.Bits(),
  62. },
  63. }
  64. if prefix.Addr().Is4() {
  65. tunDevice.addr4 = addr
  66. protoAddr.Protocol = ipv4.ProtocolNumber
  67. } else {
  68. tunDevice.addr6 = addr
  69. protoAddr.Protocol = ipv6.ProtocolNumber
  70. }
  71. err = ipStack.AddProtocolAddress(defaultNIC, protoAddr, stack.AddressProperties{})
  72. if err != nil {
  73. return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String())
  74. }
  75. }
  76. sOpt := tcpip.TCPSACKEnabled(true)
  77. ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
  78. cOpt := tcpip.CongestionControlOption("cubic")
  79. ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
  80. ipStack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: defaultNIC})
  81. ipStack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: defaultNIC})
  82. return tunDevice, nil
  83. }
  84. func (w *StackDevice) NewEndpoint() (stack.LinkEndpoint, error) {
  85. return (*wireEndpoint)(w), nil
  86. }
  87. func (w *StackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
  88. addr := tcpip.FullAddress{
  89. NIC: defaultNIC,
  90. Port: destination.Port,
  91. Addr: tun.AddressFromAddr(destination.Addr),
  92. }
  93. bind := tcpip.FullAddress{
  94. NIC: defaultNIC,
  95. }
  96. var networkProtocol tcpip.NetworkProtocolNumber
  97. if destination.IsIPv4() {
  98. networkProtocol = header.IPv4ProtocolNumber
  99. bind.Addr = w.addr4
  100. } else {
  101. networkProtocol = header.IPv6ProtocolNumber
  102. bind.Addr = w.addr6
  103. }
  104. switch N.NetworkName(network) {
  105. case N.NetworkTCP:
  106. tcpConn, err := DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
  107. if err != nil {
  108. return nil, err
  109. }
  110. return tcpConn, nil
  111. case N.NetworkUDP:
  112. udpConn, err := gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
  113. if err != nil {
  114. return nil, err
  115. }
  116. return udpConn, nil
  117. default:
  118. return nil, E.Extend(N.ErrUnknownNetwork, network)
  119. }
  120. }
  121. func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
  122. bind := tcpip.FullAddress{
  123. NIC: defaultNIC,
  124. }
  125. var networkProtocol tcpip.NetworkProtocolNumber
  126. if destination.IsIPv4() {
  127. networkProtocol = header.IPv4ProtocolNumber
  128. bind.Addr = w.addr4
  129. } else {
  130. networkProtocol = header.IPv6ProtocolNumber
  131. bind.Addr = w.addr6
  132. }
  133. udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
  134. if err != nil {
  135. return nil, err
  136. }
  137. return udpConn, nil
  138. }
  139. func (w *StackDevice) Inet4Address() netip.Addr {
  140. return tun.AddrFromAddress(w.addr4)
  141. }
  142. func (w *StackDevice) Inet6Address() netip.Addr {
  143. return tun.AddrFromAddress(w.addr6)
  144. }
  145. func (w *StackDevice) Start() error {
  146. w.events <- wgTun.EventUp
  147. return nil
  148. }
  149. func (w *StackDevice) File() *os.File {
  150. return nil
  151. }
  152. func (w *StackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
  153. select {
  154. case packetBuffer, ok := <-w.outbound:
  155. if !ok {
  156. return 0, os.ErrClosed
  157. }
  158. defer packetBuffer.DecRef()
  159. p := bufs[0]
  160. p = p[offset:]
  161. n := 0
  162. for _, slice := range packetBuffer.AsSlices() {
  163. n += copy(p[n:], slice)
  164. }
  165. sizes[0] = n
  166. count = 1
  167. return
  168. case packet := <-w.packetOutbound:
  169. defer packet.Release()
  170. sizes[0] = copy(bufs[0][offset:], packet.Bytes())
  171. count = 1
  172. return
  173. case <-w.done:
  174. return 0, os.ErrClosed
  175. }
  176. }
  177. func (w *StackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
  178. for _, b := range bufs {
  179. b = b[offset:]
  180. if len(b) == 0 {
  181. continue
  182. }
  183. var networkProtocol tcpip.NetworkProtocolNumber
  184. switch header.IPVersion(b) {
  185. case header.IPv4Version:
  186. networkProtocol = header.IPv4ProtocolNumber
  187. case header.IPv6Version:
  188. networkProtocol = header.IPv6ProtocolNumber
  189. }
  190. packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
  191. Payload: buffer.MakeWithData(b),
  192. })
  193. w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
  194. packetBuffer.DecRef()
  195. count++
  196. }
  197. return
  198. }
  199. func (w *StackDevice) Flush() error {
  200. return nil
  201. }
  202. func (w *StackDevice) MTU() (int, error) {
  203. return int(w.mtu), nil
  204. }
  205. func (w *StackDevice) Name() (string, error) {
  206. return "sing-box", nil
  207. }
  208. func (w *StackDevice) Events() <-chan wgTun.Event {
  209. return w.events
  210. }
  211. func (w *StackDevice) Close() error {
  212. close(w.done)
  213. close(w.events)
  214. w.stack.Close()
  215. for _, endpoint := range w.stack.CleanupEndpoints() {
  216. endpoint.Abort()
  217. }
  218. w.stack.Wait()
  219. return nil
  220. }
  221. func (w *StackDevice) BatchSize() int {
  222. return 1
  223. }
  224. var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
  225. type wireEndpoint StackDevice
  226. func (ep *wireEndpoint) MTU() uint32 {
  227. return ep.mtu
  228. }
  229. func (ep *wireEndpoint) MaxHeaderLength() uint16 {
  230. return 0
  231. }
  232. func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
  233. return ""
  234. }
  235. func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
  236. return stack.CapabilityRXChecksumOffload
  237. }
  238. func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
  239. ep.dispatcher = dispatcher
  240. }
  241. func (ep *wireEndpoint) IsAttached() bool {
  242. return ep.dispatcher != nil
  243. }
  244. func (ep *wireEndpoint) Wait() {
  245. }
  246. func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
  247. return header.ARPHardwareNone
  248. }
  249. func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
  250. }
  251. func (ep *wireEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
  252. return true
  253. }
  254. func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
  255. for _, packetBuffer := range list.AsSlice() {
  256. packetBuffer.IncRef()
  257. select {
  258. case <-ep.done:
  259. return 0, &tcpip.ErrClosedForSend{}
  260. case ep.outbound <- packetBuffer:
  261. }
  262. }
  263. return list.Len(), nil
  264. }