server.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package wireguard
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "github.com/xtls/xray-core/common"
  7. "github.com/xtls/xray-core/common/buf"
  8. "github.com/xtls/xray-core/common/log"
  9. "github.com/xtls/xray-core/common/net"
  10. "github.com/xtls/xray-core/common/session"
  11. "github.com/xtls/xray-core/common/signal"
  12. "github.com/xtls/xray-core/common/task"
  13. "github.com/xtls/xray-core/core"
  14. "github.com/xtls/xray-core/features/dns"
  15. "github.com/xtls/xray-core/features/policy"
  16. "github.com/xtls/xray-core/features/routing"
  17. "github.com/xtls/xray-core/transport/internet/stat"
  18. )
  19. var nullDestination = net.TCPDestination(net.AnyIP, 0)
  20. type Server struct {
  21. bindServer *netBindServer
  22. info routingInfo
  23. policyManager policy.Manager
  24. }
  25. type routingInfo struct {
  26. ctx context.Context
  27. dispatcher routing.Dispatcher
  28. inboundTag *session.Inbound
  29. outboundTag *session.Outbound
  30. contentTag *session.Content
  31. }
  32. func NewServer(ctx context.Context, conf *DeviceConfig) (*Server, error) {
  33. v := core.MustFromContext(ctx)
  34. endpoints, hasIPv4, hasIPv6, err := parseEndpoints(conf)
  35. if err != nil {
  36. return nil, err
  37. }
  38. server := &Server{
  39. bindServer: &netBindServer{
  40. netBind: netBind{
  41. dns: v.GetFeature(dns.ClientType()).(dns.Client),
  42. dnsOption: dns.IPOption{
  43. IPv4Enable: hasIPv4,
  44. IPv6Enable: hasIPv6,
  45. },
  46. },
  47. },
  48. policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
  49. }
  50. tun, err := conf.createTun()(endpoints, int(conf.Mtu), server.forwardConnection)
  51. if err != nil {
  52. return nil, err
  53. }
  54. if err = tun.BuildDevice(createIPCRequest(conf), server.bindServer); err != nil {
  55. _ = tun.Close()
  56. return nil, err
  57. }
  58. return server, nil
  59. }
  60. // Network implements proxy.Inbound.
  61. func (*Server) Network() []net.Network {
  62. return []net.Network{net.Network_UDP}
  63. }
  64. // Process implements proxy.Inbound.
  65. func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
  66. inbound := session.InboundFromContext(ctx)
  67. inbound.Name = "wireguard"
  68. inbound.SetCanSpliceCopy(3)
  69. s.info = routingInfo{
  70. ctx: core.ToBackgroundDetachedContext(ctx),
  71. dispatcher: dispatcher,
  72. inboundTag: session.InboundFromContext(ctx),
  73. outboundTag: session.OutboundFromContext(ctx),
  74. contentTag: session.ContentFromContext(ctx),
  75. }
  76. ep, err := s.bindServer.ParseEndpoint(conn.RemoteAddr().String())
  77. if err != nil {
  78. return err
  79. }
  80. nep := ep.(*netEndpoint)
  81. nep.conn = conn
  82. reader := buf.NewPacketReader(conn)
  83. for {
  84. mpayload, err := reader.ReadMultiBuffer()
  85. if err != nil {
  86. return err
  87. }
  88. for _, payload := range mpayload {
  89. v, ok := <-s.bindServer.readQueue
  90. if !ok {
  91. return nil
  92. }
  93. i, err := payload.Read(v.buff)
  94. v.bytes = i
  95. v.endpoint = nep
  96. v.err = err
  97. v.waiter.Done()
  98. if err != nil && errors.Is(err, io.EOF) {
  99. nep.conn = nil
  100. return nil
  101. }
  102. }
  103. }
  104. }
  105. func (s *Server) forwardConnection(dest net.Destination, conn net.Conn) {
  106. if s.info.dispatcher == nil {
  107. newError("unexpected: dispatcher == nil").AtError().WriteToLog()
  108. return
  109. }
  110. defer conn.Close()
  111. ctx, cancel := context.WithCancel(core.ToBackgroundDetachedContext(s.info.ctx))
  112. plcy := s.policyManager.ForLevel(0)
  113. timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
  114. ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
  115. From: nullDestination,
  116. To: dest,
  117. Status: log.AccessAccepted,
  118. Reason: "",
  119. })
  120. if s.info.inboundTag != nil {
  121. ctx = session.ContextWithInbound(ctx, s.info.inboundTag)
  122. }
  123. if s.info.outboundTag != nil {
  124. ctx = session.ContextWithOutbound(ctx, s.info.outboundTag)
  125. }
  126. if s.info.contentTag != nil {
  127. ctx = session.ContextWithContent(ctx, s.info.contentTag)
  128. }
  129. link, err := s.info.dispatcher.Dispatch(ctx, dest)
  130. if err != nil {
  131. newError("dispatch connection").Base(err).AtError().WriteToLog()
  132. }
  133. defer cancel()
  134. requestDone := func() error {
  135. defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
  136. if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
  137. return newError("failed to transport all TCP request").Base(err)
  138. }
  139. return nil
  140. }
  141. responseDone := func() error {
  142. defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
  143. if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
  144. return newError("failed to transport all TCP response").Base(err)
  145. }
  146. return nil
  147. }
  148. requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
  149. if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
  150. common.Interrupt(link.Reader)
  151. common.Interrupt(link.Writer)
  152. newError("connection ends").Base(err).AtDebug().WriteToLog()
  153. return
  154. }
  155. }