hub.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. package splithttp
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "io"
  6. gonet "net"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/quic-go/quic-go"
  13. "github.com/quic-go/quic-go/http3"
  14. "github.com/xtls/xray-core/common"
  15. "github.com/xtls/xray-core/common/errors"
  16. "github.com/xtls/xray-core/common/net"
  17. http_proto "github.com/xtls/xray-core/common/protocol/http"
  18. "github.com/xtls/xray-core/common/signal/done"
  19. "github.com/xtls/xray-core/transport/internet"
  20. "github.com/xtls/xray-core/transport/internet/stat"
  21. v2tls "github.com/xtls/xray-core/transport/internet/tls"
  22. "golang.org/x/net/http2"
  23. "golang.org/x/net/http2/h2c"
  24. )
  25. type requestHandler struct {
  26. config *Config
  27. host string
  28. path string
  29. ln *Listener
  30. sessionMu *sync.Mutex
  31. sessions sync.Map
  32. localAddr gonet.TCPAddr
  33. }
  34. type httpSession struct {
  35. uploadQueue *uploadQueue
  36. // for as long as the GET request is not opened by the client, this will be
  37. // open ("undone"), and the session may be expired within a certain TTL.
  38. // after the client connects, this becomes "done" and the session lives as
  39. // long as the GET request.
  40. isFullyConnected *done.Instance
  41. }
  42. func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessionId string) {
  43. shouldReap := done.New()
  44. go func() {
  45. time.Sleep(30 * time.Second)
  46. shouldReap.Close()
  47. }()
  48. select {
  49. case <-isFullyConnected.Wait():
  50. return
  51. case <-shouldReap.Wait():
  52. h.sessions.Delete(sessionId)
  53. }
  54. }
  55. func (h *requestHandler) upsertSession(sessionId string) *httpSession {
  56. // fast path
  57. currentSessionAny, ok := h.sessions.Load(sessionId)
  58. if ok {
  59. return currentSessionAny.(*httpSession)
  60. }
  61. // slow path
  62. h.sessionMu.Lock()
  63. defer h.sessionMu.Unlock()
  64. currentSessionAny, ok = h.sessions.Load(sessionId)
  65. if ok {
  66. return currentSessionAny.(*httpSession)
  67. }
  68. s := &httpSession{
  69. uploadQueue: NewUploadQueue(int(h.ln.config.GetNormalizedScMaxConcurrentPosts().To)),
  70. isFullyConnected: done.New(),
  71. }
  72. h.sessions.Store(sessionId, s)
  73. go h.maybeReapSession(s.isFullyConnected, sessionId)
  74. return s
  75. }
  76. func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  77. if len(h.host) > 0 && !internet.IsValidHTTPHost(request.Host, h.host) {
  78. errors.LogInfo(context.Background(), "failed to validate host, request:", request.Host, ", config:", h.host)
  79. writer.WriteHeader(http.StatusNotFound)
  80. return
  81. }
  82. if !strings.HasPrefix(request.URL.Path, h.path) {
  83. errors.LogInfo(context.Background(), "failed to validate path, request:", request.URL.Path, ", config:", h.path)
  84. writer.WriteHeader(http.StatusNotFound)
  85. return
  86. }
  87. sessionId := ""
  88. subpath := strings.Split(request.URL.Path[len(h.path):], "/")
  89. if len(subpath) > 0 {
  90. sessionId = subpath[0]
  91. }
  92. if sessionId == "" {
  93. errors.LogInfo(context.Background(), "no sessionid on request:", request.URL.Path)
  94. writer.WriteHeader(http.StatusBadRequest)
  95. return
  96. }
  97. forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
  98. remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr)
  99. if err != nil {
  100. remoteAddr = &gonet.TCPAddr{}
  101. }
  102. if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
  103. remoteAddr = &net.TCPAddr{
  104. IP: forwardedAddrs[0].IP(),
  105. Port: int(0),
  106. }
  107. }
  108. currentSession := h.upsertSession(sessionId)
  109. scMaxEachPostBytes := int(h.ln.config.GetNormalizedScMaxEachPostBytes().To)
  110. if request.Method == "POST" {
  111. seq := ""
  112. if len(subpath) > 1 {
  113. seq = subpath[1]
  114. }
  115. if seq == "" {
  116. errors.LogInfo(context.Background(), "no seq on request:", request.URL.Path)
  117. writer.WriteHeader(http.StatusBadRequest)
  118. return
  119. }
  120. payload, err := io.ReadAll(request.Body)
  121. if len(payload) > scMaxEachPostBytes {
  122. errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request had size ", len(payload), ". Adjust scMaxEachPostBytes on the server to be at least as large as client.")
  123. writer.WriteHeader(http.StatusRequestEntityTooLarge)
  124. return
  125. }
  126. if err != nil {
  127. errors.LogInfoInner(context.Background(), err, "failed to upload")
  128. writer.WriteHeader(http.StatusInternalServerError)
  129. return
  130. }
  131. seqInt, err := strconv.ParseUint(seq, 10, 64)
  132. if err != nil {
  133. errors.LogInfoInner(context.Background(), err, "failed to upload")
  134. writer.WriteHeader(http.StatusInternalServerError)
  135. return
  136. }
  137. err = currentSession.uploadQueue.Push(Packet{
  138. Payload: payload,
  139. Seq: seqInt,
  140. })
  141. if err != nil {
  142. errors.LogInfoInner(context.Background(), err, "failed to upload")
  143. writer.WriteHeader(http.StatusInternalServerError)
  144. return
  145. }
  146. h.config.WriteResponseHeader(writer)
  147. writer.WriteHeader(http.StatusOK)
  148. } else if request.Method == "GET" {
  149. responseFlusher, ok := writer.(http.Flusher)
  150. if !ok {
  151. panic("expected http.ResponseWriter to be an http.Flusher")
  152. }
  153. // after GET is done, the connection is finished. disable automatic
  154. // session reaping, and handle it in defer
  155. currentSession.isFullyConnected.Close()
  156. defer h.sessions.Delete(sessionId)
  157. // magic header instructs nginx + apache to not buffer response body
  158. writer.Header().Set("X-Accel-Buffering", "no")
  159. // A web-compliant header telling all middleboxes to disable caching.
  160. // Should be able to prevent overloading the cache, or stop CDNs from
  161. // teeing the response stream into their cache, causing slowdowns.
  162. writer.Header().Set("Cache-Control", "no-store")
  163. if !h.config.NoSSEHeader {
  164. // magic header to make the HTTP middle box consider this as SSE to disable buffer
  165. writer.Header().Set("Content-Type", "text/event-stream")
  166. }
  167. h.config.WriteResponseHeader(writer)
  168. writer.WriteHeader(http.StatusOK)
  169. responseFlusher.Flush()
  170. downloadDone := done.New()
  171. conn := splitConn{
  172. writer: &httpResponseBodyWriter{
  173. responseWriter: writer,
  174. downloadDone: downloadDone,
  175. responseFlusher: responseFlusher,
  176. },
  177. reader: currentSession.uploadQueue,
  178. remoteAddr: remoteAddr,
  179. }
  180. h.ln.addConn(stat.Connection(&conn))
  181. // "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
  182. select {
  183. case <-request.Context().Done():
  184. case <-downloadDone.Wait():
  185. }
  186. conn.Close()
  187. } else {
  188. writer.WriteHeader(http.StatusMethodNotAllowed)
  189. }
  190. }
  191. type httpResponseBodyWriter struct {
  192. sync.Mutex
  193. responseWriter http.ResponseWriter
  194. responseFlusher http.Flusher
  195. downloadDone *done.Instance
  196. }
  197. func (c *httpResponseBodyWriter) Write(b []byte) (int, error) {
  198. c.Lock()
  199. defer c.Unlock()
  200. if c.downloadDone.Done() {
  201. return 0, io.ErrClosedPipe
  202. }
  203. n, err := c.responseWriter.Write(b)
  204. if err == nil {
  205. c.responseFlusher.Flush()
  206. }
  207. return n, err
  208. }
  209. func (c *httpResponseBodyWriter) Close() error {
  210. c.Lock()
  211. defer c.Unlock()
  212. c.downloadDone.Close()
  213. return nil
  214. }
  215. type Listener struct {
  216. sync.Mutex
  217. server http.Server
  218. h3server *http3.Server
  219. listener net.Listener
  220. h3listener *quic.EarlyListener
  221. config *Config
  222. addConn internet.ConnHandler
  223. isH3 bool
  224. }
  225. func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
  226. l := &Listener{
  227. addConn: addConn,
  228. }
  229. shSettings := streamSettings.ProtocolSettings.(*Config)
  230. l.config = shSettings
  231. if l.config != nil {
  232. if streamSettings.SocketSettings == nil {
  233. streamSettings.SocketSettings = &internet.SocketConfig{}
  234. }
  235. }
  236. var listener net.Listener
  237. var err error
  238. var localAddr = gonet.TCPAddr{}
  239. handler := &requestHandler{
  240. config: shSettings,
  241. host: shSettings.Host,
  242. path: shSettings.GetNormalizedPath(),
  243. ln: l,
  244. sessionMu: &sync.Mutex{},
  245. sessions: sync.Map{},
  246. localAddr: localAddr,
  247. }
  248. tlsConfig := getTLSConfig(streamSettings)
  249. l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
  250. if port == net.Port(0) { // unix
  251. listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
  252. Name: address.Domain(),
  253. Net: "unix",
  254. }, streamSettings.SocketSettings)
  255. if err != nil {
  256. return nil, errors.New("failed to listen unix domain socket(for SH) on ", address).Base(err)
  257. }
  258. errors.LogInfo(ctx, "listening unix domain socket(for SH) on ", address)
  259. } else if l.isH3 { // quic
  260. Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
  261. IP: address.IP(),
  262. Port: int(port),
  263. }, streamSettings.SocketSettings)
  264. if err != nil {
  265. return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
  266. }
  267. h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
  268. if err != nil {
  269. return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
  270. }
  271. l.h3listener = h3listener
  272. errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port)
  273. l.h3server = &http3.Server{
  274. Handler: handler,
  275. }
  276. go func() {
  277. if err := l.h3server.ServeListener(l.h3listener); err != nil {
  278. errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp")
  279. }
  280. }()
  281. } else { // tcp
  282. localAddr = gonet.TCPAddr{
  283. IP: address.IP(),
  284. Port: int(port),
  285. }
  286. listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
  287. IP: address.IP(),
  288. Port: int(port),
  289. }, streamSettings.SocketSettings)
  290. if err != nil {
  291. return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
  292. }
  293. errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)
  294. }
  295. // tcp/unix (h1/h2)
  296. if listener != nil {
  297. if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
  298. if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
  299. listener = tls.NewListener(listener, tlsConfig)
  300. }
  301. }
  302. // h2cHandler can handle both plaintext HTTP/1.1 and h2c
  303. h2cHandler := h2c.NewHandler(handler, &http2.Server{})
  304. l.listener = listener
  305. l.server = http.Server{
  306. Handler: h2cHandler,
  307. ReadHeaderTimeout: time.Second * 4,
  308. MaxHeaderBytes: 8192,
  309. }
  310. go func() {
  311. if err := l.server.Serve(l.listener); err != nil {
  312. errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
  313. }
  314. }()
  315. }
  316. return l, err
  317. }
  318. // Addr implements net.Listener.Addr().
  319. func (ln *Listener) Addr() net.Addr {
  320. return ln.listener.Addr()
  321. }
  322. // Close implements net.Listener.Close().
  323. func (ln *Listener) Close() error {
  324. if ln.h3server != nil {
  325. if err := ln.h3server.Close(); err != nil {
  326. return err
  327. }
  328. } else if ln.listener != nil {
  329. return ln.listener.Close()
  330. }
  331. return errors.New("listener does not have an HTTP/3 server or a net.listener")
  332. }
  333. func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config {
  334. config := v2tls.ConfigFromStreamSettings(streamSettings)
  335. if config == nil {
  336. return &tls.Config{}
  337. }
  338. return config.GetTLSConfig()
  339. }
  340. func init() {
  341. common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
  342. }