server.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. package wireguard
  2. import (
  3. "context"
  4. goerrors "errors"
  5. "io"
  6. "sync"
  7. "github.com/xtls/xray-core/common"
  8. "github.com/xtls/xray-core/common/buf"
  9. c "github.com/xtls/xray-core/common/ctx"
  10. "github.com/xtls/xray-core/common/errors"
  11. "github.com/xtls/xray-core/common/log"
  12. "github.com/xtls/xray-core/common/net"
  13. "github.com/xtls/xray-core/common/session"
  14. "github.com/xtls/xray-core/common/signal"
  15. "github.com/xtls/xray-core/common/task"
  16. "github.com/xtls/xray-core/core"
  17. "github.com/xtls/xray-core/features/dns"
  18. "github.com/xtls/xray-core/features/policy"
  19. "github.com/xtls/xray-core/features/routing"
  20. "github.com/xtls/xray-core/transport/internet/stat"
  21. )
  22. var nullDestination = net.TCPDestination(net.AnyIP, 0)
  23. type Server struct {
  24. bindServer *netBindServer
  25. infoMu sync.RWMutex
  26. info routingInfo
  27. policyManager policy.Manager
  28. }
  29. type routingInfo struct {
  30. ctx context.Context
  31. dispatcher routing.Dispatcher
  32. inboundTag *session.Inbound
  33. contentTag *session.Content
  34. }
  35. func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
  36. v := core.MustFromContext(ctx)
  37. endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
  38. if err != nil {
  39. return nil, err
  40. }
  41. server := &Server{
  42. bindServer: &netBindServer{
  43. netBind: netBind{
  44. dns: v.GetFeature(dns.ClientType()).(dns.Client),
  45. dnsOption: dns.IPOption{
  46. IPv4Enable: hasIPv4,
  47. IPv6Enable: hasIPv6,
  48. },
  49. },
  50. },
  51. policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
  52. }
  53. tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
  54. if err != nil {
  55. return nil, err
  56. }
  57. if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
  58. _ = tun.Close()
  59. return nil, err
  60. }
  61. return server, nil
  62. }
  63. // Network implements proxy.Inbound.
  64. func (*Server) Network() []net.Network {
  65. return []net.Network{net.Network_UDP}
  66. }
  67. // Process implements proxy.Inbound.
  68. func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
  69. s.infoMu.Lock()
  70. s.info = routingInfo{
  71. ctx: ctx,
  72. dispatcher: dispatcher,
  73. inboundTag: session.InboundFromContext(ctx),
  74. contentTag: session.ContentFromContext(ctx),
  75. }
  76. s.infoMu.Unlock()
  77. ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
  78. if err != nil {
  79. return err
  80. }
  81. nep := ep.(*netEndpoint)
  82. nep.conn = conn
  83. reader := buf.NewPacketReader(conn)
  84. for {
  85. mpayload, err := reader.ReadMultiBuffer()
  86. if err != nil {
  87. return err
  88. }
  89. for _, payload := range mpayload {
  90. v, ok := <-s.bindServer.readQueue
  91. if !ok {
  92. return nil
  93. }
  94. i, err := payload.Read(v.buff)
  95. v.bytes = i
  96. v.endpoint = nep
  97. v.err = err
  98. v.waiter.Done()
  99. if err != nil && goerrors.Is(err, io.EOF) {
  100. nep.conn = nil
  101. return nil
  102. }
  103. }
  104. }
  105. }
  106. func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
  107. // Make a thread-safe copy of routing info
  108. s.infoMu.RLock()
  109. info := s.info
  110. s.infoMu.RUnlock()
  111. if info.dispatcher == nil {
  112. errors.LogError(info.ctx, "unexpected: dispatcher == nil")
  113. return
  114. }
  115. defer conn.Close()
  116. ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(info.ctx))
  117. sid := session.NewID()
  118. ctx = c.ContextWithID(ctx, sid)
  119. inbound := session.Inbound{} // since promiscuousModeHandler mixed-up context, we shallow copy inbound (tag) and content (configs)
  120. if info.inboundTag != nil {
  121. inbound = *info.inboundTag
  122. }
  123. inbound.Name = "wireguard"
  124. inbound.CanSpliceCopy = 3
  125. // overwrite the source to use the tun address for each sub context.
  126. // Since gvisor.ForwarderRequest doesn't provide any info to associate the sub-context with the Parent context
  127. // Currently we have no way to link to the original source address
  128. inbound.Source = net.DestinationFromAddr(conn.RemoteAddr())
  129. ctx = session.ContextWithInbound(ctx, &inbound)
  130. if info.contentTag != nil {
  131. ctx = session.ContextWithContent(ctx, info.contentTag)
  132. }
  133. ctx = session.SubContextFromMuxInbound(ctx)
  134. plcy := s.policyManager.ForLevel(0)
  135. timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
  136. ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
  137. From: nullDestination,
  138. To: dest,
  139. Status: log.AccessAccepted,
  140. Reason: "",
  141. })
  142. // Set inbound and content tags from routing info for proper routing
  143. // These were commented out in PR #4030 but are needed for domain-based routing
  144. if info.inboundTag != nil {
  145. ctx = session.ContextWithInbound(ctx, info.inboundTag)
  146. }
  147. if info.contentTag != nil {
  148. ctx = session.ContextWithContent(ctx, info.contentTag)
  149. }
  150. link, err := info.dispatcher.Dispatch(ctx, dest)
  151. if err != nil {
  152. errors.LogErrorInner(ctx, err, "dispatch connection")
  153. }
  154. defer cancel()
  155. requestDone := func() error {
  156. defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
  157. if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
  158. return errors.New("failed to transport all TCP request").Base(err)
  159. }
  160. return nil
  161. }
  162. responseDone := func() error {
  163. defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
  164. if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
  165. return errors.New("failed to transport all TCP response").Base(err)
  166. }
  167. return nil
  168. }
  169. requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
  170. if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
  171. common.Interrupt(link.Reader)
  172. common.Interrupt(link.Writer)
  173. errors.LogDebugInner(ctx, err, "connection ends")
  174. return
  175. }
  176. }