stack_gvisor.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. package tun
  2. import (
  3. "context"
  4. "time"
  5. "github.com/xtls/xray-core/common/errors"
  6. "github.com/xtls/xray-core/common/net"
  7. "gvisor.dev/gvisor/pkg/buffer"
  8. "gvisor.dev/gvisor/pkg/tcpip"
  9. "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
  10. "gvisor.dev/gvisor/pkg/tcpip/checksum"
  11. "gvisor.dev/gvisor/pkg/tcpip/header"
  12. "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
  13. "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
  14. "gvisor.dev/gvisor/pkg/tcpip/stack"
  15. "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
  16. "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
  17. "gvisor.dev/gvisor/pkg/waiter"
  18. )
  19. const (
  20. defaultNIC tcpip.NICID = 1
  21. tcpRXBufMinSize = tcp.MinBufferSize
  22. tcpRXBufDefSize = tcp.DefaultSendBufferSize
  23. tcpRXBufMaxSize = 8 << 20 // 8MiB
  24. tcpTXBufMinSize = tcp.MinBufferSize
  25. tcpTXBufDefSize = tcp.DefaultReceiveBufferSize
  26. tcpTXBufMaxSize = 6 << 20 // 6MiB
  27. )
  28. // stackGVisor is ip stack implemented by gVisor package
  29. type stackGVisor struct {
  30. ctx context.Context
  31. tun GVisorTun
  32. idleTimeout time.Duration
  33. handler *Handler
  34. stack *stack.Stack
  35. endpoint stack.LinkEndpoint
  36. }
  37. // GVisorTun implements a bridge to connect gVisor ip stack to tun interface
  38. type GVisorTun interface {
  39. newEndpoint() (stack.LinkEndpoint, error)
  40. }
  41. // NewStack builds new ip stack (using gVisor)
  42. func NewStack(ctx context.Context, options StackOptions, handler *Handler) (Stack, error) {
  43. gStack := &stackGVisor{
  44. ctx: ctx,
  45. tun: options.Tun.(GVisorTun),
  46. idleTimeout: options.IdleTimeout,
  47. handler: handler,
  48. }
  49. return gStack, nil
  50. }
  51. // Start is called by Handler to bring stack to life
  52. func (t *stackGVisor) Start() error {
  53. linkEndpoint, err := t.tun.newEndpoint()
  54. if err != nil {
  55. return err
  56. }
  57. ipStack, err := createStack(linkEndpoint)
  58. if err != nil {
  59. return err
  60. }
  61. tcpForwarder := tcp.NewForwarder(ipStack, 0, 65535, func(r *tcp.ForwarderRequest) {
  62. go func(r *tcp.ForwarderRequest) {
  63. var wq waiter.Queue
  64. var id = r.ID()
  65. // Perform a TCP three-way handshake.
  66. ep, err := r.CreateEndpoint(&wq)
  67. if err != nil {
  68. errors.LogError(t.ctx, err.String())
  69. r.Complete(true)
  70. return
  71. }
  72. options := ep.SocketOptions()
  73. options.SetKeepAlive(false)
  74. options.SetReuseAddress(true)
  75. options.SetReusePort(true)
  76. t.handler.HandleConnection(
  77. gonet.NewTCPConn(&wq, ep),
  78. // local address on the gVisor side is connection destination
  79. net.TCPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort)),
  80. )
  81. // close the socket
  82. ep.Close()
  83. // send connection complete upstream
  84. r.Complete(false)
  85. }(r)
  86. })
  87. ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
  88. // Use custom UDP packet handler, instead of strict gVisor forwarder, for FullCone NAT support
  89. udpForwarder := newUdpConnectionHandler(t.handler.HandleConnection, t.writeRawUDPPacket)
  90. ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
  91. data := pkt.Data().AsRange().ToSlice()
  92. if len(data) == 0 {
  93. return false
  94. }
  95. // source/destination of the packet we process as incoming, on gVisor side are Remote/Local
  96. // in other terms, src is the side behind tun, dst is the side behind gVisor
  97. // this function handle packets passing from the tun to the gVisor, therefore the src/dst assignement
  98. src := net.UDPDestination(net.IPAddress(id.RemoteAddress.AsSlice()), net.Port(id.RemotePort))
  99. dst := net.UDPDestination(net.IPAddress(id.LocalAddress.AsSlice()), net.Port(id.LocalPort))
  100. return udpForwarder.HandlePacket(src, dst, data)
  101. })
  102. t.stack = ipStack
  103. t.endpoint = linkEndpoint
  104. return nil
  105. }
  106. func (t *stackGVisor) writeRawUDPPacket(payload []byte, src net.Destination, dst net.Destination) error {
  107. udpLen := header.UDPMinimumSize + len(payload)
  108. srcIP := tcpip.AddrFromSlice(src.Address.IP())
  109. dstIP := tcpip.AddrFromSlice(dst.Address.IP())
  110. // build packet with appropriate IP header size
  111. isIPv4 := dst.Address.Family().IsIPv4()
  112. ipHdrSize := header.IPv6MinimumSize
  113. ipProtocol := header.IPv6ProtocolNumber
  114. if isIPv4 {
  115. ipHdrSize = header.IPv4MinimumSize
  116. ipProtocol = header.IPv4ProtocolNumber
  117. }
  118. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
  119. ReserveHeaderBytes: ipHdrSize + header.UDPMinimumSize,
  120. Payload: buffer.MakeWithData(payload),
  121. })
  122. defer pkt.DecRef()
  123. // Build UDP header
  124. udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
  125. udpHdr.Encode(&header.UDPFields{
  126. SrcPort: uint16(src.Port),
  127. DstPort: uint16(dst.Port),
  128. Length: uint16(udpLen),
  129. })
  130. // Calculate and set UDP checksum
  131. xsum := header.PseudoHeaderChecksum(header.UDPProtocolNumber, srcIP, dstIP, uint16(udpLen))
  132. udpHdr.SetChecksum(^udpHdr.CalculateChecksum(checksum.Checksum(payload, xsum)))
  133. // Build IP header
  134. if isIPv4 {
  135. ipHdr := header.IPv4(pkt.NetworkHeader().Push(header.IPv4MinimumSize))
  136. ipHdr.Encode(&header.IPv4Fields{
  137. TotalLength: uint16(header.IPv4MinimumSize + udpLen),
  138. TTL: 64,
  139. Protocol: uint8(header.UDPProtocolNumber),
  140. SrcAddr: srcIP,
  141. DstAddr: dstIP,
  142. })
  143. ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
  144. } else {
  145. ipHdr := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
  146. ipHdr.Encode(&header.IPv6Fields{
  147. PayloadLength: uint16(udpLen),
  148. TransportProtocol: header.UDPProtocolNumber,
  149. HopLimit: 64,
  150. SrcAddr: srcIP,
  151. DstAddr: dstIP,
  152. })
  153. }
  154. // dispatch the packet
  155. err := t.stack.WriteRawPacket(defaultNIC, ipProtocol, buffer.MakeWithView(pkt.ToView()))
  156. if err != nil {
  157. return errors.New("failed to write raw udp packet back to stack", err)
  158. }
  159. return nil
  160. }
  161. // Close is called by Handler to shut down the stack
  162. func (t *stackGVisor) Close() error {
  163. if t.stack == nil {
  164. return nil
  165. }
  166. t.endpoint.Attach(nil)
  167. t.stack.Close()
  168. for _, endpoint := range t.stack.CleanupEndpoints() {
  169. endpoint.Abort()
  170. }
  171. return nil
  172. }
  173. // createStack configure gVisor ip stack
  174. func createStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
  175. opts := stack.Options{
  176. NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
  177. TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
  178. HandleLocal: false,
  179. }
  180. gStack := stack.New(opts)
  181. err := gStack.CreateNIC(defaultNIC, ep)
  182. if err != nil {
  183. return nil, errors.New(err.String())
  184. }
  185. gStack.SetRouteTable([]tcpip.Route{
  186. {Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
  187. {Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
  188. })
  189. err = gStack.SetSpoofing(defaultNIC, true)
  190. if err != nil {
  191. return nil, errors.New(err.String())
  192. }
  193. err = gStack.SetPromiscuousMode(defaultNIC, true)
  194. if err != nil {
  195. return nil, errors.New(err.String())
  196. }
  197. cOpt := tcpip.CongestionControlOption("cubic")
  198. gStack.SetTransportProtocolOption(tcp.ProtocolNumber, &cOpt)
  199. sOpt := tcpip.TCPSACKEnabled(true)
  200. gStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
  201. mOpt := tcpip.TCPModerateReceiveBufferOption(true)
  202. gStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
  203. tcpRXBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{
  204. Min: tcpRXBufMinSize,
  205. Default: tcpRXBufDefSize,
  206. Max: tcpRXBufMaxSize,
  207. }
  208. err = gStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRXBufOpt)
  209. if err != nil {
  210. return nil, errors.New(err.String())
  211. }
  212. tcpTXBufOpt := tcpip.TCPSendBufferSizeRangeOption{
  213. Min: tcpTXBufMinSize,
  214. Default: tcpTXBufDefSize,
  215. Max: tcpTXBufMaxSize,
  216. }
  217. err = gStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTXBufOpt)
  218. if err != nil {
  219. return nil, errors.New(err.String())
  220. }
  221. return gStack, nil
  222. }