ws_test.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. package websocket_test
  2. import (
  3. "context"
  4. "runtime"
  5. "testing"
  6. "time"
  7. "github.com/xtls/xray-core/common"
  8. "github.com/xtls/xray-core/common/net"
  9. "github.com/xtls/xray-core/common/protocol/tls/cert"
  10. "github.com/xtls/xray-core/testing/servers/tcp"
  11. "github.com/xtls/xray-core/transport/internet"
  12. "github.com/xtls/xray-core/transport/internet/stat"
  13. "github.com/xtls/xray-core/transport/internet/tls"
  14. . "github.com/xtls/xray-core/transport/internet/websocket"
  15. )
  16. func Test_listenWSAndDial(t *testing.T) {
  17. listenPort := tcp.PickPort()
  18. listen, err := ListenWS(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
  19. ProtocolName: "websocket",
  20. ProtocolSettings: &Config{
  21. Path: "ws",
  22. },
  23. }, func(conn stat.Connection) {
  24. go func(c stat.Connection) {
  25. defer c.Close()
  26. var b [1024]byte
  27. c.SetReadDeadline(time.Now().Add(2 * time.Second))
  28. _, err := c.Read(b[:])
  29. if err != nil {
  30. return
  31. }
  32. common.Must2(c.Write([]byte("Response")))
  33. }(conn)
  34. })
  35. common.Must(err)
  36. ctx := context.Background()
  37. streamSettings := &internet.MemoryStreamConfig{
  38. ProtocolName: "websocket",
  39. ProtocolSettings: &Config{Path: "ws"},
  40. }
  41. conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  42. common.Must(err)
  43. _, err = conn.Write([]byte("Test connection 1"))
  44. common.Must(err)
  45. var b [1024]byte
  46. n, err := conn.Read(b[:])
  47. common.Must(err)
  48. if string(b[:n]) != "Response" {
  49. t.Error("response: ", string(b[:n]))
  50. }
  51. common.Must(conn.Close())
  52. <-time.After(time.Second * 5)
  53. conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  54. common.Must(err)
  55. _, err = conn.Write([]byte("Test connection 2"))
  56. common.Must(err)
  57. n, err = conn.Read(b[:])
  58. common.Must(err)
  59. if string(b[:n]) != "Response" {
  60. t.Error("response: ", string(b[:n]))
  61. }
  62. common.Must(conn.Close())
  63. common.Must(listen.Close())
  64. }
  65. func TestDialWithRemoteAddr(t *testing.T) {
  66. listenPort := tcp.PickPort()
  67. listen, err := ListenWS(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
  68. ProtocolName: "websocket",
  69. ProtocolSettings: &Config{
  70. Path: "ws",
  71. },
  72. }, func(conn stat.Connection) {
  73. go func(c stat.Connection) {
  74. defer c.Close()
  75. var b [1024]byte
  76. _, err := c.Read(b[:])
  77. // common.Must(err)
  78. if err != nil {
  79. return
  80. }
  81. _, err = c.Write([]byte(c.RemoteAddr().String()))
  82. common.Must(err)
  83. }(conn)
  84. })
  85. common.Must(err)
  86. conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), &internet.MemoryStreamConfig{
  87. ProtocolName: "websocket",
  88. ProtocolSettings: &Config{Path: "ws", Header: map[string]string{"X-Forwarded-For": "1.1.1.1"}},
  89. })
  90. common.Must(err)
  91. _, err = conn.Write([]byte("Test connection 1"))
  92. common.Must(err)
  93. var b [1024]byte
  94. n, err := conn.Read(b[:])
  95. common.Must(err)
  96. if string(b[:n]) != "1.1.1.1:0" {
  97. t.Error("response: ", string(b[:n]))
  98. }
  99. common.Must(listen.Close())
  100. }
  101. func Test_listenWSAndDial_TLS(t *testing.T) {
  102. listenPort := tcp.PickPort()
  103. if runtime.GOARCH == "arm64" {
  104. return
  105. }
  106. start := time.Now()
  107. streamSettings := &internet.MemoryStreamConfig{
  108. ProtocolName: "websocket",
  109. ProtocolSettings: &Config{
  110. Path: "wss",
  111. },
  112. SecurityType: "tls",
  113. SecuritySettings: &tls.Config{
  114. AllowInsecure: true,
  115. Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
  116. },
  117. }
  118. listen, err := ListenWS(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
  119. go func() {
  120. _ = conn.Close()
  121. }()
  122. })
  123. common.Must(err)
  124. defer listen.Close()
  125. conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  126. common.Must(err)
  127. _ = conn.Close()
  128. end := time.Now()
  129. if !end.Before(start.Add(time.Second * 5)) {
  130. t.Error("end: ", end, " start: ", start)
  131. }
  132. }