command_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. package command_test
  2. import (
  3. "context"
  4. "testing"
  5. "time"
  6. "github.com/golang/mock/gomock"
  7. "github.com/google/go-cmp/cmp"
  8. "github.com/google/go-cmp/cmp/cmpopts"
  9. "github.com/xtls/xray-core/app/router"
  10. . "github.com/xtls/xray-core/app/router/command"
  11. "github.com/xtls/xray-core/app/stats"
  12. "github.com/xtls/xray-core/common"
  13. "github.com/xtls/xray-core/common/net"
  14. "github.com/xtls/xray-core/features/routing"
  15. "github.com/xtls/xray-core/testing/mocks"
  16. "google.golang.org/grpc"
  17. "google.golang.org/grpc/test/bufconn"
  18. )
  19. func TestServiceSubscribeRoutingStats(t *testing.T) {
  20. c := stats.NewChannel(&stats.ChannelConfig{
  21. SubscriberLimit: 1,
  22. BufferSize: 0,
  23. Blocking: true,
  24. })
  25. common.Must(c.Start())
  26. defer c.Close()
  27. lis := bufconn.Listen(1024 * 1024)
  28. bufDialer := func(context.Context, string) (net.Conn, error) {
  29. return lis.Dial()
  30. }
  31. testCases := []*RoutingContext{
  32. {InboundTag: "in", OutboundTag: "out"},
  33. {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
  34. {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
  35. {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
  36. {Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
  37. {Protocol: "bittorrent", OutboundTag: "blocked"},
  38. {User: "[email protected]", OutboundTag: "out"},
  39. {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
  40. }
  41. errCh := make(chan error)
  42. // Server goroutine
  43. go func() {
  44. server := grpc.NewServer()
  45. RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
  46. errCh <- server.Serve(lis)
  47. }()
  48. // Publisher goroutine
  49. go func() {
  50. publishTestCases := func() error {
  51. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  52. defer cancel()
  53. for { // Wait until there's one subscriber in routing stats channel
  54. if len(c.Subscribers()) > 0 {
  55. break
  56. }
  57. if ctx.Err() != nil {
  58. return ctx.Err()
  59. }
  60. }
  61. for _, tc := range testCases {
  62. c.Publish(context.Background(), AsRoutingRoute(tc))
  63. time.Sleep(time.Millisecond)
  64. }
  65. return nil
  66. }
  67. if err := publishTestCases(); err != nil {
  68. errCh <- err
  69. }
  70. }()
  71. // Client goroutine
  72. go func() {
  73. defer lis.Close()
  74. conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
  75. if err != nil {
  76. errCh <- err
  77. return
  78. }
  79. defer conn.Close()
  80. client := NewRoutingServiceClient(conn)
  81. // Test retrieving all fields
  82. testRetrievingAllFields := func() error {
  83. streamCtx, streamClose := context.WithCancel(context.Background())
  84. // Test the unsubscription of stream works well
  85. defer func() {
  86. streamClose()
  87. timeOutCtx, timeout := context.WithTimeout(context.Background(), time.Second)
  88. defer timeout()
  89. for { // Wait until there's no subscriber in routing stats channel
  90. if len(c.Subscribers()) == 0 {
  91. break
  92. }
  93. if timeOutCtx.Err() != nil {
  94. t.Error("unexpected subscribers not decreased in channel", timeOutCtx.Err())
  95. }
  96. }
  97. }()
  98. stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{})
  99. if err != nil {
  100. return err
  101. }
  102. for _, tc := range testCases {
  103. msg, err := stream.Recv()
  104. if err != nil {
  105. return err
  106. }
  107. if r := cmp.Diff(msg, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  108. t.Error(r)
  109. }
  110. }
  111. // Test that double subscription will fail
  112. errStream, err := client.SubscribeRoutingStats(context.Background(), &SubscribeRoutingStatsRequest{
  113. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  114. })
  115. if err != nil {
  116. return err
  117. }
  118. if _, err := errStream.Recv(); err == nil {
  119. t.Error("unexpected successful subscription")
  120. }
  121. return nil
  122. }
  123. if err := testRetrievingAllFields(); err != nil {
  124. errCh <- err
  125. }
  126. errCh <- nil // Client passed all tests successfully
  127. }()
  128. // Wait for goroutines to complete
  129. select {
  130. case <-time.After(2 * time.Second):
  131. t.Fatal("Test timeout after 2s")
  132. case err := <-errCh:
  133. if err != nil {
  134. t.Fatal(err)
  135. }
  136. }
  137. }
  138. func TestServiceSubscribeSubsetOfFields(t *testing.T) {
  139. c := stats.NewChannel(&stats.ChannelConfig{
  140. SubscriberLimit: 1,
  141. BufferSize: 0,
  142. Blocking: true,
  143. })
  144. common.Must(c.Start())
  145. defer c.Close()
  146. lis := bufconn.Listen(1024 * 1024)
  147. bufDialer := func(context.Context, string) (net.Conn, error) {
  148. return lis.Dial()
  149. }
  150. testCases := []*RoutingContext{
  151. {InboundTag: "in", OutboundTag: "out"},
  152. {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
  153. {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
  154. {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
  155. {Network: net.Network_UDP, OutboundGroupTags: []string{"outergroup", "innergroup"}, OutboundTag: "out"},
  156. {Protocol: "bittorrent", OutboundTag: "blocked"},
  157. {User: "[email protected]", OutboundTag: "out"},
  158. {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
  159. }
  160. errCh := make(chan error)
  161. // Server goroutine
  162. go func() {
  163. server := grpc.NewServer()
  164. RegisterRoutingServiceServer(server, NewRoutingServer(nil, c))
  165. errCh <- server.Serve(lis)
  166. }()
  167. // Publisher goroutine
  168. go func() {
  169. publishTestCases := func() error {
  170. ctx, cancel := context.WithTimeout(context.Background(), time.Second)
  171. defer cancel()
  172. for { // Wait until there's one subscriber in routing stats channel
  173. if len(c.Subscribers()) > 0 {
  174. break
  175. }
  176. if ctx.Err() != nil {
  177. return ctx.Err()
  178. }
  179. }
  180. for _, tc := range testCases {
  181. c.Publish(context.Background(), AsRoutingRoute(tc))
  182. time.Sleep(time.Millisecond)
  183. }
  184. return nil
  185. }
  186. if err := publishTestCases(); err != nil {
  187. errCh <- err
  188. }
  189. }()
  190. // Client goroutine
  191. go func() {
  192. defer lis.Close()
  193. conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
  194. if err != nil {
  195. errCh <- err
  196. return
  197. }
  198. defer conn.Close()
  199. client := NewRoutingServiceClient(conn)
  200. // Test retrieving only a subset of fields
  201. testRetrievingSubsetOfFields := func() error {
  202. streamCtx, streamClose := context.WithCancel(context.Background())
  203. defer streamClose()
  204. stream, err := client.SubscribeRoutingStats(streamCtx, &SubscribeRoutingStatsRequest{
  205. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  206. })
  207. if err != nil {
  208. return err
  209. }
  210. for _, tc := range testCases {
  211. msg, err := stream.Recv()
  212. if err != nil {
  213. return err
  214. }
  215. stat := &RoutingContext{ // Only a subset of stats is retrieved
  216. SourceIPs: tc.SourceIPs,
  217. TargetIPs: tc.TargetIPs,
  218. SourcePort: tc.SourcePort,
  219. TargetPort: tc.TargetPort,
  220. TargetDomain: tc.TargetDomain,
  221. OutboundGroupTags: tc.OutboundGroupTags,
  222. OutboundTag: tc.OutboundTag,
  223. }
  224. if r := cmp.Diff(msg, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  225. t.Error(r)
  226. }
  227. }
  228. return nil
  229. }
  230. if err := testRetrievingSubsetOfFields(); err != nil {
  231. errCh <- err
  232. }
  233. errCh <- nil // Client passed all tests successfully
  234. }()
  235. // Wait for goroutines to complete
  236. select {
  237. case <-time.After(2 * time.Second):
  238. t.Fatal("Test timeout after 2s")
  239. case err := <-errCh:
  240. if err != nil {
  241. t.Fatal(err)
  242. }
  243. }
  244. }
  245. func TestSerivceTestRoute(t *testing.T) {
  246. c := stats.NewChannel(&stats.ChannelConfig{
  247. SubscriberLimit: 1,
  248. BufferSize: 16,
  249. Blocking: true,
  250. })
  251. common.Must(c.Start())
  252. defer c.Close()
  253. r := new(router.Router)
  254. mockCtl := gomock.NewController(t)
  255. defer mockCtl.Finish()
  256. common.Must(r.Init(context.TODO(), &router.Config{
  257. Rule: []*router.RoutingRule{
  258. {
  259. InboundTag: []string{"in"},
  260. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  261. },
  262. {
  263. Protocol: []string{"bittorrent"},
  264. TargetTag: &router.RoutingRule_Tag{Tag: "blocked"},
  265. },
  266. {
  267. PortList: &net.PortList{Range: []*net.PortRange{{From: 8080, To: 8080}}},
  268. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  269. },
  270. {
  271. SourcePortList: &net.PortList{Range: []*net.PortRange{{From: 9999, To: 9999}}},
  272. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  273. },
  274. {
  275. Domain: []*router.Domain{{Type: router.Domain_Domain, Value: "com"}},
  276. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  277. },
  278. {
  279. SourceGeoip: []*router.GeoIP{{CountryCode: "private", Cidr: []*router.CIDR{{Ip: []byte{127, 0, 0, 0}, Prefix: 8}}}},
  280. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  281. },
  282. {
  283. UserEmail: []string{"[email protected]"},
  284. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  285. },
  286. {
  287. Networks: []net.Network{net.Network_UDP, net.Network_TCP},
  288. TargetTag: &router.RoutingRule_Tag{Tag: "out"},
  289. },
  290. },
  291. }, mocks.NewDNSClient(mockCtl), mocks.NewOutboundManager(mockCtl)))
  292. lis := bufconn.Listen(1024 * 1024)
  293. bufDialer := func(context.Context, string) (net.Conn, error) {
  294. return lis.Dial()
  295. }
  296. errCh := make(chan error)
  297. // Server goroutine
  298. go func() {
  299. server := grpc.NewServer()
  300. RegisterRoutingServiceServer(server, NewRoutingServer(r, c))
  301. errCh <- server.Serve(lis)
  302. }()
  303. // Client goroutine
  304. go func() {
  305. defer lis.Close()
  306. conn, err := grpc.DialContext(context.Background(), "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
  307. if err != nil {
  308. errCh <- err
  309. }
  310. defer conn.Close()
  311. client := NewRoutingServiceClient(conn)
  312. testCases := []*RoutingContext{
  313. {InboundTag: "in", OutboundTag: "out"},
  314. {TargetIPs: [][]byte{{1, 2, 3, 4}}, TargetPort: 8080, OutboundTag: "out"},
  315. {TargetDomain: "example.com", TargetPort: 443, OutboundTag: "out"},
  316. {SourcePort: 9999, TargetPort: 9999, OutboundTag: "out"},
  317. {Network: net.Network_UDP, Protocol: "bittorrent", OutboundTag: "blocked"},
  318. {User: "[email protected]", OutboundTag: "out"},
  319. {SourceIPs: [][]byte{{127, 0, 0, 1}}, Attributes: map[string]string{"attr": "value"}, OutboundTag: "out"},
  320. }
  321. // Test simple TestRoute
  322. testSimple := func() error {
  323. for _, tc := range testCases {
  324. route, err := client.TestRoute(context.Background(), &TestRouteRequest{RoutingContext: tc})
  325. if err != nil {
  326. return err
  327. }
  328. if r := cmp.Diff(route, tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  329. t.Error(r)
  330. }
  331. }
  332. return nil
  333. }
  334. // Test TestRoute with special options
  335. testOptions := func() error {
  336. sub, err := c.Subscribe()
  337. if err != nil {
  338. return err
  339. }
  340. for _, tc := range testCases {
  341. route, err := client.TestRoute(context.Background(), &TestRouteRequest{
  342. RoutingContext: tc,
  343. FieldSelectors: []string{"ip", "port", "domain", "outbound"},
  344. PublishResult: true,
  345. })
  346. if err != nil {
  347. return err
  348. }
  349. stat := &RoutingContext{ // Only a subset of stats is retrieved
  350. SourceIPs: tc.SourceIPs,
  351. TargetIPs: tc.TargetIPs,
  352. SourcePort: tc.SourcePort,
  353. TargetPort: tc.TargetPort,
  354. TargetDomain: tc.TargetDomain,
  355. OutboundGroupTags: tc.OutboundGroupTags,
  356. OutboundTag: tc.OutboundTag,
  357. }
  358. if r := cmp.Diff(route, stat, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  359. t.Error(r)
  360. }
  361. select { // Check that routing result has been published to statistics channel
  362. case msg, received := <-sub:
  363. if route, ok := msg.(routing.Route); received && ok {
  364. if r := cmp.Diff(AsProtobufMessage(nil)(route), tc, cmpopts.IgnoreUnexported(RoutingContext{})); r != "" {
  365. t.Error(r)
  366. }
  367. } else {
  368. t.Error("unexpected failure in receiving published routing result for testcase", tc)
  369. }
  370. case <-time.After(100 * time.Millisecond):
  371. t.Error("unexpected failure in receiving published routing result", tc)
  372. }
  373. }
  374. return nil
  375. }
  376. if err := testSimple(); err != nil {
  377. errCh <- err
  378. }
  379. if err := testOptions(); err != nil {
  380. errCh <- err
  381. }
  382. errCh <- nil // Client passed all tests successfully
  383. }()
  384. // Wait for goroutines to complete
  385. select {
  386. case <-time.After(2 * time.Second):
  387. t.Fatal("Test timeout after 2s")
  388. case err := <-errCh:
  389. if err != nil {
  390. t.Fatal(err)
  391. }
  392. }
  393. }