dispatcher.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package udp
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "sync"
  7. "time"
  8. "github.com/xtls/xray-core/common"
  9. "github.com/xtls/xray-core/common/buf"
  10. "github.com/xtls/xray-core/common/net"
  11. "github.com/xtls/xray-core/common/protocol/udp"
  12. "github.com/xtls/xray-core/common/session"
  13. "github.com/xtls/xray-core/common/signal"
  14. "github.com/xtls/xray-core/common/signal/done"
  15. "github.com/xtls/xray-core/features/routing"
  16. "github.com/xtls/xray-core/transport"
  17. )
  18. type ResponseCallback func(ctx context.Context, packet *udp.Packet)
  19. type connEntry struct {
  20. link *transport.Link
  21. timer signal.ActivityUpdater
  22. cancel context.CancelFunc
  23. }
  24. type Dispatcher struct {
  25. sync.RWMutex
  26. conns map[net.Destination]*connEntry
  27. dispatcher routing.Dispatcher
  28. callback ResponseCallback
  29. callClose func() error
  30. }
  31. func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
  32. return &Dispatcher{
  33. conns: make(map[net.Destination]*connEntry),
  34. dispatcher: dispatcher,
  35. callback: callback,
  36. }
  37. }
  38. func (v *Dispatcher) RemoveRay(dest net.Destination) {
  39. v.Lock()
  40. defer v.Unlock()
  41. if conn, found := v.conns[dest]; found {
  42. common.Close(conn.link.Reader)
  43. common.Close(conn.link.Writer)
  44. delete(v.conns, dest)
  45. }
  46. }
  47. func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*connEntry, error) {
  48. v.Lock()
  49. defer v.Unlock()
  50. if entry, found := v.conns[dest]; found {
  51. return entry, nil
  52. }
  53. newError("establishing new connection for ", dest).WriteToLog()
  54. ctx, cancel := context.WithCancel(ctx)
  55. removeRay := func() {
  56. cancel()
  57. v.RemoveRay(dest)
  58. }
  59. timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
  60. link, err := v.dispatcher.Dispatch(ctx, dest)
  61. if err != nil {
  62. return nil, newError("failed to dispatch request to ", dest).Base(err)
  63. }
  64. entry := &connEntry{
  65. link: link,
  66. timer: timer,
  67. cancel: removeRay,
  68. }
  69. v.conns[dest] = entry
  70. go handleInput(ctx, entry, dest, v.callback, v.callClose)
  71. return entry, nil
  72. }
  73. func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
  74. // TODO: Add user to destString
  75. newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx))
  76. conn, err := v.getInboundRay(ctx, destination)
  77. if err != nil {
  78. newError("failed to get inbound").Base(err).WriteToLog(session.ExportIDToError(ctx))
  79. return
  80. }
  81. outputStream := conn.link.Writer
  82. if outputStream != nil {
  83. if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
  84. newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
  85. conn.cancel()
  86. return
  87. }
  88. }
  89. }
  90. func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
  91. defer func() {
  92. conn.cancel()
  93. if callClose != nil {
  94. callClose()
  95. }
  96. }()
  97. input := conn.link.Reader
  98. timer := conn.timer
  99. for {
  100. select {
  101. case <-ctx.Done():
  102. return
  103. default:
  104. }
  105. mb, err := input.ReadMultiBuffer()
  106. if err != nil {
  107. if !errors.Is(err, io.EOF) {
  108. newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx))
  109. }
  110. return
  111. }
  112. timer.Update()
  113. for _, b := range mb {
  114. callback(ctx, &udp.Packet{
  115. Payload: b,
  116. Source: dest,
  117. })
  118. }
  119. }
  120. }
  121. type dispatcherConn struct {
  122. dispatcher *Dispatcher
  123. cache chan *udp.Packet
  124. done *done.Instance
  125. }
  126. func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) {
  127. c := &dispatcherConn{
  128. cache: make(chan *udp.Packet, 16),
  129. done: done.New(),
  130. }
  131. d := &Dispatcher{
  132. conns: make(map[net.Destination]*connEntry),
  133. dispatcher: dispatcher,
  134. callback: c.callback,
  135. callClose: c.Close,
  136. }
  137. c.dispatcher = d
  138. return c, nil
  139. }
  140. func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
  141. select {
  142. case <-c.done.Wait():
  143. packet.Payload.Release()
  144. return
  145. case c.cache <- packet:
  146. default:
  147. packet.Payload.Release()
  148. return
  149. }
  150. }
  151. func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
  152. var packet *udp.Packet
  153. s:
  154. select {
  155. case <-c.done.Wait():
  156. select {
  157. case packet = <-c.cache:
  158. break s
  159. default:
  160. return 0, nil, io.EOF
  161. }
  162. case packet = <-c.cache:
  163. }
  164. return copy(p, packet.Payload.Bytes()), &net.UDPAddr{
  165. IP: packet.Source.Address.IP(),
  166. Port: int(packet.Source.Port),
  167. }, nil
  168. }
  169. func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
  170. buffer := buf.New()
  171. raw := buffer.Extend(buf.Size)
  172. n := copy(raw, p)
  173. buffer.Resize(0, int32(n))
  174. ctx := context.Background()
  175. c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer)
  176. return n, nil
  177. }
  178. func (c *dispatcherConn) Close() error {
  179. return c.done.Close()
  180. }
  181. func (c *dispatcherConn) LocalAddr() net.Addr {
  182. return &net.UDPAddr{
  183. IP: []byte{0, 0, 0, 0},
  184. Port: 0,
  185. }
  186. }
  187. func (c *dispatcherConn) SetDeadline(t time.Time) error {
  188. return nil
  189. }
  190. func (c *dispatcherConn) SetReadDeadline(t time.Time) error {
  191. return nil
  192. }
  193. func (c *dispatcherConn) SetWriteDeadline(t time.Time) error {
  194. return nil
  195. }