command.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package command
  2. import (
  3. "context"
  4. "time"
  5. "github.com/xtls/xray-core/common"
  6. "github.com/xtls/xray-core/common/errors"
  7. "github.com/xtls/xray-core/core"
  8. "github.com/xtls/xray-core/features/routing"
  9. "github.com/xtls/xray-core/features/stats"
  10. "google.golang.org/grpc"
  11. )
  12. // routingServer is an implementation of RoutingService.
  13. type routingServer struct {
  14. router routing.Router
  15. routingStats stats.Channel
  16. }
  17. func (s *routingServer) GetBalancerInfo(ctx context.Context, request *GetBalancerInfoRequest) (*GetBalancerInfoResponse, error) {
  18. var ret GetBalancerInfoResponse
  19. ret.Balancer = &BalancerMsg{}
  20. if bo, ok := s.router.(routing.BalancerOverrider); ok {
  21. {
  22. res, err := bo.GetOverrideTarget(request.GetTag())
  23. if err != nil {
  24. return nil, err
  25. }
  26. ret.Balancer.Override = &OverrideInfo{
  27. Target: res,
  28. }
  29. }
  30. }
  31. if pt, ok := s.router.(routing.BalancerPrincipleTarget); ok {
  32. {
  33. res, err := pt.GetPrincipleTarget(request.GetTag())
  34. if err != nil {
  35. errors.LogInfoInner(ctx, err, "unable to obtain principle target")
  36. } else {
  37. ret.Balancer.PrincipleTarget = &PrincipleTargetInfo{Tag: res}
  38. }
  39. }
  40. }
  41. return &ret, nil
  42. }
  43. func (s *routingServer) OverrideBalancerTarget(ctx context.Context, request *OverrideBalancerTargetRequest) (*OverrideBalancerTargetResponse, error) {
  44. if bo, ok := s.router.(routing.BalancerOverrider); ok {
  45. return &OverrideBalancerTargetResponse{}, bo.SetOverrideTarget(request.BalancerTag, request.Target)
  46. }
  47. return nil, errors.New("unsupported router implementation")
  48. }
  49. func (s *routingServer) AddRule(ctx context.Context, request *AddRuleRequest) (*AddRuleResponse, error) {
  50. if bo, ok := s.router.(routing.Router); ok {
  51. return &AddRuleResponse{}, bo.AddRule(request.Config, request.ShouldAppend)
  52. }
  53. return nil, errors.New("unsupported router implementation")
  54. }
  55. func (s *routingServer) RemoveRule(ctx context.Context, request *RemoveRuleRequest) (*RemoveRuleResponse, error) {
  56. if bo, ok := s.router.(routing.Router); ok {
  57. return &RemoveRuleResponse{}, bo.RemoveRule(request.RuleTag)
  58. }
  59. return nil, errors.New("unsupported router implementation")
  60. }
  61. func (s *routingServer) ListRule(ctx context.Context, request *ListRuleRequest) (*ListRuleResponse, error) {
  62. if bo, ok := s.router.(routing.Router); ok {
  63. response := &ListRuleResponse{}
  64. for _, v := range bo.ListRule() {
  65. response.Rules = append(response.Rules, &ListRuleItem{
  66. Tag: v.GetOutboundTag(),
  67. RuleTag: v.GetRuleTag(),
  68. })
  69. }
  70. return response, nil
  71. }
  72. return nil, errors.New("unsupported router implementation")
  73. }
  74. // NewRoutingServer creates a statistics service with statistics manager.
  75. func NewRoutingServer(router routing.Router, routingStats stats.Channel) RoutingServiceServer {
  76. return &routingServer{
  77. router: router,
  78. routingStats: routingStats,
  79. }
  80. }
  81. func (s *routingServer) TestRoute(ctx context.Context, request *TestRouteRequest) (*RoutingContext, error) {
  82. if request.RoutingContext == nil {
  83. return nil, errors.New("Invalid routing request.")
  84. }
  85. route, err := s.router.PickRoute(AsRoutingContext(request.RoutingContext))
  86. if err != nil {
  87. return nil, err
  88. }
  89. if request.PublishResult && s.routingStats != nil {
  90. ctx, _ := context.WithTimeout(context.Background(), 4*time.Second)
  91. s.routingStats.Publish(ctx, route)
  92. }
  93. return AsProtobufMessage(request.FieldSelectors)(route), nil
  94. }
  95. func (s *routingServer) SubscribeRoutingStats(request *SubscribeRoutingStatsRequest, stream RoutingService_SubscribeRoutingStatsServer) error {
  96. if s.routingStats == nil {
  97. return errors.New("Routing statistics not enabled.")
  98. }
  99. genMessage := AsProtobufMessage(request.FieldSelectors)
  100. subscriber, err := stats.SubscribeRunnableChannel(s.routingStats)
  101. if err != nil {
  102. return err
  103. }
  104. defer stats.UnsubscribeClosableChannel(s.routingStats, subscriber)
  105. for {
  106. select {
  107. case value, ok := <-subscriber:
  108. if !ok {
  109. return errors.New("Upstream closed the subscriber channel.")
  110. }
  111. route, ok := value.(routing.Route)
  112. if !ok {
  113. return errors.New("Upstream sent malformed statistics.")
  114. }
  115. err := stream.Send(genMessage(route))
  116. if err != nil {
  117. return err
  118. }
  119. case <-stream.Context().Done():
  120. return stream.Context().Err()
  121. }
  122. }
  123. }
  124. func (s *routingServer) mustEmbedUnimplementedRoutingServiceServer() {}
  125. type service struct {
  126. v *core.Instance
  127. }
  128. func (s *service) Register(server *grpc.Server) {
  129. common.Must(s.v.RequireFeatures(func(router routing.Router, stats stats.Manager) {
  130. rs := NewRoutingServer(router, nil)
  131. RegisterRoutingServiceServer(server, rs)
  132. // For compatibility purposes
  133. vCoreDesc := RoutingService_ServiceDesc
  134. vCoreDesc.ServiceName = "v2ray.core.app.router.command.RoutingService"
  135. server.RegisterService(&vCoreDesc, rs)
  136. }, false))
  137. }
  138. func init() {
  139. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, cfg interface{}) (interface{}, error) {
  140. s := core.MustFromContext(ctx)
  141. return &service{v: s}, nil
  142. }))
  143. }