worker.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. package inbound
  2. import (
  3. "context"
  4. "sync"
  5. "sync/atomic"
  6. "time"
  7. "github.com/xtls/xray-core/app/proxyman"
  8. "github.com/xtls/xray-core/common"
  9. "github.com/xtls/xray-core/common/buf"
  10. c "github.com/xtls/xray-core/common/ctx"
  11. "github.com/xtls/xray-core/common/errors"
  12. "github.com/xtls/xray-core/common/net"
  13. "github.com/xtls/xray-core/common/serial"
  14. "github.com/xtls/xray-core/common/session"
  15. "github.com/xtls/xray-core/common/signal/done"
  16. "github.com/xtls/xray-core/common/task"
  17. "github.com/xtls/xray-core/features/routing"
  18. "github.com/xtls/xray-core/features/stats"
  19. "github.com/xtls/xray-core/proxy"
  20. "github.com/xtls/xray-core/transport/internet"
  21. "github.com/xtls/xray-core/transport/internet/stat"
  22. "github.com/xtls/xray-core/transport/internet/tcp"
  23. "github.com/xtls/xray-core/transport/internet/udp"
  24. "github.com/xtls/xray-core/transport/pipe"
  25. )
  26. type worker interface {
  27. Start() error
  28. Close() error
  29. Port() net.Port
  30. Proxy() proxy.Inbound
  31. }
  32. type tcpWorker struct {
  33. address net.Address
  34. port net.Port
  35. proxy proxy.Inbound
  36. stream *internet.MemoryStreamConfig
  37. recvOrigDest bool
  38. tag string
  39. dispatcher routing.Dispatcher
  40. sniffingConfig *proxyman.SniffingConfig
  41. uplinkCounter stats.Counter
  42. downlinkCounter stats.Counter
  43. hub internet.Listener
  44. ctx context.Context
  45. }
  46. func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyMode {
  47. if s == nil || s.SocketSettings == nil {
  48. return internet.SocketConfig_Off
  49. }
  50. return s.SocketSettings.Tproxy
  51. }
  52. func (w *tcpWorker) callback(conn stat.Connection) {
  53. ctx, cancel := context.WithCancel(w.ctx)
  54. sid := session.NewID()
  55. ctx = c.ContextWithID(ctx, sid)
  56. outbounds := []*session.Outbound{{}}
  57. if w.recvOrigDest {
  58. var dest net.Destination
  59. switch getTProxyType(w.stream) {
  60. case internet.SocketConfig_Redirect:
  61. d, err := tcp.GetOriginalDestination(conn)
  62. if err != nil {
  63. errors.LogInfoInner(ctx, err, "failed to get original destination")
  64. } else {
  65. dest = d
  66. }
  67. case internet.SocketConfig_TProxy:
  68. dest = net.DestinationFromAddr(conn.LocalAddr())
  69. }
  70. if dest.IsValid() {
  71. outbounds[0].Target = dest
  72. }
  73. }
  74. ctx = session.ContextWithOutbounds(ctx, outbounds)
  75. if w.uplinkCounter != nil || w.downlinkCounter != nil {
  76. conn = &stat.CounterConnection{
  77. Connection: conn,
  78. ReadCounter: w.uplinkCounter,
  79. WriteCounter: w.downlinkCounter,
  80. }
  81. }
  82. ctx = session.ContextWithInbound(ctx, &session.Inbound{
  83. Source: net.DestinationFromAddr(conn.RemoteAddr()),
  84. Gateway: net.TCPDestination(w.address, w.port),
  85. Tag: w.tag,
  86. Conn: conn,
  87. })
  88. content := new(session.Content)
  89. if w.sniffingConfig != nil {
  90. content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
  91. content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
  92. content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded
  93. content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
  94. content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly
  95. }
  96. ctx = session.ContextWithContent(ctx, content)
  97. if err := w.proxy.Process(ctx, net.Network_TCP, conn, w.dispatcher); err != nil {
  98. errors.LogInfoInner(ctx, err, "connection ends")
  99. }
  100. cancel()
  101. conn.Close()
  102. }
  103. func (w *tcpWorker) Proxy() proxy.Inbound {
  104. return w.proxy
  105. }
  106. func (w *tcpWorker) Start() error {
  107. ctx := context.Background()
  108. hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) {
  109. go w.callback(conn)
  110. })
  111. if err != nil {
  112. return errors.New("failed to listen TCP on ", w.port).AtWarning().Base(err)
  113. }
  114. w.hub = hub
  115. return nil
  116. }
  117. func (w *tcpWorker) Close() error {
  118. var errs []interface{}
  119. if w.hub != nil {
  120. if err := common.Close(w.hub); err != nil {
  121. errs = append(errs, err)
  122. }
  123. if err := common.Close(w.proxy); err != nil {
  124. errs = append(errs, err)
  125. }
  126. }
  127. if len(errs) > 0 {
  128. return errors.New("failed to close all resources").Base(errors.New(serial.Concat(errs...)))
  129. }
  130. return nil
  131. }
  132. func (w *tcpWorker) Port() net.Port {
  133. return w.port
  134. }
  135. type udpConn struct {
  136. lastActivityTime int64 // in seconds
  137. reader buf.Reader
  138. writer buf.Writer
  139. output func([]byte) (int, error)
  140. remote net.Addr
  141. local net.Addr
  142. done *done.Instance
  143. uplink stats.Counter
  144. downlink stats.Counter
  145. inactive bool
  146. cancel context.CancelFunc
  147. }
  148. func (c *udpConn) setInactive() {
  149. c.inactive = true
  150. }
  151. func (c *udpConn) updateActivity() {
  152. atomic.StoreInt64(&c.lastActivityTime, time.Now().Unix())
  153. }
  154. // ReadMultiBuffer implements buf.Reader
  155. func (c *udpConn) ReadMultiBuffer() (buf.MultiBuffer, error) {
  156. mb, err := c.reader.ReadMultiBuffer()
  157. if err != nil {
  158. return nil, err
  159. }
  160. c.updateActivity()
  161. if c.uplink != nil {
  162. c.uplink.Add(int64(mb.Len()))
  163. }
  164. return mb, nil
  165. }
  166. func (c *udpConn) Read(buf []byte) (int, error) {
  167. panic("not implemented")
  168. }
  169. // Write implements io.Writer.
  170. func (c *udpConn) Write(buf []byte) (int, error) {
  171. n, err := c.output(buf)
  172. if c.downlink != nil {
  173. c.downlink.Add(int64(n))
  174. }
  175. if err == nil {
  176. c.updateActivity()
  177. }
  178. return n, err
  179. }
  180. func (c *udpConn) Close() error {
  181. if c.cancel != nil {
  182. c.cancel()
  183. }
  184. common.Must(c.done.Close())
  185. common.Must(common.Close(c.writer))
  186. return nil
  187. }
  188. func (c *udpConn) RemoteAddr() net.Addr {
  189. return c.remote
  190. }
  191. func (c *udpConn) LocalAddr() net.Addr {
  192. return c.local
  193. }
  194. func (*udpConn) SetDeadline(time.Time) error {
  195. return nil
  196. }
  197. func (*udpConn) SetReadDeadline(time.Time) error {
  198. return nil
  199. }
  200. func (*udpConn) SetWriteDeadline(time.Time) error {
  201. return nil
  202. }
  203. type connID struct {
  204. src net.Destination
  205. dest net.Destination
  206. }
  207. type udpWorker struct {
  208. sync.RWMutex
  209. proxy proxy.Inbound
  210. hub *udp.Hub
  211. address net.Address
  212. port net.Port
  213. tag string
  214. stream *internet.MemoryStreamConfig
  215. dispatcher routing.Dispatcher
  216. sniffingConfig *proxyman.SniffingConfig
  217. uplinkCounter stats.Counter
  218. downlinkCounter stats.Counter
  219. checker *task.Periodic
  220. activeConn map[connID]*udpConn
  221. ctx context.Context
  222. cone bool
  223. }
  224. func (w *udpWorker) getConnection(id connID) (*udpConn, bool) {
  225. w.Lock()
  226. defer w.Unlock()
  227. if conn, found := w.activeConn[id]; found && !conn.done.Done() {
  228. conn.updateActivity()
  229. return conn, true
  230. }
  231. pReader, pWriter := pipe.New(pipe.DiscardOverflow(), pipe.WithSizeLimit(16*1024))
  232. conn := &udpConn{
  233. reader: pReader,
  234. writer: pWriter,
  235. output: func(b []byte) (int, error) {
  236. return w.hub.WriteTo(b, id.src)
  237. },
  238. remote: &net.UDPAddr{
  239. IP: id.src.Address.IP(),
  240. Port: int(id.src.Port),
  241. },
  242. local: &net.UDPAddr{
  243. IP: w.address.IP(),
  244. Port: int(w.port),
  245. },
  246. done: done.New(),
  247. uplink: w.uplinkCounter,
  248. downlink: w.downlinkCounter,
  249. }
  250. w.activeConn[id] = conn
  251. conn.updateActivity()
  252. return conn, false
  253. }
  254. func (w *udpWorker) callback(b *buf.Buffer, source net.Destination, originalDest net.Destination) {
  255. id := connID{
  256. src: source,
  257. }
  258. if originalDest.IsValid() {
  259. if !w.cone {
  260. id.dest = originalDest
  261. }
  262. b.UDP = &originalDest
  263. }
  264. conn, existing := w.getConnection(id)
  265. // payload will be discarded in pipe is full.
  266. conn.writer.WriteMultiBuffer(buf.MultiBuffer{b})
  267. if !existing {
  268. common.Must(w.checker.Start())
  269. go func() {
  270. ctx, cancel := context.WithCancel(w.ctx)
  271. conn.cancel = cancel
  272. sid := session.NewID()
  273. ctx = c.ContextWithID(ctx, sid)
  274. outbounds := []*session.Outbound{{}}
  275. if originalDest.IsValid() {
  276. outbounds[0].Target = originalDest
  277. }
  278. ctx = session.ContextWithOutbounds(ctx, outbounds)
  279. ctx = session.ContextWithInbound(ctx, &session.Inbound{
  280. Source: source,
  281. Gateway: net.UDPDestination(w.address, w.port),
  282. Tag: w.tag,
  283. })
  284. content := new(session.Content)
  285. if w.sniffingConfig != nil {
  286. content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
  287. content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
  288. content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded
  289. content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
  290. content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly
  291. }
  292. ctx = session.ContextWithContent(ctx, content)
  293. if err := w.proxy.Process(ctx, net.Network_UDP, conn, w.dispatcher); err != nil {
  294. errors.LogInfoInner(ctx, err, "connection ends")
  295. }
  296. conn.Close()
  297. // conn not removed by checker TODO may be lock worker here is better
  298. if !conn.inactive {
  299. conn.setInactive()
  300. w.removeConn(id)
  301. }
  302. }()
  303. }
  304. }
  305. func (w *udpWorker) removeConn(id connID) {
  306. w.Lock()
  307. delete(w.activeConn, id)
  308. w.Unlock()
  309. }
  310. func (w *udpWorker) handlePackets() {
  311. receive := w.hub.Receive()
  312. for payload := range receive {
  313. w.callback(payload.Payload, payload.Source, payload.Target)
  314. }
  315. }
  316. func (w *udpWorker) clean() error {
  317. nowSec := time.Now().Unix()
  318. w.Lock()
  319. defer w.Unlock()
  320. if len(w.activeConn) == 0 {
  321. return errors.New("no more connections. stopping...")
  322. }
  323. for addr, conn := range w.activeConn {
  324. if nowSec-atomic.LoadInt64(&conn.lastActivityTime) > 2*60 {
  325. if !conn.inactive {
  326. conn.setInactive()
  327. delete(w.activeConn, addr)
  328. }
  329. conn.Close()
  330. }
  331. }
  332. if len(w.activeConn) == 0 {
  333. w.activeConn = make(map[connID]*udpConn, 16)
  334. }
  335. return nil
  336. }
  337. func (w *udpWorker) Start() error {
  338. w.activeConn = make(map[connID]*udpConn, 16)
  339. ctx := context.Background()
  340. h, err := udp.ListenUDP(ctx, w.address, w.port, w.stream, udp.HubCapacity(256))
  341. if err != nil {
  342. return err
  343. }
  344. w.cone = w.ctx.Value("cone").(bool)
  345. w.checker = &task.Periodic{
  346. Interval: time.Minute,
  347. Execute: w.clean,
  348. }
  349. w.hub = h
  350. go w.handlePackets()
  351. return nil
  352. }
  353. func (w *udpWorker) Close() error {
  354. w.Lock()
  355. defer w.Unlock()
  356. var errs []interface{}
  357. if w.hub != nil {
  358. if err := w.hub.Close(); err != nil {
  359. errs = append(errs, err)
  360. }
  361. }
  362. if w.checker != nil {
  363. if err := w.checker.Close(); err != nil {
  364. errs = append(errs, err)
  365. }
  366. }
  367. if err := common.Close(w.proxy); err != nil {
  368. errs = append(errs, err)
  369. }
  370. if len(errs) > 0 {
  371. return errors.New("failed to close all resources").Base(errors.New(serial.Concat(errs...)))
  372. }
  373. return nil
  374. }
  375. func (w *udpWorker) Port() net.Port {
  376. return w.port
  377. }
  378. func (w *udpWorker) Proxy() proxy.Inbound {
  379. return w.proxy
  380. }
  381. type dsWorker struct {
  382. address net.Address
  383. proxy proxy.Inbound
  384. stream *internet.MemoryStreamConfig
  385. tag string
  386. dispatcher routing.Dispatcher
  387. sniffingConfig *proxyman.SniffingConfig
  388. uplinkCounter stats.Counter
  389. downlinkCounter stats.Counter
  390. hub internet.Listener
  391. ctx context.Context
  392. }
  393. func (w *dsWorker) callback(conn stat.Connection) {
  394. ctx, cancel := context.WithCancel(w.ctx)
  395. sid := session.NewID()
  396. ctx = c.ContextWithID(ctx, sid)
  397. if w.uplinkCounter != nil || w.downlinkCounter != nil {
  398. conn = &stat.CounterConnection{
  399. Connection: conn,
  400. ReadCounter: w.uplinkCounter,
  401. WriteCounter: w.downlinkCounter,
  402. }
  403. }
  404. ctx = session.ContextWithInbound(ctx, &session.Inbound{
  405. Source: net.DestinationFromAddr(conn.RemoteAddr()),
  406. Gateway: net.UnixDestination(w.address),
  407. Tag: w.tag,
  408. Conn: conn,
  409. })
  410. content := new(session.Content)
  411. if w.sniffingConfig != nil {
  412. content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
  413. content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
  414. content.SniffingRequest.ExcludeForDomain = w.sniffingConfig.DomainsExcluded
  415. content.SniffingRequest.MetadataOnly = w.sniffingConfig.MetadataOnly
  416. content.SniffingRequest.RouteOnly = w.sniffingConfig.RouteOnly
  417. }
  418. ctx = session.ContextWithContent(ctx, content)
  419. if err := w.proxy.Process(ctx, net.Network_UNIX, conn, w.dispatcher); err != nil {
  420. errors.LogInfoInner(ctx, err, "connection ends")
  421. }
  422. cancel()
  423. if err := conn.Close(); err != nil {
  424. errors.LogInfoInner(ctx, err, "failed to close connection")
  425. }
  426. }
  427. func (w *dsWorker) Proxy() proxy.Inbound {
  428. return w.proxy
  429. }
  430. func (w *dsWorker) Port() net.Port {
  431. return net.Port(0)
  432. }
  433. func (w *dsWorker) Start() error {
  434. ctx := context.Background()
  435. hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn stat.Connection) {
  436. go w.callback(conn)
  437. })
  438. if err != nil {
  439. return errors.New("failed to listen Unix Domain Socket on ", w.address).AtWarning().Base(err)
  440. }
  441. w.hub = hub
  442. return nil
  443. }
  444. func (w *dsWorker) Close() error {
  445. var errs []interface{}
  446. if w.hub != nil {
  447. if err := common.Close(w.hub); err != nil {
  448. errs = append(errs, err)
  449. }
  450. if err := common.Close(w.proxy); err != nil {
  451. errs = append(errs, err)
  452. }
  453. }
  454. if len(errs) > 0 {
  455. return errors.New("failed to close all resources").Base(errors.New(serial.Concat(errs...)))
  456. }
  457. return nil
  458. }