hub.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. package splithttp
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "io"
  6. gonet "net"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/quic-go/quic-go"
  13. "github.com/quic-go/quic-go/http3"
  14. goreality "github.com/xtls/reality"
  15. "github.com/xtls/xray-core/common"
  16. "github.com/xtls/xray-core/common/errors"
  17. "github.com/xtls/xray-core/common/net"
  18. http_proto "github.com/xtls/xray-core/common/protocol/http"
  19. "github.com/xtls/xray-core/common/signal/done"
  20. "github.com/xtls/xray-core/transport/internet"
  21. "github.com/xtls/xray-core/transport/internet/reality"
  22. "github.com/xtls/xray-core/transport/internet/stat"
  23. v2tls "github.com/xtls/xray-core/transport/internet/tls"
  24. "golang.org/x/net/http2"
  25. "golang.org/x/net/http2/h2c"
  26. )
  27. type requestHandler struct {
  28. config *Config
  29. host string
  30. path string
  31. ln *Listener
  32. sessionMu *sync.Mutex
  33. sessions sync.Map
  34. localAddr gonet.TCPAddr
  35. }
  36. type httpSession struct {
  37. uploadQueue *uploadQueue
  38. // for as long as the GET request is not opened by the client, this will be
  39. // open ("undone"), and the session may be expired within a certain TTL.
  40. // after the client connects, this becomes "done" and the session lives as
  41. // long as the GET request.
  42. isFullyConnected *done.Instance
  43. }
  44. func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessionId string) {
  45. shouldReap := done.New()
  46. go func() {
  47. time.Sleep(30 * time.Second)
  48. shouldReap.Close()
  49. }()
  50. select {
  51. case <-isFullyConnected.Wait():
  52. return
  53. case <-shouldReap.Wait():
  54. h.sessions.Delete(sessionId)
  55. }
  56. }
  57. func (h *requestHandler) upsertSession(sessionId string) *httpSession {
  58. // fast path
  59. currentSessionAny, ok := h.sessions.Load(sessionId)
  60. if ok {
  61. return currentSessionAny.(*httpSession)
  62. }
  63. // slow path
  64. h.sessionMu.Lock()
  65. defer h.sessionMu.Unlock()
  66. currentSessionAny, ok = h.sessions.Load(sessionId)
  67. if ok {
  68. return currentSessionAny.(*httpSession)
  69. }
  70. s := &httpSession{
  71. uploadQueue: NewUploadQueue(int(h.ln.config.GetNormalizedScMaxConcurrentPosts().To)),
  72. isFullyConnected: done.New(),
  73. }
  74. h.sessions.Store(sessionId, s)
  75. go h.maybeReapSession(s.isFullyConnected, sessionId)
  76. return s
  77. }
  78. func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  79. if len(h.host) > 0 && !internet.IsValidHTTPHost(request.Host, h.host) {
  80. errors.LogInfo(context.Background(), "failed to validate host, request:", request.Host, ", config:", h.host)
  81. writer.WriteHeader(http.StatusNotFound)
  82. return
  83. }
  84. if !strings.HasPrefix(request.URL.Path, h.path) {
  85. errors.LogInfo(context.Background(), "failed to validate path, request:", request.URL.Path, ", config:", h.path)
  86. writer.WriteHeader(http.StatusNotFound)
  87. return
  88. }
  89. h.config.WriteResponseHeader(writer)
  90. validRange := h.config.GetNormalizedXPaddingBytes()
  91. x_padding := int32(len(request.URL.Query().Get("x_padding")))
  92. if validRange.To > 0 && (x_padding < validRange.From || x_padding > validRange.To) {
  93. errors.LogInfo(context.Background(), "invalid x_padding length:", x_padding)
  94. writer.WriteHeader(http.StatusBadRequest)
  95. return
  96. }
  97. sessionId := ""
  98. subpath := strings.Split(request.URL.Path[len(h.path):], "/")
  99. if len(subpath) > 0 {
  100. sessionId = subpath[0]
  101. }
  102. if sessionId == "" && h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "stream-one" && h.config.Mode != "stream-up" {
  103. errors.LogInfo(context.Background(), "stream-one mode is not allowed")
  104. writer.WriteHeader(http.StatusBadRequest)
  105. return
  106. }
  107. forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
  108. remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr)
  109. if err != nil {
  110. remoteAddr = &gonet.TCPAddr{}
  111. }
  112. if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
  113. remoteAddr = &net.TCPAddr{
  114. IP: forwardedAddrs[0].IP(),
  115. Port: int(0),
  116. }
  117. }
  118. var currentSession *httpSession
  119. if sessionId != "" {
  120. currentSession = h.upsertSession(sessionId)
  121. }
  122. scMaxEachPostBytes := int(h.ln.config.GetNormalizedScMaxEachPostBytes().To)
  123. if request.Method == "POST" && sessionId != "" {
  124. seq := ""
  125. if len(subpath) > 1 {
  126. seq = subpath[1]
  127. }
  128. if seq == "" {
  129. if h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "stream-up" {
  130. errors.LogInfo(context.Background(), "stream-up mode is not allowed")
  131. writer.WriteHeader(http.StatusBadRequest)
  132. return
  133. }
  134. err = currentSession.uploadQueue.Push(Packet{
  135. Reader: request.Body,
  136. })
  137. if err != nil {
  138. errors.LogInfoInner(context.Background(), err, "failed to upload (PushReader)")
  139. writer.WriteHeader(http.StatusConflict)
  140. } else {
  141. writer.WriteHeader(http.StatusOK)
  142. <-request.Context().Done()
  143. }
  144. return
  145. }
  146. if h.config.Mode != "" && h.config.Mode != "auto" && h.config.Mode != "packet-up" {
  147. errors.LogInfo(context.Background(), "packet-up mode is not allowed")
  148. writer.WriteHeader(http.StatusBadRequest)
  149. return
  150. }
  151. payload, err := io.ReadAll(request.Body)
  152. if len(payload) > scMaxEachPostBytes {
  153. errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request had size ", len(payload), ". Adjust scMaxEachPostBytes on the server to be at least as large as client.")
  154. writer.WriteHeader(http.StatusRequestEntityTooLarge)
  155. return
  156. }
  157. if err != nil {
  158. errors.LogInfoInner(context.Background(), err, "failed to upload (ReadAll)")
  159. writer.WriteHeader(http.StatusInternalServerError)
  160. return
  161. }
  162. seqInt, err := strconv.ParseUint(seq, 10, 64)
  163. if err != nil {
  164. errors.LogInfoInner(context.Background(), err, "failed to upload (ParseUint)")
  165. writer.WriteHeader(http.StatusInternalServerError)
  166. return
  167. }
  168. err = currentSession.uploadQueue.Push(Packet{
  169. Payload: payload,
  170. Seq: seqInt,
  171. })
  172. if err != nil {
  173. errors.LogInfoInner(context.Background(), err, "failed to upload (PushPayload)")
  174. writer.WriteHeader(http.StatusInternalServerError)
  175. return
  176. }
  177. writer.WriteHeader(http.StatusOK)
  178. } else if request.Method == "GET" || sessionId == "" {
  179. responseFlusher, ok := writer.(http.Flusher)
  180. if !ok {
  181. panic("expected http.ResponseWriter to be an http.Flusher")
  182. }
  183. if sessionId != "" {
  184. // after GET is done, the connection is finished. disable automatic
  185. // session reaping, and handle it in defer
  186. currentSession.isFullyConnected.Close()
  187. defer h.sessions.Delete(sessionId)
  188. }
  189. // magic header instructs nginx + apache to not buffer response body
  190. writer.Header().Set("X-Accel-Buffering", "no")
  191. // A web-compliant header telling all middleboxes to disable caching.
  192. // Should be able to prevent overloading the cache, or stop CDNs from
  193. // teeing the response stream into their cache, causing slowdowns.
  194. writer.Header().Set("Cache-Control", "no-store")
  195. if !h.config.NoSSEHeader {
  196. // magic header to make the HTTP middle box consider this as SSE to disable buffer
  197. writer.Header().Set("Content-Type", "text/event-stream")
  198. }
  199. writer.WriteHeader(http.StatusOK)
  200. responseFlusher.Flush()
  201. downloadDone := done.New()
  202. conn := splitConn{
  203. writer: &httpResponseBodyWriter{
  204. responseWriter: writer,
  205. downloadDone: downloadDone,
  206. responseFlusher: responseFlusher,
  207. },
  208. reader: request.Body,
  209. remoteAddr: remoteAddr,
  210. }
  211. if sessionId != "" {
  212. conn.reader = currentSession.uploadQueue
  213. }
  214. h.ln.addConn(stat.Connection(&conn))
  215. // "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
  216. select {
  217. case <-request.Context().Done():
  218. case <-downloadDone.Wait():
  219. }
  220. conn.Close()
  221. } else {
  222. errors.LogInfo(context.Background(), "unsupported method: ", request.Method)
  223. writer.WriteHeader(http.StatusMethodNotAllowed)
  224. }
  225. }
  226. type httpResponseBodyWriter struct {
  227. sync.Mutex
  228. responseWriter http.ResponseWriter
  229. responseFlusher http.Flusher
  230. downloadDone *done.Instance
  231. }
  232. func (c *httpResponseBodyWriter) Write(b []byte) (int, error) {
  233. c.Lock()
  234. defer c.Unlock()
  235. if c.downloadDone.Done() {
  236. return 0, io.ErrClosedPipe
  237. }
  238. n, err := c.responseWriter.Write(b)
  239. if err == nil {
  240. c.responseFlusher.Flush()
  241. }
  242. return n, err
  243. }
  244. func (c *httpResponseBodyWriter) Close() error {
  245. c.Lock()
  246. defer c.Unlock()
  247. c.downloadDone.Close()
  248. return nil
  249. }
  250. type Listener struct {
  251. sync.Mutex
  252. server http.Server
  253. h3server *http3.Server
  254. listener net.Listener
  255. h3listener *quic.EarlyListener
  256. config *Config
  257. addConn internet.ConnHandler
  258. isH3 bool
  259. }
  260. func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
  261. l := &Listener{
  262. addConn: addConn,
  263. }
  264. shSettings := streamSettings.ProtocolSettings.(*Config)
  265. l.config = shSettings
  266. if l.config != nil {
  267. if streamSettings.SocketSettings == nil {
  268. streamSettings.SocketSettings = &internet.SocketConfig{}
  269. }
  270. }
  271. var listener net.Listener
  272. var err error
  273. var localAddr = gonet.TCPAddr{}
  274. handler := &requestHandler{
  275. config: shSettings,
  276. host: shSettings.Host,
  277. path: shSettings.GetNormalizedPath(),
  278. ln: l,
  279. sessionMu: &sync.Mutex{},
  280. sessions: sync.Map{},
  281. localAddr: localAddr,
  282. }
  283. tlsConfig := getTLSConfig(streamSettings)
  284. l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
  285. if port == net.Port(0) { // unix
  286. listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
  287. Name: address.Domain(),
  288. Net: "unix",
  289. }, streamSettings.SocketSettings)
  290. if err != nil {
  291. return nil, errors.New("failed to listen unix domain socket(for SH) on ", address).Base(err)
  292. }
  293. errors.LogInfo(ctx, "listening unix domain socket(for SH) on ", address)
  294. } else if l.isH3 { // quic
  295. Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
  296. IP: address.IP(),
  297. Port: int(port),
  298. }, streamSettings.SocketSettings)
  299. if err != nil {
  300. return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
  301. }
  302. h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
  303. if err != nil {
  304. return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
  305. }
  306. l.h3listener = h3listener
  307. errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port)
  308. l.h3server = &http3.Server{
  309. Handler: handler,
  310. }
  311. go func() {
  312. if err := l.h3server.ServeListener(l.h3listener); err != nil {
  313. errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp")
  314. }
  315. }()
  316. } else { // tcp
  317. localAddr = gonet.TCPAddr{
  318. IP: address.IP(),
  319. Port: int(port),
  320. }
  321. listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
  322. IP: address.IP(),
  323. Port: int(port),
  324. }, streamSettings.SocketSettings)
  325. if err != nil {
  326. return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
  327. }
  328. errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)
  329. }
  330. // tcp/unix (h1/h2)
  331. if listener != nil {
  332. if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
  333. if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
  334. listener = tls.NewListener(listener, tlsConfig)
  335. }
  336. }
  337. if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
  338. listener = goreality.NewListener(listener, config.GetREALITYConfig())
  339. }
  340. // h2cHandler can handle both plaintext HTTP/1.1 and h2c
  341. h2cHandler := h2c.NewHandler(handler, &http2.Server{})
  342. l.listener = listener
  343. l.server = http.Server{
  344. Handler: h2cHandler,
  345. ReadHeaderTimeout: time.Second * 4,
  346. MaxHeaderBytes: 8192,
  347. }
  348. go func() {
  349. if err := l.server.Serve(l.listener); err != nil {
  350. errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
  351. }
  352. }()
  353. }
  354. return l, err
  355. }
  356. // Addr implements net.Listener.Addr().
  357. func (ln *Listener) Addr() net.Addr {
  358. if ln.h3listener != nil {
  359. return ln.h3listener.Addr()
  360. }
  361. if ln.listener != nil {
  362. return ln.listener.Addr()
  363. }
  364. return nil
  365. }
  366. // Close implements net.Listener.Close().
  367. func (ln *Listener) Close() error {
  368. if ln.h3server != nil {
  369. if err := ln.h3server.Close(); err != nil {
  370. return err
  371. }
  372. } else if ln.listener != nil {
  373. return ln.listener.Close()
  374. }
  375. return errors.New("listener does not have an HTTP/3 server or a net.listener")
  376. }
  377. func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config {
  378. config := v2tls.ConfigFromStreamSettings(streamSettings)
  379. if config == nil {
  380. return &tls.Config{}
  381. }
  382. return config.GetTLSConfig()
  383. }
  384. func init() {
  385. common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
  386. }