command.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package command
  2. //go:generate go run github.com/xtls/xray-core/common/errors/errorgen
  3. import (
  4. "context"
  5. "time"
  6. "github.com/xtls/xray-core/common"
  7. "github.com/xtls/xray-core/common/errors"
  8. "github.com/xtls/xray-core/core"
  9. "github.com/xtls/xray-core/features/routing"
  10. "github.com/xtls/xray-core/features/stats"
  11. "google.golang.org/grpc"
  12. )
  13. // routingServer is an implementation of RoutingService.
  14. type routingServer struct {
  15. router routing.Router
  16. routingStats stats.Channel
  17. }
  18. func (s *routingServer) GetBalancerInfo(ctx context.Context, request *GetBalancerInfoRequest) (*GetBalancerInfoResponse, error) {
  19. var ret GetBalancerInfoResponse
  20. ret.Balancer = &BalancerMsg{}
  21. if bo, ok := s.router.(routing.BalancerOverrider); ok {
  22. {
  23. res, err := bo.GetOverrideTarget(request.GetTag())
  24. if err != nil {
  25. return nil, err
  26. }
  27. ret.Balancer.Override = &OverrideInfo{
  28. Target: res,
  29. }
  30. }
  31. }
  32. if pt, ok := s.router.(routing.BalancerPrincipleTarget); ok {
  33. {
  34. res, err := pt.GetPrincipleTarget(request.GetTag())
  35. if err != nil {
  36. errors.LogInfoInner(ctx, err, "unable to obtain principle target")
  37. } else {
  38. ret.Balancer.PrincipleTarget = &PrincipleTargetInfo{Tag: res}
  39. }
  40. }
  41. }
  42. return &ret, nil
  43. }
  44. func (s *routingServer) OverrideBalancerTarget(ctx context.Context, request *OverrideBalancerTargetRequest) (*OverrideBalancerTargetResponse, error) {
  45. if bo, ok := s.router.(routing.BalancerOverrider); ok {
  46. return &OverrideBalancerTargetResponse{}, bo.SetOverrideTarget(request.BalancerTag, request.Target)
  47. }
  48. return nil, errors.New("unsupported router implementation")
  49. }
  50. func (s *routingServer) AddRule(ctx context.Context, request *AddRuleRequest) (*AddRuleResponse, error) {
  51. if bo, ok := s.router.(routing.Router); ok {
  52. return &AddRuleResponse{}, bo.AddRule(request.Config, request.ShouldAppend)
  53. }
  54. return nil, errors.New("unsupported router implementation")
  55. }
  56. func (s *routingServer) RemoveRule(ctx context.Context, request *RemoveRuleRequest) (*RemoveRuleResponse, error) {
  57. if bo, ok := s.router.(routing.Router); ok {
  58. return &RemoveRuleResponse{}, bo.RemoveRule(request.RuleTag)
  59. }
  60. return nil, errors.New("unsupported router implementation")
  61. }
  62. // NewRoutingServer creates a statistics service with statistics manager.
  63. func NewRoutingServer(router routing.Router, routingStats stats.Channel) RoutingServiceServer {
  64. return &routingServer{
  65. router: router,
  66. routingStats: routingStats,
  67. }
  68. }
  69. func (s *routingServer) TestRoute(ctx context.Context, request *TestRouteRequest) (*RoutingContext, error) {
  70. if request.RoutingContext == nil {
  71. return nil, errors.New("Invalid routing request.")
  72. }
  73. route, err := s.router.PickRoute(AsRoutingContext(request.RoutingContext))
  74. if err != nil {
  75. return nil, err
  76. }
  77. if request.PublishResult && s.routingStats != nil {
  78. ctx, _ := context.WithTimeout(context.Background(), 4*time.Second)
  79. s.routingStats.Publish(ctx, route)
  80. }
  81. return AsProtobufMessage(request.FieldSelectors)(route), nil
  82. }
  83. func (s *routingServer) SubscribeRoutingStats(request *SubscribeRoutingStatsRequest, stream RoutingService_SubscribeRoutingStatsServer) error {
  84. if s.routingStats == nil {
  85. return errors.New("Routing statistics not enabled.")
  86. }
  87. genMessage := AsProtobufMessage(request.FieldSelectors)
  88. subscriber, err := stats.SubscribeRunnableChannel(s.routingStats)
  89. if err != nil {
  90. return err
  91. }
  92. defer stats.UnsubscribeClosableChannel(s.routingStats, subscriber)
  93. for {
  94. select {
  95. case value, ok := <-subscriber:
  96. if !ok {
  97. return errors.New("Upstream closed the subscriber channel.")
  98. }
  99. route, ok := value.(routing.Route)
  100. if !ok {
  101. return errors.New("Upstream sent malformed statistics.")
  102. }
  103. err := stream.Send(genMessage(route))
  104. if err != nil {
  105. return err
  106. }
  107. case <-stream.Context().Done():
  108. return stream.Context().Err()
  109. }
  110. }
  111. }
  112. func (s *routingServer) mustEmbedUnimplementedRoutingServiceServer() {}
  113. type service struct {
  114. v *core.Instance
  115. }
  116. func (s *service) Register(server *grpc.Server) {
  117. common.Must(s.v.RequireFeatures(func(router routing.Router, stats stats.Manager) {
  118. rs := NewRoutingServer(router, nil)
  119. RegisterRoutingServiceServer(server, rs)
  120. // For compatibility purposes
  121. vCoreDesc := RoutingService_ServiceDesc
  122. vCoreDesc.ServiceName = "v2ray.core.app.router.command.RoutingService"
  123. server.RegisterService(&vCoreDesc, rs)
  124. }))
  125. }
  126. func init() {
  127. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, cfg interface{}) (interface{}, error) {
  128. s := core.MustFromContext(ctx)
  129. return &service{v: s}, nil
  130. }))
  131. }