command_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. package command_test
  2. import (
  3. "context"
  4. "testing"
  5. "time"
  6. "github.com/golang/mock/gomock"
  7. "github.com/google/go-cmp/cmp"
  8. "github.com/google/go-cmp/cmp/cmpopts"
  9. "github.com/xtls/xray-core/app/router"
  10. . "github.com/xtls/xray-core/app/router/command"
  11. "github.com/xtls/xray-core/app/stats"
  12. "github.com/xtls/xray-core/common"
  13. "github.com/xtls/xray-core/common/net"
  14. "github.com/xtls/xray-core/features/routing"
  15. "github.com/xtls/xray-core/testing/mocks"
  16. "google.golang.org/grpc"
  17. "google.golang.org/grpc/credentials/insecure"
  18. "google.golang.org/grpc/test/bufconn"
  19. )
  20. func TestServiceSubscribeRoutingStats(t *testing.T) {
  21. c := stats.NewChannel(&stats.ChannelConfig{
  22. SubscriberLimit: 1,
  23. BufferSize: 0,
  24. Blocking: true,
  25. })
  26. common.Must(c.Start())
  27. defer c.Close()
  28. lis := bufconn.Listen(1024 * 1024)
  29. bufDialer := func(context.Context, string) (net.Conn, error) {
  30. return lis.Dial()
  31. }
  32. testCases := []*RoutingContext{
  33. {InboundTag: "in", OutboundTag: "out"},
  34. {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
  35. {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
  36. {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
  37. {Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
  38. {Protocol: "bittorrent", OutboundTag: "blocked"},
  39. {User: "[email protected]", OutboundTag: "out"},
  40. {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
  41. }
  42. errCh := make(chan error)
  43. // Server goroutine
  44. go func() {
  45. server := grpc.NewServer()
  46. RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
  47. errCh <- server.Serve(lis)
  48. }()
  49. // Publisher goroutine
  50. go func() {
  51. publishTestCases := func() error {
  52. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  53. defer cancel()
  54. for { // Wait until there's one subscriber in routing stats channel
  55. if len(c.Subscribers()) > 0 {
  56. break
  57. }
  58. if ctx.Err() != nil {
  59. return ctx.Err()
  60. }
  61. }
  62. for _, tc := range testCases {
  63. c.Publish(context.Background(), AsRoutingRoute(tc))
  64. time.Sleep(time.Millisecond)
  65. }
  66. return nil
  67. }
  68. if err := publishTestCases(); err != nil {
  69. errCh <- err
  70. }
  71. }()
  72. // Client goroutine
  73. go func() {
  74. defer lis.Close()
  75. conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
  76. if err != nil {
  77. errCh <- err
  78. return
  79. }
  80. defer conn.Close()
  81. client := NewRoutingServiceClient(conn)
  82. // Test retrieving all fields
  83. testRetrievingAllFields := func() error {
  84. streamCtx, streamClose := context.WithCancel(context.Background())
  85. // Test the unsubscription of stream works well
  86. defer func() {
  87. streamClose()
  88. timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second)
  89. defer timeout()
  90. for { // Wait until there's no subscriber in routing stats channel
  91. if len(c.Subscribers()) == 0 {
  92. break
  93. }
  94. if timeOutCtx.Err() != nil {
  95. t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err())
  96. }
  97. }
  98. }()
  99. stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{})
  100. if err != nil {
  101. return err
  102. }
  103. for _, tc := range testCases {
  104. msg, err := stream.Recv()
  105. if err != nil {
  106. return err
  107. }
  108. if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  109. t.Error(r)
  110. }
  111. }
  112. // Test that double subscription will fail
  113. errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{
  114. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  115. })
  116. if err != nil {
  117. return err
  118. }
  119. if _, err := errStream.Recv(); err == nil {
  120. t.Error("unexpected successful subscription")
  121. }
  122. return nil
  123. }
  124. if err := testRetrievingAllFields(); err != nil {
  125. errCh <- err
  126. }
  127. errCh <- nil // Client passed all tests successfully
  128. }()
  129. // Wait for goroutines to complete
  130. select {
  131. case <-time.After(2 * time.Second):
  132. t.Fatal("Test timeout after 2s")
  133. case err := <-errCh:
  134. if err != nil {
  135. t.Fatal(err)
  136. }
  137. }
  138. }
  139. func TestServiceSubscribeSubsetOfFields(t *testing.T) {
  140. c := stats.NewChannel(&stats.ChannelConfig{
  141. SubscriberLimit: 1,
  142. BufferSize: 0,
  143. Blocking: true,
  144. })
  145. common.Must(c.Start())
  146. defer c.Close()
  147. lis := bufconn.Listen(1024 * 1024)
  148. bufDialer := func(context.Context, string) (net.Conn, error) {
  149. return lis.Dial()
  150. }
  151. testCases := []*RoutingContext{
  152. {InboundTag: "in", OutboundTag: "out"},
  153. {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
  154. {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
  155. {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
  156. {Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
  157. {Protocol: "bittorrent", OutboundTag: "blocked"},
  158. {User: "[email protected]", OutboundTag: "out"},
  159. {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
  160. }
  161. errCh := make(chan error)
  162. // Server goroutine
  163. go func() {
  164. server := grpc.NewServer()
  165. RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
  166. errCh <- server.Serve(lis)
  167. }()
  168. // Publisher goroutine
  169. go func() {
  170. publishTestCases := func() error {
  171. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  172. defer cancel()
  173. for { // Wait until there's one subscriber in routing stats channel
  174. if len(c.Subscribers()) > 0 {
  175. break
  176. }
  177. if ctx.Err() != nil {
  178. return ctx.Err()
  179. }
  180. }
  181. for _, tc := range testCases {
  182. c.Publish(context.Background(), AsRoutingRoute(tc))
  183. time.Sleep(time.Millisecond)
  184. }
  185. return nil
  186. }
  187. if err := publishTestCases(); err != nil {
  188. errCh <- err
  189. }
  190. }()
  191. // Client goroutine
  192. go func() {
  193. defer lis.Close()
  194. conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
  195. if err != nil {
  196. errCh <- err
  197. return
  198. }
  199. defer conn.Close()
  200. client := NewRoutingServiceClient(conn)
  201. // Test retrieving only a subset of fields
  202. testRetrievingSubsetOfFields := func() error {
  203. streamCtx, streamClose := context.WithCancel(context.Background())
  204. defer streamClose()
  205. stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{
  206. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  207. })
  208. if err != nil {
  209. return err
  210. }
  211. for _, tc := range testCases {
  212. msg, err := stream.Recv()
  213. if err != nil {
  214. return err
  215. }
  216. stat := &RoutingContext{ // Only a subset of stats is retrieved
  217. SourceIPs: tc.SourceIPs,
  218. TargetIPs: tc.TargetIPs,
  219. SourcePort: tc.SourcePort,
  220. TargetPort: tc.TargetPort,
  221. TargetDomain: tc.TargetDomain,
  222. OutboundGroupTags: tc.OutboundGroupTags,
  223. OutboundTag: tc.OutboundTag,
  224. }
  225. if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  226. t.Error(r)
  227. }
  228. }
  229. return nil
  230. }
  231. if err := testRetrievingSubsetOfFields(); err != nil {
  232. errCh <- err
  233. }
  234. errCh <- nil // Client passed all tests successfully
  235. }()
  236. // Wait for goroutines to complete
  237. select {
  238. case <-time.After(2 * time.Second):
  239. t.Fatal("Test timeout after 2s")
  240. case err := <-errCh:
  241. if err != nil {
  242. t.Fatal(err)
  243. }
  244. }
  245. }
  246. func TestServiceTestRoute(t *testing.T) {
  247. c := stats.NewChannel(&stats.ChannelConfig{
  248. SubscriberLimit: 1,
  249. BufferSize: 16,
  250. Blocking: true,
  251. })
  252. common.Must(c.Start())
  253. defer c.Close()
  254. r := new(router.Router)
  255. mockCtl := gomock.NewController(t)
  256. defer mockCtl.Finish()
  257. common.Must(r.Init(context.TODO(), &router.Config{
  258. Rule: []*router.RoutingRule{
  259. {
  260. InboundTag: []string{"in"},
  261. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  262. },
  263. {
  264. Protocol: []string{"bittorrent"},
  265. TargetTag: &router.RoutingRule_Tag{Tag: "blocked"},
  266. },
  267. {
  268. PortList: &net.PortList{Range: []*net.PortRange{{From: 8080, To: 8080}}},
  269. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  270. },
  271. {
  272. SourcePortList: &net.PortList{Range: []*net.PortRange{{From: 9999, To: 9999}}},
  273. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  274. },
  275. {
  276. Domain: []*router.Domain{{Type: router.Domain_Domain, Value: "com"}},
  277. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  278. },
  279. {
  280. SourceGeoip: []*router.GeoIP{{CountryCode: "private", Cidr: []*router.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}},
  281. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  282. },
  283. {
  284. UserEmail: []string{"[email protected]"},
  285. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  286. },
  287. {
  288. Networks: []net.Network{net.Network_UDP, net.Network_TCP},
  289. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  290. },
  291. },
  292. }, mocks.NewDNSClient(mockCtl), mocks.NewOutboundManager(mockCtl), nil))
  293. lis := bufconn.Listen(1024 * 1024)
  294. bufDialer := func(context.Context, string) (net.Conn, error) {
  295. return lis.Dial()
  296. }
  297. errCh := make(chan error)
  298. // Server goroutine
  299. go func() {
  300. server := grpc.NewServer()
  301. RegisterRoutingServiceServer(server, NewRoutingServer(r, c))
  302. errCh <- server.Serve(lis)
  303. }()
  304. // Client goroutine
  305. go func() {
  306. defer lis.Close()
  307. conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithTransportCredentials(insecure.NewCredentials()))
  308. if err != nil {
  309. errCh <- err
  310. }
  311. defer conn.Close()
  312. client := NewRoutingServiceClient(conn)
  313. testCases := []*RoutingContext{
  314. {InboundTag: "in", OutboundTag: "out"},
  315. {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
  316. {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
  317. {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
  318. {Network: net.Network_UDP, Protocol: "bittorrent", OutboundTag: "blocked"},
  319. {User: "[email protected]", OutboundTag: "out"},
  320. {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
  321. }
  322. // Test simple TestRoute
  323. testSimple := func() error {
  324. for _, tc := range testCases {
  325. route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc})
  326. if err != nil {
  327. return err
  328. }
  329. if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  330. t.Error(r)
  331. }
  332. }
  333. return nil
  334. }
  335. // Test TestRoute with special options
  336. testOptions := func() error {
  337. sub, err := c.Subscribe()
  338. if err != nil {
  339. return err
  340. }
  341. for _, tc := range testCases {
  342. route, err := client.TestRoute(context.Background(), &TestRouteRequest{
  343. RoutingContext: tc,
  344. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  345. PublishResult: true,
  346. })
  347. if err != nil {
  348. return err
  349. }
  350. stat := &RoutingContext{ // Only a subset of stats is retrieved
  351. SourceIPs: tc.SourceIPs,
  352. TargetIPs: tc.TargetIPs,
  353. SourcePort: tc.SourcePort,
  354. TargetPort: tc.TargetPort,
  355. TargetDomain: tc.TargetDomain,
  356. OutboundGroupTags: tc.OutboundGroupTags,
  357. OutboundTag: tc.OutboundTag,
  358. }
  359. if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  360. t.Error(r)
  361. }
  362. select { // Check that routing result has been published to statistics channel
  363. case msg, received := <-sub:
  364. if route, ok := msg.(routing.Route); received && ok {
  365. if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  366. t.Error(r)
  367. }
  368. } else {
  369. t.Error("unexpected failure in receiving published routing result for testcase", tc)
  370. }
  371. case <-time.After(100 * time.Millisecond):
  372. t.Error("unexpected failure in receiving published routing result", tc)
  373. }
  374. }
  375. return nil
  376. }
  377. if err := testSimple(); err != nil {
  378. errCh <- err
  379. }
  380. if err := testOptions(); err != nil {
  381. errCh <- err
  382. }
  383. errCh <- nil // Client passed all tests successfully
  384. }()
  385. // Wait for goroutines to complete
  386. select {
  387. case <-time.After(2 * time.Second):
  388. t.Fatal("Test timeout after 2s")
  389. case err := <-errCh:
  390. if err != nil {
  391. t.Fatal(err)
  392. }
  393. }
  394. }