dispatcher.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package udp
  2. import (
  3. "context"
  4. goerrors "errors"
  5. "io"
  6. "sync"
  7. "time"
  8. "github.com/xtls/xray-core/common"
  9. "github.com/xtls/xray-core/common/buf"
  10. "github.com/xtls/xray-core/common/errors"
  11. "github.com/xtls/xray-core/common/net"
  12. "github.com/xtls/xray-core/common/protocol/udp"
  13. "github.com/xtls/xray-core/common/signal"
  14. "github.com/xtls/xray-core/common/signal/done"
  15. "github.com/xtls/xray-core/features/routing"
  16. "github.com/xtls/xray-core/transport"
  17. )
  18. type ResponseCallback func(ctx context.Context, packet *udp.Packet)
  19. type connEntry struct {
  20. link *transport.Link
  21. timer signal.ActivityUpdater
  22. cancel context.CancelFunc
  23. }
  24. type Dispatcher struct {
  25. sync.RWMutex
  26. conn *connEntry
  27. dispatcher routing.Dispatcher
  28. callback ResponseCallback
  29. callClose func() error
  30. }
  31. func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
  32. return &Dispatcher{
  33. dispatcher: dispatcher,
  34. callback: callback,
  35. }
  36. }
  37. func (v *Dispatcher) RemoveRay() {
  38. v.Lock()
  39. defer v.Unlock()
  40. v.removeRay()
  41. }
  42. func (v *Dispatcher) removeRay() {
  43. if v.conn != nil {
  44. common.Interrupt(v.conn.link.Reader)
  45. common.Close(v.conn.link.Writer)
  46. v.conn = nil
  47. }
  48. }
  49. func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*connEntry, error) {
  50. v.Lock()
  51. defer v.Unlock()
  52. if v.conn != nil {
  53. return v.conn, nil
  54. }
  55. errors.LogInfo(ctx, "establishing new connection for ", dest)
  56. ctx, cancel := context.WithCancel(ctx)
  57. entry := &connEntry{}
  58. removeRay := func() {
  59. v.Lock()
  60. defer v.Unlock()
  61. // sometimes the entry is already removed by others, don't close again
  62. if entry == v.conn {
  63. cancel()
  64. v.removeRay()
  65. }
  66. }
  67. timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
  68. link, err := v.dispatcher.Dispatch(ctx, dest)
  69. if err != nil {
  70. return nil, errors.New("failed to dispatch request to ", dest).Base(err)
  71. }
  72. *entry = connEntry{
  73. link: link,
  74. timer: timer,
  75. cancel: removeRay,
  76. }
  77. v.conn = entry
  78. go handleInput(ctx, entry, dest, v.callback, v.callClose)
  79. return entry, nil
  80. }
  81. func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
  82. // TODO: Add user to destString
  83. errors.LogDebug(ctx, "dispatch request to: ", destination)
  84. conn, err := v.getInboundRay(ctx, destination)
  85. if err != nil {
  86. errors.LogInfoInner(ctx, err, "failed to get inbound")
  87. return
  88. }
  89. outputStream := conn.link.Writer
  90. if outputStream != nil {
  91. if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
  92. errors.LogInfoInner(ctx, err, "failed to write first UDP payload")
  93. conn.cancel()
  94. return
  95. }
  96. }
  97. }
  98. func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
  99. defer func() {
  100. conn.cancel()
  101. if callClose != nil {
  102. callClose()
  103. }
  104. }()
  105. input := conn.link.Reader
  106. timer := conn.timer
  107. for {
  108. select {
  109. case <-ctx.Done():
  110. return
  111. default:
  112. }
  113. mb, err := input.ReadMultiBuffer()
  114. if err != nil {
  115. if !goerrors.Is(err, io.EOF) {
  116. errors.LogInfoInner(ctx, err, "failed to handle UDP input")
  117. }
  118. return
  119. }
  120. timer.Update()
  121. for _, b := range mb {
  122. if b.UDP != nil {
  123. dest = *b.UDP
  124. }
  125. callback(ctx, &udp.Packet{
  126. Payload: b,
  127. Source: dest,
  128. })
  129. }
  130. }
  131. }
  132. type dispatcherConn struct {
  133. dispatcher *Dispatcher
  134. cache chan *udp.Packet
  135. done *done.Instance
  136. ctx context.Context
  137. }
  138. func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) {
  139. c := &dispatcherConn{
  140. cache: make(chan *udp.Packet, 16),
  141. done: done.New(),
  142. ctx: ctx,
  143. }
  144. d := &Dispatcher{
  145. dispatcher: dispatcher,
  146. callback: c.callback,
  147. callClose: c.Close,
  148. }
  149. c.dispatcher = d
  150. return c, nil
  151. }
  152. func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
  153. select {
  154. case <-c.done.Wait():
  155. packet.Payload.Release()
  156. return
  157. case c.cache <- packet:
  158. default:
  159. packet.Payload.Release()
  160. return
  161. }
  162. }
  163. func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
  164. var packet *udp.Packet
  165. s:
  166. select {
  167. case <-c.done.Wait():
  168. select {
  169. case packet = <-c.cache:
  170. break s
  171. default:
  172. return 0, nil, io.EOF
  173. }
  174. case packet = <-c.cache:
  175. }
  176. return copy(p, packet.Payload.Bytes()), &net.UDPAddr{
  177. IP: packet.Source.Address.IP(),
  178. Port: int(packet.Source.Port),
  179. }, nil
  180. }
  181. func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
  182. buffer := buf.New()
  183. raw := buffer.Extend(buf.Size)
  184. n := copy(raw, p)
  185. buffer.Resize(0, int32(n))
  186. destination := net.DestinationFromAddr(addr)
  187. buffer.UDP = &destination
  188. c.dispatcher.Dispatch(c.ctx, destination, buffer)
  189. return n, nil
  190. }
  191. func (c *dispatcherConn) Close() error {
  192. return c.done.Close()
  193. }
  194. func (c *dispatcherConn) LocalAddr() net.Addr {
  195. return &net.UDPAddr{
  196. IP: []byte{0, 0, 0, 0},
  197. Port: 0,
  198. }
  199. }
  200. func (c *dispatcherConn) SetDeadline(t time.Time) error {
  201. return nil
  202. }
  203. func (c *dispatcherConn) SetReadDeadline(t time.Time) error {
  204. return nil
  205. }
  206. func (c *dispatcherConn) SetWriteDeadline(t time.Time) error {
  207. return nil
  208. }