hub.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. package splithttp
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "io"
  6. gonet "net"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/quic-go/quic-go"
  13. "github.com/quic-go/quic-go/http3"
  14. "github.com/xtls/xray-core/common"
  15. "github.com/xtls/xray-core/common/errors"
  16. "github.com/xtls/xray-core/common/net"
  17. http_proto "github.com/xtls/xray-core/common/protocol/http"
  18. "github.com/xtls/xray-core/common/signal/done"
  19. "github.com/xtls/xray-core/transport/internet"
  20. "github.com/xtls/xray-core/transport/internet/stat"
  21. v2tls "github.com/xtls/xray-core/transport/internet/tls"
  22. "golang.org/x/net/http2"
  23. "golang.org/x/net/http2/h2c"
  24. )
  25. type requestHandler struct {
  26. host string
  27. path string
  28. ln *Listener
  29. sessionMu *sync.Mutex
  30. sessions sync.Map
  31. localAddr gonet.TCPAddr
  32. }
  33. type httpSession struct {
  34. uploadQueue *uploadQueue
  35. // for as long as the GET request is not opened by the client, this will be
  36. // open ("undone"), and the session may be expired within a certain TTL.
  37. // after the client connects, this becomes "done" and the session lives as
  38. // long as the GET request.
  39. isFullyConnected *done.Instance
  40. }
  41. func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessionId string) {
  42. shouldReap := done.New()
  43. go func() {
  44. time.Sleep(30 * time.Second)
  45. shouldReap.Close()
  46. }()
  47. select {
  48. case <-isFullyConnected.Wait():
  49. return
  50. case <-shouldReap.Wait():
  51. h.sessions.Delete(sessionId)
  52. }
  53. }
  54. func (h *requestHandler) upsertSession(sessionId string) *httpSession {
  55. // fast path
  56. currentSessionAny, ok := h.sessions.Load(sessionId)
  57. if ok {
  58. return currentSessionAny.(*httpSession)
  59. }
  60. // slow path
  61. h.sessionMu.Lock()
  62. defer h.sessionMu.Unlock()
  63. currentSessionAny, ok = h.sessions.Load(sessionId)
  64. if ok {
  65. return currentSessionAny.(*httpSession)
  66. }
  67. s := &httpSession{
  68. uploadQueue: NewUploadQueue(int(h.ln.config.GetNormalizedMaxConcurrentUploads(true).To)),
  69. isFullyConnected: done.New(),
  70. }
  71. h.sessions.Store(sessionId, s)
  72. go h.maybeReapSession(s.isFullyConnected, sessionId)
  73. return s
  74. }
  75. func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  76. if len(h.host) > 0 && !internet.IsValidHTTPHost(request.Host, h.host) {
  77. errors.LogInfo(context.Background(), "failed to validate host, request:", request.Host, ", config:", h.host)
  78. writer.WriteHeader(http.StatusNotFound)
  79. return
  80. }
  81. if !strings.HasPrefix(request.URL.Path, h.path) {
  82. errors.LogInfo(context.Background(), "failed to validate path, request:", request.URL.Path, ", config:", h.path)
  83. writer.WriteHeader(http.StatusNotFound)
  84. return
  85. }
  86. sessionId := ""
  87. subpath := strings.Split(request.URL.Path[len(h.path):], "/")
  88. if len(subpath) > 0 {
  89. sessionId = subpath[0]
  90. }
  91. if sessionId == "" {
  92. errors.LogInfo(context.Background(), "no sessionid on request:", request.URL.Path)
  93. writer.WriteHeader(http.StatusBadRequest)
  94. return
  95. }
  96. forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
  97. remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr)
  98. if err != nil {
  99. remoteAddr = &gonet.TCPAddr{}
  100. }
  101. if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
  102. remoteAddr = &net.TCPAddr{
  103. IP: forwardedAddrs[0].IP(),
  104. Port: int(0),
  105. }
  106. }
  107. currentSession := h.upsertSession(sessionId)
  108. maxUploadSize := int(h.ln.config.GetNormalizedMaxUploadSize(true).To)
  109. if request.Method == "POST" {
  110. seq := ""
  111. if len(subpath) > 1 {
  112. seq = subpath[1]
  113. }
  114. if seq == "" {
  115. errors.LogInfo(context.Background(), "no seq on request:", request.URL.Path)
  116. writer.WriteHeader(http.StatusBadRequest)
  117. return
  118. }
  119. payload, err := io.ReadAll(request.Body)
  120. if len(payload) > maxUploadSize {
  121. errors.LogInfo(context.Background(), "Too large upload. maxUploadSize is set to", maxUploadSize, "but request had size", len(payload), ". Adjust maxUploadSize on the server to be at least as large as client.")
  122. writer.WriteHeader(http.StatusRequestEntityTooLarge)
  123. return
  124. }
  125. if err != nil {
  126. errors.LogInfoInner(context.Background(), err, "failed to upload")
  127. writer.WriteHeader(http.StatusInternalServerError)
  128. return
  129. }
  130. seqInt, err := strconv.ParseUint(seq, 10, 64)
  131. if err != nil {
  132. errors.LogInfoInner(context.Background(), err, "failed to upload")
  133. writer.WriteHeader(http.StatusInternalServerError)
  134. return
  135. }
  136. err = currentSession.uploadQueue.Push(Packet{
  137. Payload: payload,
  138. Seq: seqInt,
  139. })
  140. if err != nil {
  141. errors.LogInfoInner(context.Background(), err, "failed to upload")
  142. writer.WriteHeader(http.StatusInternalServerError)
  143. return
  144. }
  145. writer.WriteHeader(http.StatusOK)
  146. } else if request.Method == "GET" {
  147. responseFlusher, ok := writer.(http.Flusher)
  148. if !ok {
  149. panic("expected http.ResponseWriter to be an http.Flusher")
  150. }
  151. // after GET is done, the connection is finished. disable automatic
  152. // session reaping, and handle it in defer
  153. currentSession.isFullyConnected.Close()
  154. defer h.sessions.Delete(sessionId)
  155. // magic header instructs nginx + apache to not buffer response body
  156. writer.Header().Set("X-Accel-Buffering", "no")
  157. // magic header to make the HTTP middle box consider this as SSE to disable buffer
  158. writer.Header().Set("Content-Type", "text/event-stream")
  159. writer.WriteHeader(http.StatusOK)
  160. // send a chunk immediately to enable CDN streaming.
  161. // many CDN buffer the response headers until the origin starts sending
  162. // the body, with no way to turn it off.
  163. writer.Write([]byte("ok"))
  164. responseFlusher.Flush()
  165. downloadDone := done.New()
  166. conn := splitConn{
  167. writer: &httpResponseBodyWriter{
  168. responseWriter: writer,
  169. downloadDone: downloadDone,
  170. responseFlusher: responseFlusher,
  171. },
  172. reader: currentSession.uploadQueue,
  173. remoteAddr: remoteAddr,
  174. }
  175. h.ln.addConn(stat.Connection(&conn))
  176. // "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
  177. <-downloadDone.Wait()
  178. } else {
  179. writer.WriteHeader(http.StatusMethodNotAllowed)
  180. }
  181. }
  182. type httpResponseBodyWriter struct {
  183. sync.Mutex
  184. responseWriter http.ResponseWriter
  185. responseFlusher http.Flusher
  186. downloadDone *done.Instance
  187. }
  188. func (c *httpResponseBodyWriter) Write(b []byte) (int, error) {
  189. c.Lock()
  190. defer c.Unlock()
  191. if c.downloadDone.Done() {
  192. return 0, io.ErrClosedPipe
  193. }
  194. n, err := c.responseWriter.Write(b)
  195. if err == nil {
  196. c.responseFlusher.Flush()
  197. }
  198. return n, err
  199. }
  200. func (c *httpResponseBodyWriter) Close() error {
  201. c.Lock()
  202. defer c.Unlock()
  203. c.downloadDone.Close()
  204. return nil
  205. }
  206. type Listener struct {
  207. sync.Mutex
  208. server http.Server
  209. h3server *http3.Server
  210. listener net.Listener
  211. h3listener *quic.EarlyListener
  212. config *Config
  213. addConn internet.ConnHandler
  214. isH3 bool
  215. }
  216. func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
  217. l := &Listener{
  218. addConn: addConn,
  219. }
  220. shSettings := streamSettings.ProtocolSettings.(*Config)
  221. l.config = shSettings
  222. if l.config != nil {
  223. if streamSettings.SocketSettings == nil {
  224. streamSettings.SocketSettings = &internet.SocketConfig{}
  225. }
  226. }
  227. var listener net.Listener
  228. var err error
  229. var localAddr = gonet.TCPAddr{}
  230. handler := &requestHandler{
  231. host: shSettings.Host,
  232. path: shSettings.GetNormalizedPath("", false),
  233. ln: l,
  234. sessionMu: &sync.Mutex{},
  235. sessions: sync.Map{},
  236. localAddr: localAddr,
  237. }
  238. tlsConfig := getTLSConfig(streamSettings)
  239. l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
  240. if port == net.Port(0) { // unix
  241. listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
  242. Name: address.Domain(),
  243. Net: "unix",
  244. }, streamSettings.SocketSettings)
  245. if err != nil {
  246. return nil, errors.New("failed to listen unix domain socket(for SH) on ", address).Base(err)
  247. }
  248. errors.LogInfo(ctx, "listening unix domain socket(for SH) on ", address)
  249. } else if l.isH3 { // quic
  250. Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
  251. IP: address.IP(),
  252. Port: int(port),
  253. }, streamSettings.SocketSettings)
  254. if err != nil {
  255. return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
  256. }
  257. h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
  258. if err != nil {
  259. return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
  260. }
  261. l.h3listener = h3listener
  262. errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port)
  263. l.h3server = &http3.Server{
  264. Handler: handler,
  265. }
  266. go func() {
  267. if err := l.h3server.ServeListener(l.h3listener); err != nil {
  268. errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp")
  269. }
  270. }()
  271. } else { // tcp
  272. localAddr = gonet.TCPAddr{
  273. IP: address.IP(),
  274. Port: int(port),
  275. }
  276. listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
  277. IP: address.IP(),
  278. Port: int(port),
  279. }, streamSettings.SocketSettings)
  280. if err != nil {
  281. return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
  282. }
  283. errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)
  284. }
  285. // tcp/unix (h1/h2)
  286. if listener != nil {
  287. if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
  288. if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
  289. listener = tls.NewListener(listener, tlsConfig)
  290. }
  291. }
  292. // h2cHandler can handle both plaintext HTTP/1.1 and h2c
  293. h2cHandler := h2c.NewHandler(handler, &http2.Server{})
  294. l.listener = listener
  295. l.server = http.Server{
  296. Handler: h2cHandler,
  297. ReadHeaderTimeout: time.Second * 4,
  298. MaxHeaderBytes: 8192,
  299. }
  300. go func() {
  301. if err := l.server.Serve(l.listener); err != nil {
  302. errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
  303. }
  304. }()
  305. }
  306. return l, err
  307. }
  308. // Addr implements net.Listener.Addr().
  309. func (ln *Listener) Addr() net.Addr {
  310. return ln.listener.Addr()
  311. }
  312. // Close implements net.Listener.Close().
  313. func (ln *Listener) Close() error {
  314. if ln.h3server != nil {
  315. if err := ln.h3server.Close(); err != nil {
  316. return err
  317. }
  318. } else if ln.listener != nil {
  319. return ln.listener.Close()
  320. }
  321. return errors.New("listener does not have an HTTP/3 server or a net.listener")
  322. }
  323. func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config {
  324. config := v2tls.ConfigFromStreamSettings(streamSettings)
  325. if config == nil {
  326. return &tls.Config{}
  327. }
  328. return config.GetTLSConfig()
  329. }
  330. func init() {
  331. common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
  332. }