splithttp_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. package splithttp_test
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/rand"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "runtime"
  10. "testing"
  11. "time"
  12. "github.com/google/go-cmp/cmp"
  13. "github.com/xtls/xray-core/common"
  14. "github.com/xtls/xray-core/common/buf"
  15. "github.com/xtls/xray-core/common/net"
  16. "github.com/xtls/xray-core/common/protocol/tls/cert"
  17. "github.com/xtls/xray-core/testing/servers/tcp"
  18. "github.com/xtls/xray-core/testing/servers/udp"
  19. "github.com/xtls/xray-core/transport/internet"
  20. . "github.com/xtls/xray-core/transport/internet/splithttp"
  21. "github.com/xtls/xray-core/transport/internet/stat"
  22. "github.com/xtls/xray-core/transport/internet/tls"
  23. )
  24. func Test_ListenXHAndDial(t *testing.T) {
  25. listenPort := tcp.PickPort()
  26. listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
  27. ProtocolName: "splithttp",
  28. ProtocolSettings: &Config{
  29. Path: "/sh",
  30. },
  31. }, func(conn stat.Connection) {
  32. go func(c stat.Connection) {
  33. defer c.Close()
  34. var b [1024]byte
  35. c.SetReadDeadline(time.Now().Add(2 * time.Second))
  36. _, err := c.Read(b[:])
  37. if err != nil {
  38. return
  39. }
  40. common.Must2(c.Write([]byte("Response")))
  41. }(conn)
  42. })
  43. common.Must(err)
  44. ctx := context.Background()
  45. streamSettings := &internet.MemoryStreamConfig{
  46. ProtocolName: "splithttp",
  47. ProtocolSettings: &Config{Path: "sh"},
  48. }
  49. conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  50. common.Must(err)
  51. _, err = conn.Write([]byte("Test connection 1"))
  52. common.Must(err)
  53. var b [1024]byte
  54. fmt.Println("test2")
  55. n, _ := io.ReadFull(conn, b[:])
  56. fmt.Println("string is", n)
  57. if string(b[:n]) != "Response" {
  58. t.Error("response: ", string(b[:n]))
  59. }
  60. common.Must(conn.Close())
  61. conn, err = Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  62. common.Must(err)
  63. _, err = conn.Write([]byte("Test connection 2"))
  64. common.Must(err)
  65. n, _ = io.ReadFull(conn, b[:])
  66. common.Must(err)
  67. if string(b[:n]) != "Response" {
  68. t.Error("response: ", string(b[:n]))
  69. }
  70. common.Must(conn.Close())
  71. common.Must(listen.Close())
  72. }
  73. func TestDialWithRemoteAddr(t *testing.T) {
  74. listenPort := tcp.PickPort()
  75. listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
  76. ProtocolName: "splithttp",
  77. ProtocolSettings: &Config{
  78. Path: "sh",
  79. },
  80. }, func(conn stat.Connection) {
  81. go func(c stat.Connection) {
  82. defer c.Close()
  83. var b [1024]byte
  84. _, err := c.Read(b[:])
  85. // common.Must(err)
  86. if err != nil {
  87. return
  88. }
  89. _, err = c.Write([]byte(c.RemoteAddr().String()))
  90. common.Must(err)
  91. }(conn)
  92. })
  93. common.Must(err)
  94. conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), &internet.MemoryStreamConfig{
  95. ProtocolName: "splithttp",
  96. ProtocolSettings: &Config{Path: "sh", Headers: map[string]string{"X-Forwarded-For": "1.1.1.1"}},
  97. })
  98. common.Must(err)
  99. _, err = conn.Write([]byte("Test connection 1"))
  100. common.Must(err)
  101. var b [1024]byte
  102. n, _ := io.ReadFull(conn, b[:])
  103. if string(b[:n]) != "1.1.1.1:0" {
  104. t.Error("response: ", string(b[:n]))
  105. }
  106. common.Must(listen.Close())
  107. }
  108. func Test_ListenXHAndDial_TLS(t *testing.T) {
  109. if runtime.GOARCH == "arm64" {
  110. return
  111. }
  112. listenPort := tcp.PickPort()
  113. start := time.Now()
  114. streamSettings := &internet.MemoryStreamConfig{
  115. ProtocolName: "splithttp",
  116. ProtocolSettings: &Config{
  117. Path: "shs",
  118. },
  119. SecurityType: "tls",
  120. SecuritySettings: &tls.Config{
  121. AllowInsecure: true,
  122. Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
  123. },
  124. }
  125. listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
  126. go func() {
  127. defer conn.Close()
  128. var b [1024]byte
  129. conn.SetReadDeadline(time.Now().Add(2 * time.Second))
  130. _, err := conn.Read(b[:])
  131. if err != nil {
  132. return
  133. }
  134. common.Must2(conn.Write([]byte("Response")))
  135. }()
  136. })
  137. common.Must(err)
  138. defer listen.Close()
  139. conn, err := Dial(context.Background(), net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  140. common.Must(err)
  141. _, err = conn.Write([]byte("Test connection 1"))
  142. common.Must(err)
  143. var b [1024]byte
  144. n, _ := io.ReadFull(conn, b[:])
  145. if string(b[:n]) != "Response" {
  146. t.Error("response: ", string(b[:n]))
  147. }
  148. end := time.Now()
  149. if !end.Before(start.Add(time.Second * 5)) {
  150. t.Error("end: ", end, " start: ", start)
  151. }
  152. }
  153. func Test_ListenXHAndDial_H2C(t *testing.T) {
  154. if runtime.GOARCH == "arm64" {
  155. return
  156. }
  157. listenPort := tcp.PickPort()
  158. streamSettings := &internet.MemoryStreamConfig{
  159. ProtocolName: "splithttp",
  160. ProtocolSettings: &Config{
  161. Path: "shs",
  162. },
  163. }
  164. listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
  165. go func() {
  166. _ = conn.Close()
  167. }()
  168. })
  169. common.Must(err)
  170. defer listen.Close()
  171. protocols := new(http.Protocols)
  172. protocols.SetUnencryptedHTTP2(true)
  173. client := http.Client{
  174. Transport: &http.Transport{
  175. Protocols: protocols,
  176. },
  177. }
  178. resp, err := client.Get("http://" + net.LocalHostIP.String() + ":" + listenPort.String())
  179. common.Must(err)
  180. if resp.StatusCode != 404 {
  181. t.Error("Expected 404 but got:", resp.StatusCode)
  182. }
  183. if resp.ProtoMajor != 2 {
  184. t.Error("Expected h2 but got:", resp.ProtoMajor)
  185. }
  186. }
  187. func Test_ListenXHAndDial_QUIC(t *testing.T) {
  188. if runtime.GOARCH == "arm64" {
  189. return
  190. }
  191. listenPort := udp.PickPort()
  192. start := time.Now()
  193. streamSettings := &internet.MemoryStreamConfig{
  194. ProtocolName: "splithttp",
  195. ProtocolSettings: &Config{
  196. Path: "shs",
  197. },
  198. SecurityType: "tls",
  199. SecuritySettings: &tls.Config{
  200. AllowInsecure: true,
  201. Certificate: []*tls.Certificate{tls.ParseCertificate(cert.MustGenerate(nil, cert.CommonName("localhost")))},
  202. NextProtocol: []string{"h3"},
  203. },
  204. }
  205. serverClosed := false
  206. listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
  207. go func() {
  208. defer conn.Close()
  209. b := buf.New()
  210. defer b.Release()
  211. for {
  212. b.Clear()
  213. if _, err := b.ReadFrom(conn); err != nil {
  214. break
  215. }
  216. common.Must2(conn.Write(b.Bytes()))
  217. }
  218. serverClosed = true
  219. }()
  220. })
  221. common.Must(err)
  222. defer listen.Close()
  223. time.Sleep(time.Second)
  224. conn, err := Dial(context.Background(), net.UDPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  225. common.Must(err)
  226. const N = 1024
  227. b1 := make([]byte, N)
  228. common.Must2(rand.Read(b1))
  229. b2 := buf.New()
  230. common.Must2(conn.Write(b1))
  231. b2.Clear()
  232. common.Must2(b2.ReadFullFrom(conn, N))
  233. if r := cmp.Diff(b2.Bytes(), b1); r != "" {
  234. t.Error(r)
  235. }
  236. common.Must2(conn.Write(b1))
  237. b2.Clear()
  238. common.Must2(b2.ReadFullFrom(conn, N))
  239. if r := cmp.Diff(b2.Bytes(), b1); r != "" {
  240. t.Error(r)
  241. }
  242. conn.Close()
  243. time.Sleep(100 * time.Millisecond)
  244. if !serverClosed {
  245. t.Error("server did not get closed")
  246. }
  247. end := time.Now()
  248. if !end.Before(start.Add(time.Second * 5)) {
  249. t.Error("end: ", end, " start: ", start)
  250. }
  251. }
  252. func Test_ListenXHAndDial_Unix(t *testing.T) {
  253. tempDir := t.TempDir()
  254. tempSocket := tempDir + "/server.sock"
  255. listen, err := ListenXH(context.Background(), net.DomainAddress(tempSocket), 0, &internet.MemoryStreamConfig{
  256. ProtocolName: "splithttp",
  257. ProtocolSettings: &Config{
  258. Path: "/sh",
  259. },
  260. }, func(conn stat.Connection) {
  261. go func(c stat.Connection) {
  262. defer c.Close()
  263. var b [1024]byte
  264. c.SetReadDeadline(time.Now().Add(2 * time.Second))
  265. _, err := c.Read(b[:])
  266. if err != nil {
  267. return
  268. }
  269. common.Must2(c.Write([]byte("Response")))
  270. }(conn)
  271. })
  272. common.Must(err)
  273. ctx := context.Background()
  274. streamSettings := &internet.MemoryStreamConfig{
  275. ProtocolName: "splithttp",
  276. ProtocolSettings: &Config{
  277. Host: "example.com",
  278. Path: "sh",
  279. },
  280. }
  281. conn, err := Dial(ctx, net.UnixDestination(net.DomainAddress(tempSocket)), streamSettings)
  282. common.Must(err)
  283. _, err = conn.Write([]byte("Test connection 1"))
  284. common.Must(err)
  285. var b [1024]byte
  286. fmt.Println("test2")
  287. n, _ := io.ReadFull(conn, b[:])
  288. fmt.Println("string is", n)
  289. if string(b[:n]) != "Response" {
  290. t.Error("response: ", string(b[:n]))
  291. }
  292. common.Must(conn.Close())
  293. conn, err = Dial(ctx, net.UnixDestination(net.DomainAddress(tempSocket)), streamSettings)
  294. common.Must(err)
  295. _, err = conn.Write([]byte("Test connection 2"))
  296. common.Must(err)
  297. n, _ = io.ReadFull(conn, b[:])
  298. common.Must(err)
  299. if string(b[:n]) != "Response" {
  300. t.Error("response: ", string(b[:n]))
  301. }
  302. common.Must(conn.Close())
  303. common.Must(listen.Close())
  304. }
  305. func Test_queryString(t *testing.T) {
  306. listenPort := tcp.PickPort()
  307. listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, &internet.MemoryStreamConfig{
  308. ProtocolName: "splithttp",
  309. ProtocolSettings: &Config{
  310. // this querystring does not have any effect, but sometimes people blindly copy it from websocket config. make sure the outbound doesn't break
  311. Path: "/sh?ed=2048",
  312. },
  313. }, func(conn stat.Connection) {
  314. go func(c stat.Connection) {
  315. defer c.Close()
  316. var b [1024]byte
  317. c.SetReadDeadline(time.Now().Add(2 * time.Second))
  318. _, err := c.Read(b[:])
  319. if err != nil {
  320. return
  321. }
  322. common.Must2(c.Write([]byte("Response")))
  323. }(conn)
  324. })
  325. common.Must(err)
  326. ctx := context.Background()
  327. streamSettings := &internet.MemoryStreamConfig{
  328. ProtocolName: "splithttp",
  329. ProtocolSettings: &Config{Path: "sh?ed=2048"},
  330. }
  331. conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  332. common.Must(err)
  333. _, err = conn.Write([]byte("Test connection 1"))
  334. common.Must(err)
  335. var b [1024]byte
  336. fmt.Println("test2")
  337. n, _ := io.ReadFull(conn, b[:])
  338. fmt.Println("string is", n)
  339. if string(b[:n]) != "Response" {
  340. t.Error("response: ", string(b[:n]))
  341. }
  342. common.Must(conn.Close())
  343. common.Must(listen.Close())
  344. }
  345. func Test_maxUpload(t *testing.T) {
  346. listenPort := tcp.PickPort()
  347. streamSettings := &internet.MemoryStreamConfig{
  348. ProtocolName: "splithttp",
  349. ProtocolSettings: &Config{
  350. Path: "/sh",
  351. ScMaxEachPostBytes: &RangeConfig{
  352. From: 10000,
  353. To: 10000,
  354. },
  355. },
  356. }
  357. uploadReceived := make([]byte, 10001)
  358. listen, err := ListenXH(context.Background(), net.LocalHostIP, listenPort, streamSettings, func(conn stat.Connection) {
  359. go func(c stat.Connection) {
  360. defer c.Close()
  361. c.SetReadDeadline(time.Now().Add(2 * time.Second))
  362. io.ReadFull(c, uploadReceived)
  363. common.Must2(c.Write([]byte("Response")))
  364. }(conn)
  365. })
  366. common.Must(err)
  367. ctx := context.Background()
  368. conn, err := Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), listenPort), streamSettings)
  369. common.Must(err)
  370. // send a slightly too large upload
  371. upload := make([]byte, 10001)
  372. rand.Read(upload)
  373. _, err = conn.Write(upload)
  374. common.Must(err)
  375. var b [10240]byte
  376. n, _ := io.ReadFull(conn, b[:])
  377. fmt.Println("string is", n)
  378. if string(b[:n]) != "Response" {
  379. t.Error("response: ", string(b[:n]))
  380. }
  381. common.Must(conn.Close())
  382. if !bytes.Equal(upload, uploadReceived) {
  383. t.Error("incorrect upload", upload, uploadReceived)
  384. }
  385. common.Must(listen.Close())
  386. }