ws_test.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package websocket_test
  2. import (
  3. "context"
  4. "runtime"
  5. "testing"
  6. "time"
  7. "github.com/xtls/xray-core/common"
  8. "github.com/xtls/xray-core/common/net"
  9. "github.com/xtls/xray-core/common/protocol/tls/cert"
  10. "github.com/xtls/xray-core/testing/servers/tcp"
  11. "github.com/xtls/xray-core/transport/internet"
  12. "github.com/xtls/xray-core/transport/internet/stat"
  13. "github.com/xtls/xray-core/transport/internet/tls"
  14. . "github.com/xtls/xray-core/transport/internet/websocket"
  15. )
  16. func Test_listenWSAndDial(t *testing.T) {
  17. listenPort := tcp.PickPort()
  18. listen, err := ListenWS(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
  19. ProtocolName: "websocket",
  20. ProtocolSettings: &Config{
  21. Path: "ws",
  22. },
  23. }, func(conn stat.Connection) {
  24. go func(c stat.Connection) {
  25. defer c.Close()
  26. var b [1024]byte
  27. c.SetReadDeadline(time.Now().Add(2 * time.Second))
  28. _, err := c.Read(b[:])
  29. if err != nil {
  30. return
  31. }
  32. common.Must2(c.Write([]byte("Response")))
  33. }(conn)
  34. })
  35. common.Must(err)
  36. ctx := context.Background()
  37. streamSettings := &internet.MemoryStreamConfig{
  38. ProtocolName: "websocket",
  39. ProtocolSettings: &Config{Path: "ws"},
  40. }
  41. conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  42. common.Must(err)
  43. _, err = conn.Write([]byte("Test connection 1"))
  44. common.Must(err)
  45. var b [1024]byte
  46. n, err := conn.Read(b[:])
  47. common.Must(err)
  48. if string(b[:n]) != "Response" {
  49. t.Error("response: ", string(b[:n]))
  50. }
  51. common.Must(conn.Close())
  52. conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  53. common.Must(err)
  54. _, err = conn.Write([]byte("Test connection 2"))
  55. common.Must(err)
  56. n, err = conn.Read(b[:])
  57. common.Must(err)
  58. if string(b[:n]) != "Response" {
  59. t.Error("response: ", string(b[:n]))
  60. }
  61. common.Must(conn.Close())
  62. common.Must(listen.Close())
  63. }
  64. func TestDialWithRemoteAddr(t *testing.T) {
  65. listenPort := tcp.PickPort()
  66. listen, err := ListenWS(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
  67. ProtocolName: "websocket",
  68. ProtocolSettings: &Config{
  69. Path: "ws",
  70. },
  71. }, func(conn stat.Connection) {
  72. go func(c stat.Connection) {
  73. defer c.Close()
  74. var b [1024]byte
  75. _, err := c.Read(b[:])
  76. // common.Must(err)
  77. if err != nil {
  78. return
  79. }
  80. _, err = c.Write([]byte(c.RemoteAddr().String()))
  81. common.Must(err)
  82. }(conn)
  83. })
  84. common.Must(err)
  85. conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), &internet.MemoryStreamConfig{
  86. ProtocolName: "websocket",
  87. ProtocolSettings: &Config{Path: "ws", Header: map[string]string{"X-Forwarded-For": "1.1.1.1"}},
  88. })
  89. common.Must(err)
  90. _, err = conn.Write([]byte("Test connection 1"))
  91. common.Must(err)
  92. var b [1024]byte
  93. n, err := conn.Read(b[:])
  94. common.Must(err)
  95. if string(b[:n]) != "1.1.1.1:0" {
  96. t.Error("response: ", string(b[:n]))
  97. }
  98. common.Must(listen.Close())
  99. }
  100. func Test_listenWSAndDial_TLS(t *testing.T) {
  101. listenPort := tcp.PickPort()
  102. if runtime.GOARCH == "arm64" {
  103. return
  104. }
  105. start := time.Now()
  106. streamSettings := &internet.MemoryStreamConfig{
  107. ProtocolName: "websocket",
  108. ProtocolSettings: &Config{
  109. Path: "wss",
  110. },
  111. SecurityType: "tls",
  112. SecuritySettings: &tls.Config{
  113. AllowInsecure: true,
  114. Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
  115. },
  116. }
  117. listen, err := ListenWS(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
  118. go func() {
  119. _ = conn.Close()
  120. }()
  121. })
  122. common.Must(err)
  123. defer listen.Close()
  124. conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  125. common.Must(err)
  126. _ = conn.Close()
  127. end := time.Now()
  128. if !end.Before(start.Add(time.Second * 5)) {
  129. t.Error("end: ", end, " start: ", start)
  130. }
  131. }