router.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. package router
  2. //go:generate go run github.com/xtls/xray-core/common/errors/errorgen
  3. import (
  4. "context"
  5. "github.com/xtls/xray-core/common"
  6. "github.com/xtls/xray-core/core"
  7. "github.com/xtls/xray-core/features/dns"
  8. "github.com/xtls/xray-core/features/outbound"
  9. "github.com/xtls/xray-core/features/routing"
  10. routing_dns "github.com/xtls/xray-core/features/routing/dns"
  11. )
  12. // Router is an implementation of routing.Router.
  13. type Router struct {
  14. domainStrategy Config_DomainStrategy
  15. rules []*Rule
  16. balancers map[string]*Balancer
  17. dns dns.Client
  18. }
  19. // Route is an implementation of routing.Route.
  20. type Route struct {
  21. routing.Context
  22. outboundGroupTags []string
  23. outboundTag string
  24. }
  25. // Init initializes the Router.
  26. func (r *Router) Init(ctx context.Context, config *Config, d dns.Client, ohm outbound.Manager) error {
  27. r.domainStrategy = config.DomainStrategy
  28. r.dns = d
  29. r.balancers = make(map[string]*Balancer, len(config.BalancingRule))
  30. for _, rule := range config.BalancingRule {
  31. balancer, err := rule.Build(ohm)
  32. if err != nil {
  33. return err
  34. }
  35. balancer.InjectContext(ctx)
  36. r.balancers[rule.Tag] = balancer
  37. }
  38. r.rules = make([]*Rule, 0, len(config.Rule))
  39. for _, rule := range config.Rule {
  40. cond, err := rule.BuildCondition()
  41. if err != nil {
  42. return err
  43. }
  44. rr := &Rule{
  45. Condition: cond,
  46. Tag: rule.GetTag(),
  47. }
  48. btag := rule.GetBalancingTag()
  49. if len(btag) > 0 {
  50. brule, found := r.balancers[btag]
  51. if !found {
  52. return newError("balancer ", btag, " not found")
  53. }
  54. rr.Balancer = brule
  55. }
  56. r.rules = append(r.rules, rr)
  57. }
  58. return nil
  59. }
  60. // PickRoute implements routing.Router.
  61. func (r *Router) PickRoute(ctx routing.Context) (routing.Route, error) {
  62. rule, ctx, err := r.pickRouteInternal(ctx)
  63. if err != nil {
  64. return nil, err
  65. }
  66. tag, err := rule.GetTag()
  67. if err != nil {
  68. return nil, err
  69. }
  70. return &Route{Context: ctx, outboundTag: tag}, nil
  71. }
  72. func (r *Router) pickRouteInternal(ctx routing.Context) (*Rule, routing.Context, error) {
  73. // SkipDNSResolve is set from DNS module.
  74. // the DOH remote server maybe a domain name,
  75. // this prevents cycle resolving dead loop
  76. skipDNSResolve := ctx.GetSkipDNSResolve()
  77. if r.domainStrategy == Config_IpOnDemand && !skipDNSResolve {
  78. ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
  79. }
  80. for _, rule := range r.rules {
  81. if rule.Apply(ctx) {
  82. return rule, ctx, nil
  83. }
  84. }
  85. if r.domainStrategy != Config_IpIfNonMatch || len(ctx.GetTargetDomain()) == 0 || skipDNSResolve {
  86. return nil, ctx, common.ErrNoClue
  87. }
  88. ctx = routing_dns.ContextWithDNSClient(ctx, r.dns)
  89. // Try applying rules again if we have IPs.
  90. for _, rule := range r.rules {
  91. if rule.Apply(ctx) {
  92. return rule, ctx, nil
  93. }
  94. }
  95. return nil, ctx, common.ErrNoClue
  96. }
  97. // Start implements common.Runnable.
  98. func (*Router) Start() error {
  99. return nil
  100. }
  101. // Close implements common.Closable.
  102. func (*Router) Close() error {
  103. return nil
  104. }
  105. // Type implement common.HasType.
  106. func (*Router) Type() interface{} {
  107. return routing.RouterType()
  108. }
  109. // GetOutboundGroupTags implements routing.Route.
  110. func (r *Route) GetOutboundGroupTags() []string {
  111. return r.outboundGroupTags
  112. }
  113. // GetOutboundTag implements routing.Route.
  114. func (r *Route) GetOutboundTag() string {
  115. return r.outboundTag
  116. }
  117. func init() {
  118. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
  119. r := new(Router)
  120. if err := core.RequireFeatures(ctx, func(d dns.Client, ohm outbound.Manager) error {
  121. return r.Init(ctx, config.(*Config), d, ohm)
  122. }); err != nil {
  123. return nil, err
  124. }
  125. return r, nil
  126. }))
  127. }