default.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. package dispatcher
  2. import (
  3. "context"
  4. "regexp"
  5. "strings"
  6. "sync"
  7. "time"
  8. "github.com/xtls/xray-core/common"
  9. "github.com/xtls/xray-core/common/buf"
  10. "github.com/xtls/xray-core/common/errors"
  11. "github.com/xtls/xray-core/common/log"
  12. "github.com/xtls/xray-core/common/net"
  13. "github.com/xtls/xray-core/common/protocol"
  14. "github.com/xtls/xray-core/common/session"
  15. "github.com/xtls/xray-core/core"
  16. "github.com/xtls/xray-core/features/dns"
  17. "github.com/xtls/xray-core/features/outbound"
  18. "github.com/xtls/xray-core/features/policy"
  19. "github.com/xtls/xray-core/features/routing"
  20. routing_session "github.com/xtls/xray-core/features/routing/session"
  21. "github.com/xtls/xray-core/features/stats"
  22. "github.com/xtls/xray-core/transport"
  23. "github.com/xtls/xray-core/transport/pipe"
  24. )
  25. var errSniffingTimeout = errors.New("timeout on sniffing")
  26. type cachedReader struct {
  27. sync.Mutex
  28. reader buf.TimeoutReader // *pipe.Reader or *buf.TimeoutWrapperReader
  29. cache buf.MultiBuffer
  30. }
  31. func (r *cachedReader) Cache(b *buf.Buffer, deadline time.Duration) error {
  32. mb, err := r.reader.ReadMultiBufferTimeout(deadline)
  33. if err != nil {
  34. return err
  35. }
  36. r.Lock()
  37. if !mb.IsEmpty() {
  38. r.cache, _ = buf.MergeMulti(r.cache, mb)
  39. }
  40. b.Clear()
  41. rawBytes := b.Extend(min(r.cache.Len(), b.Cap()))
  42. n := r.cache.Copy(rawBytes)
  43. b.Resize(0, int32(n))
  44. r.Unlock()
  45. return nil
  46. }
  47. func (r *cachedReader) readInternal() buf.MultiBuffer {
  48. r.Lock()
  49. defer r.Unlock()
  50. if r.cache != nil && !r.cache.IsEmpty() {
  51. mb := r.cache
  52. r.cache = nil
  53. return mb
  54. }
  55. return nil
  56. }
  57. func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
  58. mb := r.readInternal()
  59. if mb != nil {
  60. return mb, nil
  61. }
  62. return r.reader.ReadMultiBuffer()
  63. }
  64. func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
  65. mb := r.readInternal()
  66. if mb != nil {
  67. return mb, nil
  68. }
  69. return r.reader.ReadMultiBufferTimeout(timeout)
  70. }
  71. func (r *cachedReader) Interrupt() {
  72. r.Lock()
  73. if r.cache != nil {
  74. r.cache = buf.ReleaseMulti(r.cache)
  75. }
  76. r.Unlock()
  77. if p, ok := r.reader.(*pipe.Reader); ok {
  78. p.Interrupt()
  79. }
  80. }
  81. // DefaultDispatcher is a default implementation of Dispatcher.
  82. type DefaultDispatcher struct {
  83. ohm outbound.Manager
  84. router routing.Router
  85. policy policy.Manager
  86. stats stats.Manager
  87. fdns dns.FakeDNSEngine
  88. }
  89. func init() {
  90. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
  91. d := new(DefaultDispatcher)
  92. if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error {
  93. core.OptionalFeatures(ctx, func(fdns dns.FakeDNSEngine) {
  94. d.fdns = fdns
  95. })
  96. return d.Init(config.(*Config), om, router, pm, sm)
  97. }); err != nil {
  98. return nil, err
  99. }
  100. return d, nil
  101. }))
  102. }
  103. // Init initializes DefaultDispatcher.
  104. func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error {
  105. d.ohm = om
  106. d.router = router
  107. d.policy = pm
  108. d.stats = sm
  109. return nil
  110. }
  111. // Type implements common.HasType.
  112. func (*DefaultDispatcher) Type() interface{} {
  113. return routing.DispatcherType()
  114. }
  115. // Start implements common.Runnable.
  116. func (*DefaultDispatcher) Start() error {
  117. return nil
  118. }
  119. // Close implements common.Closable.
  120. func (*DefaultDispatcher) Close() error { return nil }
  121. func (d *DefaultDispatcher) getLink(ctx context.Context) (*transport.Link, *transport.Link) {
  122. opt := pipe.OptionsFromContext(ctx)
  123. uplinkReader, uplinkWriter := pipe.New(opt...)
  124. downlinkReader, downlinkWriter := pipe.New(opt...)
  125. inboundLink := &transport.Link{
  126. Reader: downlinkReader,
  127. Writer: uplinkWriter,
  128. }
  129. outboundLink := &transport.Link{
  130. Reader: uplinkReader,
  131. Writer: downlinkWriter,
  132. }
  133. sessionInbound := session.InboundFromContext(ctx)
  134. var user *protocol.MemoryUser
  135. if sessionInbound != nil {
  136. user = sessionInbound.User
  137. }
  138. if user != nil && len(user.Email) > 0 {
  139. p := d.policy.ForLevel(user.Level)
  140. if p.Stats.UserUplink {
  141. name := "user>>>" + user.Email + ">>>traffic>>>uplink"
  142. if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
  143. inboundLink.Writer = &SizeStatWriter{
  144. Counter: c,
  145. Writer: inboundLink.Writer,
  146. }
  147. }
  148. }
  149. if p.Stats.UserDownlink {
  150. name := "user>>>" + user.Email + ">>>traffic>>>downlink"
  151. if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
  152. outboundLink.Writer = &SizeStatWriter{
  153. Counter: c,
  154. Writer: outboundLink.Writer,
  155. }
  156. }
  157. }
  158. if p.Stats.UserOnline {
  159. name := "user>>>" + user.Email + ">>>online"
  160. if om, _ := stats.GetOrRegisterOnlineMap(d.stats, name); om != nil {
  161. sessionInbounds := session.InboundFromContext(ctx)
  162. userIP := sessionInbounds.Source.Address.String()
  163. om.AddIP(userIP)
  164. // log Online user with ips
  165. // errors.LogDebug(ctx, "user>>>" + user.Email + ">>>online", om.Count(), om.List())
  166. }
  167. }
  168. }
  169. return inboundLink, outboundLink
  170. }
  171. func (d *DefaultDispatcher) WrapLink(ctx context.Context, link *transport.Link) *transport.Link {
  172. sessionInbound := session.InboundFromContext(ctx)
  173. var user *protocol.MemoryUser
  174. if sessionInbound != nil {
  175. user = sessionInbound.User
  176. }
  177. link.Reader = &buf.TimeoutWrapperReader{Reader: link.Reader}
  178. if user != nil && len(user.Email) > 0 {
  179. p := d.policy.ForLevel(user.Level)
  180. if p.Stats.UserUplink {
  181. name := "user>>>" + user.Email + ">>>traffic>>>uplink"
  182. if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
  183. link.Reader.(*buf.TimeoutWrapperReader).Counter = c
  184. }
  185. }
  186. if p.Stats.UserDownlink {
  187. name := "user>>>" + user.Email + ">>>traffic>>>downlink"
  188. if c, _ := stats.GetOrRegisterCounter(d.stats, name); c != nil {
  189. link.Writer = &SizeStatWriter{
  190. Counter: c,
  191. Writer: link.Writer,
  192. }
  193. }
  194. }
  195. if p.Stats.UserOnline {
  196. name := "user>>>" + user.Email + ">>>online"
  197. if om, _ := stats.GetOrRegisterOnlineMap(d.stats, name); om != nil {
  198. sessionInbounds := session.InboundFromContext(ctx)
  199. userIP := sessionInbounds.Source.Address.String()
  200. om.AddIP(userIP)
  201. // log Online user with ips
  202. // errors.LogDebug(ctx, "user>>>" + user.Email + ">>>online", om.Count(), om.List())
  203. }
  204. }
  205. }
  206. return link
  207. }
  208. func (d *DefaultDispatcher) shouldOverride(ctx context.Context, result SniffResult, request session.SniffingRequest, destination net.Destination) bool {
  209. domain := result.Domain()
  210. if domain == "" {
  211. return false
  212. }
  213. for _, d := range request.ExcludeForDomain {
  214. if strings.HasPrefix(d, "regexp:") {
  215. pattern := d[7:]
  216. re, err := regexp.Compile(pattern)
  217. if err != nil {
  218. errors.LogInfo(ctx, "Unable to compile regex")
  219. continue
  220. }
  221. if re.MatchString(domain) {
  222. return false
  223. }
  224. } else {
  225. if strings.ToLower(domain) == d {
  226. return false
  227. }
  228. }
  229. }
  230. protocolString := result.Protocol()
  231. if resComp, ok := result.(SnifferResultComposite); ok {
  232. protocolString = resComp.ProtocolForDomainResult()
  233. }
  234. for _, p := range request.OverrideDestinationForProtocol {
  235. if strings.HasPrefix(protocolString, p) || strings.HasPrefix(p, protocolString) {
  236. return true
  237. }
  238. if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && protocolString != "bittorrent" && p == "fakedns" &&
  239. fkr0.IsIPInIPPool(destination.Address) {
  240. errors.LogInfo(ctx, "Using sniffer ", protocolString, " since the fake DNS missed")
  241. return true
  242. }
  243. if resultSubset, ok := result.(SnifferIsProtoSubsetOf); ok {
  244. if resultSubset.IsProtoSubsetOf(p) {
  245. return true
  246. }
  247. }
  248. }
  249. return false
  250. }
  251. // Dispatch implements routing.Dispatcher.
  252. func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destination) (*transport.Link, error) {
  253. if !destination.IsValid() {
  254. panic("Dispatcher: Invalid destination.")
  255. }
  256. outbounds := session.OutboundsFromContext(ctx)
  257. if len(outbounds) == 0 {
  258. outbounds = []*session.Outbound{{}}
  259. ctx = session.ContextWithOutbounds(ctx, outbounds)
  260. }
  261. ob := outbounds[len(outbounds)-1]
  262. ob.OriginalTarget = destination
  263. ob.Target = destination
  264. content := session.ContentFromContext(ctx)
  265. if content == nil {
  266. content = new(session.Content)
  267. ctx = session.ContextWithContent(ctx, content)
  268. }
  269. sniffingRequest := content.SniffingRequest
  270. inbound, outbound := d.getLink(ctx)
  271. if !sniffingRequest.Enabled {
  272. go d.routedDispatch(ctx, outbound, destination)
  273. } else {
  274. go func() {
  275. cReader := &cachedReader{
  276. reader: outbound.Reader.(*pipe.Reader),
  277. }
  278. outbound.Reader = cReader
  279. result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
  280. if err == nil {
  281. content.Protocol = result.Protocol()
  282. }
  283. if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
  284. domain := result.Domain()
  285. errors.LogInfo(ctx, "sniffed domain: ", domain)
  286. destination.Address = net.ParseAddress(domain)
  287. protocol := result.Protocol()
  288. if resComp, ok := result.(SnifferResultComposite); ok {
  289. protocol = resComp.ProtocolForDomainResult()
  290. }
  291. isFakeIP := false
  292. if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) {
  293. isFakeIP = true
  294. }
  295. if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
  296. ob.RouteTarget = destination
  297. } else {
  298. ob.Target = destination
  299. }
  300. }
  301. d.routedDispatch(ctx, outbound, destination)
  302. }()
  303. }
  304. return inbound, nil
  305. }
  306. // DispatchLink implements routing.Dispatcher.
  307. func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.Destination, outbound *transport.Link) error {
  308. if !destination.IsValid() {
  309. return errors.New("Dispatcher: Invalid destination.")
  310. }
  311. outbounds := session.OutboundsFromContext(ctx)
  312. if len(outbounds) == 0 {
  313. outbounds = []*session.Outbound{{}}
  314. ctx = session.ContextWithOutbounds(ctx, outbounds)
  315. }
  316. ob := outbounds[len(outbounds)-1]
  317. ob.OriginalTarget = destination
  318. ob.Target = destination
  319. content := session.ContentFromContext(ctx)
  320. if content == nil {
  321. content = new(session.Content)
  322. ctx = session.ContextWithContent(ctx, content)
  323. }
  324. outbound = d.WrapLink(ctx, outbound)
  325. sniffingRequest := content.SniffingRequest
  326. if !sniffingRequest.Enabled {
  327. d.routedDispatch(ctx, outbound, destination)
  328. } else {
  329. cReader := &cachedReader{
  330. reader: outbound.Reader.(buf.TimeoutReader),
  331. }
  332. outbound.Reader = cReader
  333. result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
  334. if err == nil {
  335. content.Protocol = result.Protocol()
  336. }
  337. if err == nil && d.shouldOverride(ctx, result, sniffingRequest, destination) {
  338. domain := result.Domain()
  339. errors.LogInfo(ctx, "sniffed domain: ", domain)
  340. destination.Address = net.ParseAddress(domain)
  341. protocol := result.Protocol()
  342. if resComp, ok := result.(SnifferResultComposite); ok {
  343. protocol = resComp.ProtocolForDomainResult()
  344. }
  345. isFakeIP := false
  346. if fkr0, ok := d.fdns.(dns.FakeDNSEngineRev0); ok && fkr0.IsIPInIPPool(ob.Target.Address) {
  347. isFakeIP = true
  348. }
  349. if sniffingRequest.RouteOnly && protocol != "fakedns" && protocol != "fakedns+others" && !isFakeIP {
  350. ob.RouteTarget = destination
  351. } else {
  352. ob.Target = destination
  353. }
  354. }
  355. d.routedDispatch(ctx, outbound, destination)
  356. }
  357. return nil
  358. }
  359. func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
  360. payload := buf.NewWithSize(32767)
  361. defer payload.Release()
  362. sniffer := NewSniffer(ctx)
  363. metaresult, metadataErr := sniffer.SniffMetadata(ctx)
  364. if metadataOnly {
  365. return metaresult, metadataErr
  366. }
  367. contentResult, contentErr := func() (SniffResult, error) {
  368. cacheDeadline := 200 * time.Millisecond
  369. totalAttempt := 0
  370. for {
  371. select {
  372. case <-ctx.Done():
  373. return nil, ctx.Err()
  374. default:
  375. cachingStartingTimeStamp := time.Now()
  376. err := cReader.Cache(payload, cacheDeadline)
  377. if err != nil {
  378. return nil, err
  379. }
  380. cachingTimeElapsed := time.Since(cachingStartingTimeStamp)
  381. cacheDeadline -= cachingTimeElapsed
  382. if !payload.IsEmpty() {
  383. result, err := sniffer.Sniff(ctx, payload.Bytes(), network)
  384. switch err {
  385. case common.ErrNoClue: // No Clue: protocol not matches, and sniffer cannot determine whether there will be a match or not
  386. totalAttempt++
  387. case protocol.ErrProtoNeedMoreData: // Protocol Need More Data: protocol matches, but need more data to complete sniffing
  388. // in this case, do not add totalAttempt(allow to read until timeout)
  389. default:
  390. return result, err
  391. }
  392. } else {
  393. totalAttempt++
  394. }
  395. if totalAttempt >= 2 || cacheDeadline <= 0 {
  396. return nil, errSniffingTimeout
  397. }
  398. }
  399. }
  400. }()
  401. if contentErr != nil && metadataErr == nil {
  402. return metaresult, nil
  403. }
  404. if contentErr == nil && metadataErr == nil {
  405. return CompositeResult(metaresult, contentResult), nil
  406. }
  407. return contentResult, contentErr
  408. }
  409. func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.Link, destination net.Destination) {
  410. outbounds := session.OutboundsFromContext(ctx)
  411. ob := outbounds[len(outbounds)-1]
  412. var handler outbound.Handler
  413. routingLink := routing_session.AsRoutingContext(ctx)
  414. inTag := routingLink.GetInboundTag()
  415. isPickRoute := 0
  416. if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
  417. ctx = session.SetForcedOutboundTagToContext(ctx, "")
  418. if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
  419. isPickRoute = 1
  420. errors.LogInfo(ctx, "taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]")
  421. handler = h
  422. } else {
  423. errors.LogError(ctx, "non existing tag for platform initialized detour: ", forcedOutboundTag)
  424. common.Close(link.Writer)
  425. common.Interrupt(link.Reader)
  426. return
  427. }
  428. } else if d.router != nil {
  429. if route, err := d.router.PickRoute(routingLink); err == nil {
  430. outTag := route.GetOutboundTag()
  431. if h := d.ohm.GetHandler(outTag); h != nil {
  432. isPickRoute = 2
  433. if route.GetRuleTag() == "" {
  434. errors.LogInfo(ctx, "taking detour [", outTag, "] for [", destination, "]")
  435. } else {
  436. errors.LogInfo(ctx, "Hit route rule: [", route.GetRuleTag(), "] so taking detour [", outTag, "] for [", destination, "]")
  437. }
  438. handler = h
  439. } else {
  440. errors.LogWarning(ctx, "non existing outTag: ", outTag)
  441. }
  442. } else {
  443. errors.LogInfo(ctx, "default route for ", destination)
  444. }
  445. }
  446. if handler == nil {
  447. handler = d.ohm.GetDefaultHandler()
  448. }
  449. if handler == nil {
  450. errors.LogInfo(ctx, "default outbound handler not exist")
  451. common.Close(link.Writer)
  452. common.Interrupt(link.Reader)
  453. return
  454. }
  455. ob.Tag = handler.Tag()
  456. if accessMessage := log.AccessMessageFromContext(ctx); accessMessage != nil {
  457. if tag := handler.Tag(); tag != "" {
  458. if inTag == "" {
  459. accessMessage.Detour = tag
  460. } else if isPickRoute == 1 {
  461. accessMessage.Detour = inTag + " ==> " + tag
  462. } else if isPickRoute == 2 {
  463. accessMessage.Detour = inTag + " -> " + tag
  464. } else {
  465. accessMessage.Detour = inTag + " >> " + tag
  466. }
  467. }
  468. log.Record(accessMessage)
  469. }
  470. handler.Dispatch(ctx, link)
  471. }