dispatcher.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. package udp
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "sync"
  7. "time"
  8. "github.com/xtls/xray-core/common"
  9. "github.com/xtls/xray-core/common/buf"
  10. "github.com/xtls/xray-core/common/net"
  11. "github.com/xtls/xray-core/common/protocol/udp"
  12. "github.com/xtls/xray-core/common/session"
  13. "github.com/xtls/xray-core/common/signal"
  14. "github.com/xtls/xray-core/common/signal/done"
  15. "github.com/xtls/xray-core/features/routing"
  16. "github.com/xtls/xray-core/transport"
  17. )
  18. type ResponseCallback func(ctx context.Context, packet *udp.Packet)
  19. type connEntry struct {
  20. link *transport.Link
  21. timer signal.ActivityUpdater
  22. cancel context.CancelFunc
  23. }
  24. type Dispatcher struct {
  25. sync.RWMutex
  26. conn *connEntry
  27. dispatcher routing.Dispatcher
  28. callback ResponseCallback
  29. callClose func() error
  30. }
  31. func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
  32. return &Dispatcher{
  33. dispatcher: dispatcher,
  34. callback: callback,
  35. }
  36. }
  37. func (v *Dispatcher) RemoveRay() {
  38. v.Lock()
  39. defer v.Unlock()
  40. if v.conn != nil {
  41. common.Interrupt(v.conn.link.Reader)
  42. common.Close(v.conn.link.Writer)
  43. v.conn = nil
  44. }
  45. }
  46. func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*connEntry, error) {
  47. v.Lock()
  48. defer v.Unlock()
  49. if v.conn != nil {
  50. return v.conn, nil
  51. }
  52. newError("establishing new connection for ", dest).WriteToLog()
  53. ctx, cancel := context.WithCancel(ctx)
  54. removeRay := func() {
  55. cancel()
  56. v.RemoveRay()
  57. }
  58. timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
  59. link, err := v.dispatcher.Dispatch(ctx, dest)
  60. if err != nil {
  61. return nil, newError("failed to dispatch request to ", dest).Base(err)
  62. }
  63. entry := &connEntry{
  64. link: link,
  65. timer: timer,
  66. cancel: removeRay,
  67. }
  68. v.conn = entry
  69. go handleInput(ctx, entry, dest, v.callback, v.callClose)
  70. return entry, nil
  71. }
  72. func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
  73. // TODO: Add user to destString
  74. newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx))
  75. conn, err := v.getInboundRay(ctx, destination)
  76. if err != nil {
  77. newError("failed to get inbound").Base(err).WriteToLog(session.ExportIDToError(ctx))
  78. return
  79. }
  80. outputStream := conn.link.Writer
  81. if outputStream != nil {
  82. if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
  83. newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
  84. conn.cancel()
  85. return
  86. }
  87. }
  88. }
  89. func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
  90. defer func() {
  91. conn.cancel()
  92. if callClose != nil {
  93. callClose()
  94. }
  95. }()
  96. input := conn.link.Reader
  97. timer := conn.timer
  98. for {
  99. select {
  100. case <-ctx.Done():
  101. return
  102. default:
  103. }
  104. mb, err := input.ReadMultiBuffer()
  105. if err != nil {
  106. if !errors.Is(err, io.EOF) {
  107. newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx))
  108. }
  109. return
  110. }
  111. timer.Update()
  112. for _, b := range mb {
  113. if b.UDP != nil {
  114. dest = *b.UDP
  115. }
  116. callback(ctx, &udp.Packet{
  117. Payload: b,
  118. Source: dest,
  119. })
  120. }
  121. }
  122. }
  123. type dispatcherConn struct {
  124. dispatcher *Dispatcher
  125. cache chan *udp.Packet
  126. done *done.Instance
  127. ctx context.Context
  128. }
  129. func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) {
  130. c := &dispatcherConn{
  131. cache: make(chan *udp.Packet, 16),
  132. done: done.New(),
  133. ctx: ctx,
  134. }
  135. d := &Dispatcher{
  136. dispatcher: dispatcher,
  137. callback: c.callback,
  138. callClose: c.Close,
  139. }
  140. c.dispatcher = d
  141. return c, nil
  142. }
  143. func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
  144. select {
  145. case <-c.done.Wait():
  146. packet.Payload.Release()
  147. return
  148. case c.cache <- packet:
  149. default:
  150. packet.Payload.Release()
  151. return
  152. }
  153. }
  154. func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
  155. var packet *udp.Packet
  156. s:
  157. select {
  158. case <-c.done.Wait():
  159. select {
  160. case packet = <-c.cache:
  161. break s
  162. default:
  163. return 0, nil, io.EOF
  164. }
  165. case packet = <-c.cache:
  166. }
  167. return copy(p, packet.Payload.Bytes()), &net.UDPAddr{
  168. IP: packet.Source.Address.IP(),
  169. Port: int(packet.Source.Port),
  170. }, nil
  171. }
  172. func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
  173. buffer := buf.New()
  174. raw := buffer.Extend(buf.Size)
  175. n := copy(raw, p)
  176. buffer.Resize(0, int32(n))
  177. destination := net.DestinationFromAddr(addr)
  178. buffer.UDP = &destination
  179. c.dispatcher.Dispatch(c.ctx, destination, buffer)
  180. return n, nil
  181. }
  182. func (c *dispatcherConn) Close() error {
  183. return c.done.Close()
  184. }
  185. func (c *dispatcherConn) LocalAddr() net.Addr {
  186. return &net.UDPAddr{
  187. IP: []byte{0, 0, 0, 0},
  188. Port: 0,
  189. }
  190. }
  191. func (c *dispatcherConn) SetDeadline(t time.Time) error {
  192. return nil
  193. }
  194. func (c *dispatcherConn) SetReadDeadline(t time.Time) error {
  195. return nil
  196. }
  197. func (c *dispatcherConn) SetWriteDeadline(t time.Time) error {
  198. return nil
  199. }