portal.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. package reverse
  2. import (
  3. "context"
  4. "sync"
  5. "time"
  6. "github.com/xtls/xray-core/common"
  7. "github.com/xtls/xray-core/common/buf"
  8. "github.com/xtls/xray-core/common/errors"
  9. "github.com/xtls/xray-core/common/mux"
  10. "github.com/xtls/xray-core/common/net"
  11. "github.com/xtls/xray-core/common/serial"
  12. "github.com/xtls/xray-core/common/session"
  13. "github.com/xtls/xray-core/common/signal"
  14. "github.com/xtls/xray-core/common/task"
  15. "github.com/xtls/xray-core/features/outbound"
  16. "github.com/xtls/xray-core/transport"
  17. "github.com/xtls/xray-core/transport/pipe"
  18. "google.golang.org/protobuf/proto"
  19. )
  20. type Portal struct {
  21. ohm outbound.Manager
  22. tag string
  23. domain string
  24. picker *StaticMuxPicker
  25. client *mux.ClientManager
  26. }
  27. func NewPortal(config *PortalConfig, ohm outbound.Manager) (*Portal, error) {
  28. if config.Tag == "" {
  29. return nil, errors.New("portal tag is empty")
  30. }
  31. if config.Domain == "" {
  32. return nil, errors.New("portal domain is empty")
  33. }
  34. picker, err := NewStaticMuxPicker()
  35. if err != nil {
  36. return nil, err
  37. }
  38. return &Portal{
  39. ohm: ohm,
  40. tag: config.Tag,
  41. domain: config.Domain,
  42. picker: picker,
  43. client: &mux.ClientManager{
  44. Picker: picker,
  45. },
  46. }, nil
  47. }
  48. func (p *Portal) Start() error {
  49. return p.ohm.AddHandler(context.Background(), &Outbound{
  50. portal: p,
  51. tag: p.tag,
  52. })
  53. }
  54. func (p *Portal) Close() error {
  55. return p.ohm.RemoveHandler(context.Background(), p.tag)
  56. }
  57. func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error {
  58. outbounds := session.OutboundsFromContext(ctx)
  59. ob := outbounds[len(outbounds)-1]
  60. if ob == nil {
  61. return errors.New("outbound metadata not found").AtError()
  62. }
  63. if isDomain(ob.Target, p.domain) {
  64. opts := pipe.OptionsFromContext(ctx)
  65. uplinkReader, uplinkWriter := pipe.New(opts...)
  66. downlinkReader, downlinkWriter := pipe.New(opts...)
  67. muxClient, err := mux.NewClientWorker(transport.Link{
  68. Reader: uplinkReader,
  69. Writer: downlinkWriter,
  70. }, mux.ClientStrategy{})
  71. if err != nil {
  72. return errors.New("failed to create mux client worker").Base(err).AtWarning()
  73. }
  74. worker, err := NewPortalWorker(muxClient)
  75. if err != nil {
  76. return errors.New("failed to create portal worker").Base(err)
  77. }
  78. p.picker.AddWorker(worker)
  79. inboundLink := &transport.Link{Reader: downlinkReader, Writer: uplinkWriter}
  80. requestDone := func() error {
  81. if err := buf.Copy(link.Reader, inboundLink.Writer); err != nil {
  82. return errors.New("failed to transfer request").Base(err)
  83. }
  84. return nil
  85. }
  86. responseDone := func() error {
  87. if err := buf.Copy(inboundLink.Reader, link.Writer); err != nil {
  88. return err
  89. }
  90. return nil
  91. }
  92. requestDonePost := task.OnSuccess(requestDone, task.Close(inboundLink.Writer))
  93. if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
  94. common.Interrupt(inboundLink.Reader)
  95. common.Interrupt(inboundLink.Writer)
  96. return errors.New("connection ends").Base(err)
  97. }
  98. return nil
  99. }
  100. if ob.Target.Network == net.Network_UDP && ob.OriginalTarget.Address != nil && ob.OriginalTarget.Address != ob.Target.Address {
  101. link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
  102. link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address}
  103. }
  104. return p.client.Dispatch(ctx, link)
  105. }
  106. type Outbound struct {
  107. portal *Portal
  108. tag string
  109. }
  110. func (o *Outbound) Tag() string {
  111. return o.tag
  112. }
  113. func (o *Outbound) Dispatch(ctx context.Context, link *transport.Link) {
  114. if err := o.portal.HandleConnection(ctx, link); err != nil {
  115. errors.LogInfoInner(ctx, err, "failed to process reverse connection")
  116. common.Interrupt(link.Writer)
  117. common.Interrupt(link.Reader)
  118. }
  119. }
  120. func (o *Outbound) Start() error {
  121. return nil
  122. }
  123. func (o *Outbound) Close() error {
  124. return nil
  125. }
  126. // SenderSettings implements outbound.Handler.
  127. func (o *Outbound) SenderSettings() *serial.TypedMessage {
  128. return nil
  129. }
  130. // ProxySettings implements outbound.Handler.
  131. func (o *Outbound) ProxySettings() *serial.TypedMessage {
  132. return nil
  133. }
  134. type StaticMuxPicker struct {
  135. access sync.Mutex
  136. workers []*PortalWorker
  137. cTask *task.Periodic
  138. }
  139. func NewStaticMuxPicker() (*StaticMuxPicker, error) {
  140. p := &StaticMuxPicker{}
  141. p.cTask = &task.Periodic{
  142. Execute: p.cleanup,
  143. Interval: time.Second * 30,
  144. }
  145. p.cTask.Start()
  146. return p, nil
  147. }
  148. func (p *StaticMuxPicker) cleanup() error {
  149. p.access.Lock()
  150. defer p.access.Unlock()
  151. var activeWorkers []*PortalWorker
  152. for _, w := range p.workers {
  153. if !w.Closed() {
  154. activeWorkers = append(activeWorkers, w)
  155. } else {
  156. w.timer.SetTimeout(0)
  157. }
  158. }
  159. if len(activeWorkers) != len(p.workers) {
  160. p.workers = activeWorkers
  161. }
  162. return nil
  163. }
  164. func (p *StaticMuxPicker) PickAvailable() (*mux.ClientWorker, error) {
  165. p.access.Lock()
  166. defer p.access.Unlock()
  167. if len(p.workers) == 0 {
  168. return nil, errors.New("empty worker list")
  169. }
  170. var minIdx int = -1
  171. var minConn uint32 = 9999
  172. for i, w := range p.workers {
  173. if w.draining {
  174. continue
  175. }
  176. if w.IsFull() {
  177. continue
  178. }
  179. if w.client.ActiveConnections() < minConn {
  180. minConn = w.client.ActiveConnections()
  181. minIdx = i
  182. }
  183. }
  184. if minIdx == -1 {
  185. for i, w := range p.workers {
  186. if w.IsFull() {
  187. continue
  188. }
  189. if w.client.ActiveConnections() < minConn {
  190. minConn = w.client.ActiveConnections()
  191. minIdx = i
  192. }
  193. }
  194. }
  195. if minIdx != -1 {
  196. return p.workers[minIdx].client, nil
  197. }
  198. return nil, errors.New("no mux client worker available")
  199. }
  200. func (p *StaticMuxPicker) AddWorker(worker *PortalWorker) {
  201. p.access.Lock()
  202. defer p.access.Unlock()
  203. p.workers = append(p.workers, worker)
  204. }
  205. type PortalWorker struct {
  206. client *mux.ClientWorker
  207. control *task.Periodic
  208. writer buf.Writer
  209. reader buf.Reader
  210. draining bool
  211. counter uint32
  212. timer *signal.ActivityTimer
  213. }
  214. func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
  215. opt := []pipe.Option{pipe.WithSizeLimit(16 * 1024)}
  216. uplinkReader, uplinkWriter := pipe.New(opt...)
  217. downlinkReader, downlinkWriter := pipe.New(opt...)
  218. ctx := context.Background()
  219. outbounds := []*session.Outbound{{
  220. Target: net.UDPDestination(net.DomainAddress(internalDomain), 0),
  221. }}
  222. ctx = session.ContextWithOutbounds(ctx, outbounds)
  223. f := client.Dispatch(ctx, &transport.Link{
  224. Reader: uplinkReader,
  225. Writer: downlinkWriter,
  226. })
  227. if !f {
  228. return nil, errors.New("unable to dispatch control connection")
  229. }
  230. terminate := func() {
  231. client.Close()
  232. }
  233. w := &PortalWorker{
  234. client: client,
  235. reader: downlinkReader,
  236. writer: uplinkWriter,
  237. timer: signal.CancelAfterInactivity(ctx, terminate, 24*time.Hour), // // prevent leak
  238. }
  239. w.control = &task.Periodic{
  240. Execute: w.heartbeat,
  241. Interval: time.Second * 2,
  242. }
  243. w.control.Start()
  244. return w, nil
  245. }
  246. func (w *PortalWorker) heartbeat() error {
  247. if w.Closed() {
  248. return errors.New("client worker stopped")
  249. }
  250. if w.draining || w.writer == nil {
  251. return errors.New("already disposed")
  252. }
  253. msg := &Control{}
  254. msg.FillInRandom()
  255. if w.client.TotalConnections() > 256 {
  256. w.draining = true
  257. msg.State = Control_DRAIN
  258. defer func() {
  259. common.Close(w.writer)
  260. common.Interrupt(w.reader)
  261. w.writer = nil
  262. }()
  263. }
  264. w.counter = (w.counter + 1) % 5
  265. if w.draining || w.counter == 1 {
  266. b, err := proto.Marshal(msg)
  267. common.Must(err)
  268. mb := buf.MergeBytes(nil, b)
  269. w.timer.Update()
  270. return w.writer.WriteMultiBuffer(mb)
  271. }
  272. return nil
  273. }
  274. func (w *PortalWorker) IsFull() bool {
  275. return w.client.IsFull()
  276. }
  277. func (w *PortalWorker) Closed() bool {
  278. return w.client.Closed()
  279. }