command_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. package command_test
  2. import (
  3. "context"
  4. "testing"
  5. "time"
  6. "github.com/golang/mock/gomock"
  7. "github.com/google/go-cmp/cmp"
  8. "github.com/google/go-cmp/cmp/cmpopts"
  9. "github.com/xtls/xray-core/v1/app/router"
  10. . "github.com/xtls/xray-core/v1/app/router/command"
  11. "github.com/xtls/xray-core/v1/app/stats"
  12. "github.com/xtls/xray-core/v1/common"
  13. "github.com/xtls/xray-core/v1/common/net"
  14. "github.com/xtls/xray-core/v1/features/routing"
  15. "github.com/xtls/xray-core/v1/testing/mocks"
  16. "google.golang.org/grpc"
  17. "google.golang.org/grpc/test/bufconn"
  18. )
  19. func TestServiceSubscribeRoutingStats(t *testing.T) {
  20. c := stats.NewChannel(&stats.ChannelConfig{
  21. SubscriberLimit: 1,
  22. BufferSize: 0,
  23. Blocking: true,
  24. })
  25. common.Must(c.Start())
  26. defer c.Close()
  27. lis := bufconn.Listen(1024 * 1024)
  28. bufDialer := func(context.Context, string) (net.Conn, error) {
  29. return lis.Dial()
  30. }
  31. testCases := []*RoutingContext{
  32. {InboundTag: "in", OutboundTag: "out"},
  33. {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
  34. {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
  35. {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
  36. {Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
  37. {Protocol: "bittorrent", OutboundTag: "blocked"},
  38. {User: "[email protected]", OutboundTag: "out"},
  39. {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
  40. }
  41. errCh := make(chan error)
  42. nextPub := make(chan struct{})
  43. // Server goroutine
  44. go func() {
  45. server := grpc.NewServer()
  46. RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
  47. errCh <- server.Serve(lis)
  48. }()
  49. // Publisher goroutine
  50. go func() {
  51. publishTestCases := func() error {
  52. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  53. defer cancel()
  54. for { // Wait until there's one subscriber in routing stats channel
  55. if len(c.Subscribers()) > 0 {
  56. break
  57. }
  58. if ctx.Err() != nil {
  59. return ctx.Err()
  60. }
  61. }
  62. for _, tc := range testCases {
  63. c.Publish(context.Background(), AsRoutingRoute(tc))
  64. time.Sleep(time.Millisecond)
  65. }
  66. return nil
  67. }
  68. if err := publishTestCases(); err != nil {
  69. errCh <- err
  70. }
  71. // Wait for next round of publishing
  72. <-nextPub
  73. if err := publishTestCases(); err != nil {
  74. errCh <- err
  75. }
  76. }()
  77. // Client goroutine
  78. go func() {
  79. defer lis.Close()
  80. conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
  81. if err != nil {
  82. errCh <- err
  83. return
  84. }
  85. defer conn.Close()
  86. client := NewRoutingServiceClient(conn)
  87. // Test retrieving all fields
  88. testRetrievingAllFields := func() error {
  89. streamCtx, streamClose := context.WithCancel(context.Background())
  90. // Test the unsubscription of stream works well
  91. defer func() {
  92. streamClose()
  93. timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second)
  94. defer timeout()
  95. for { // Wait until there's no subscriber in routing stats channel
  96. if len(c.Subscribers()) == 0 {
  97. break
  98. }
  99. if timeOutCtx.Err() != nil {
  100. t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err())
  101. }
  102. }
  103. }()
  104. stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{})
  105. if err != nil {
  106. return err
  107. }
  108. for _, tc := range testCases {
  109. msg, err := stream.Recv()
  110. if err != nil {
  111. return err
  112. }
  113. if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  114. t.Error(r)
  115. }
  116. }
  117. // Test that double subscription will fail
  118. errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{
  119. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  120. })
  121. if err != nil {
  122. return err
  123. }
  124. if _, err := errStream.Recv(); err == nil {
  125. t.Error("unexpected successful subscription")
  126. }
  127. return nil
  128. }
  129. // Test retrieving only a subset of fields
  130. testRetrievingSubsetOfFields := func() error {
  131. streamCtx, streamClose := context.WithCancel(context.Background())
  132. defer streamClose()
  133. stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{
  134. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  135. })
  136. if err != nil {
  137. return err
  138. }
  139. // Send nextPub signal to start next round of publishing
  140. close(nextPub)
  141. for _, tc := range testCases {
  142. msg, err := stream.Recv()
  143. if err != nil {
  144. return err
  145. }
  146. stat := &RoutingContext{ // Only a subset of stats is retrieved
  147. SourceIPs: tc.SourceIPs,
  148. TargetIPs: tc.TargetIPs,
  149. SourcePort: tc.SourcePort,
  150. TargetPort: tc.TargetPort,
  151. TargetDomain: tc.TargetDomain,
  152. OutboundGroupTags: tc.OutboundGroupTags,
  153. OutboundTag: tc.OutboundTag,
  154. }
  155. if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  156. t.Error(r)
  157. }
  158. }
  159. return nil
  160. }
  161. if err := testRetrievingAllFields(); err != nil {
  162. errCh <- err
  163. }
  164. if err := testRetrievingSubsetOfFields(); err != nil {
  165. errCh <- err
  166. }
  167. errCh <- nil // Client passed all tests successfully
  168. }()
  169. // Wait for goroutines to complete
  170. select {
  171. case <-time.After(2 * time.Second):
  172. t.Fatal("Test timeout after 2s")
  173. case err := <-errCh:
  174. if err != nil {
  175. t.Fatal(err)
  176. }
  177. }
  178. }
  179. func TestSerivceTestRoute(t *testing.T) {
  180. c := stats.NewChannel(&stats.ChannelConfig{
  181. SubscriberLimit: 1,
  182. BufferSize: 16,
  183. Blocking: true,
  184. })
  185. common.Must(c.Start())
  186. defer c.Close()
  187. r := new(router.Router)
  188. mockCtl := gomock.NewController(t)
  189. defer mockCtl.Finish()
  190. common.Must(r.Init(&router.Config{
  191. Rule: []*router.RoutingRule{
  192. {
  193. InboundTag: []string{"in"},
  194. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  195. },
  196. {
  197. Protocol: []string{"bittorrent"},
  198. TargetTag: &router.RoutingRule_Tag{Tag: "blocked"},
  199. },
  200. {
  201. PortList: &net.PortList{Range: []*net.PortRange{{From: 8080, To: 8080}}},
  202. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  203. },
  204. {
  205. SourcePortList: &net.PortList{Range: []*net.PortRange{{From: 9999, To: 9999}}},
  206. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  207. },
  208. {
  209. Domain: []*router.Domain{{Type: router.Domain_Domain, Value: "com"}},
  210. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  211. },
  212. {
  213. SourceGeoip: []*router.GeoIP{{CountryCode: "private", Cidr: []*router.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}},
  214. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  215. },
  216. {
  217. UserEmail: []string{"[email protected]"},
  218. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  219. },
  220. {
  221. Networks: []net.Network{net.Network_UDP, net.Network_TCP},
  222. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  223. },
  224. },
  225. }, mocks.NewDNSClient(mockCtl), mocks.NewOutboundManager(mockCtl)))
  226. lis := bufconn.Listen(1024 * 1024)
  227. bufDialer := func(context.Context, string) (net.Conn, error) {
  228. return lis.Dial()
  229. }
  230. errCh := make(chan error)
  231. // Server goroutine
  232. go func() {
  233. server := grpc.NewServer()
  234. RegisterRoutingServiceServer(server, NewRoutingServer(r, c))
  235. errCh <- server.Serve(lis)
  236. }()
  237. // Client goroutine
  238. go func() {
  239. defer lis.Close()
  240. conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
  241. if err != nil {
  242. errCh <- err
  243. }
  244. defer conn.Close()
  245. client := NewRoutingServiceClient(conn)
  246. testCases := []*RoutingContext{
  247. {InboundTag: "in", OutboundTag: "out"},
  248. {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
  249. {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
  250. {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
  251. {Network: net.Network_UDP, Protocol: "bittorrent", OutboundTag: "blocked"},
  252. {User: "[email protected]", OutboundTag: "out"},
  253. {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
  254. }
  255. // Test simple TestRoute
  256. testSimple := func() error {
  257. for _, tc := range testCases {
  258. route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc})
  259. if err != nil {
  260. return err
  261. }
  262. if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  263. t.Error(r)
  264. }
  265. }
  266. return nil
  267. }
  268. // Test TestRoute with special options
  269. testOptions := func() error {
  270. sub, err := c.Subscribe()
  271. if err != nil {
  272. return err
  273. }
  274. for _, tc := range testCases {
  275. route, err := client.TestRoute(context.Background(), &TestRouteRequest{
  276. RoutingContext: tc,
  277. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  278. PublishResult: true,
  279. })
  280. if err != nil {
  281. return err
  282. }
  283. stat := &RoutingContext{ // Only a subset of stats is retrieved
  284. SourceIPs: tc.SourceIPs,
  285. TargetIPs: tc.TargetIPs,
  286. SourcePort: tc.SourcePort,
  287. TargetPort: tc.TargetPort,
  288. TargetDomain: tc.TargetDomain,
  289. OutboundGroupTags: tc.OutboundGroupTags,
  290. OutboundTag: tc.OutboundTag,
  291. }
  292. if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  293. t.Error(r)
  294. }
  295. select { // Check that routing result has been published to statistics channel
  296. case msg, received := <-sub:
  297. if route, ok := msg.(routing.Route); received && ok {
  298. if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  299. t.Error(r)
  300. }
  301. } else {
  302. t.Error("unexpected failure in receiving published routing result for testcase", tc)
  303. }
  304. case <-time.After(100 * time.Millisecond):
  305. t.Error("unexpected failure in receiving published routing result", tc)
  306. }
  307. }
  308. return nil
  309. }
  310. if err := testSimple(); err != nil {
  311. errCh <- err
  312. }
  313. if err := testOptions(); err != nil {
  314. errCh <- err
  315. }
  316. errCh <- nil // Client passed all tests successfully
  317. }()
  318. // Wait for goroutines to complete
  319. select {
  320. case <-time.After(2 * time.Second):
  321. t.Fatal("Test timeout after 2s")
  322. case err := <-errCh:
  323. if err != nil {
  324. t.Fatal(err)
  325. }
  326. }
  327. }