hub.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. package splithttp
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "io"
  6. gonet "net"
  7. "net/http"
  8. "strconv"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/quic-go/quic-go"
  13. "github.com/quic-go/quic-go/http3"
  14. goreality "github.com/xtls/reality"
  15. "github.com/xtls/xray-core/common"
  16. "github.com/xtls/xray-core/common/errors"
  17. "github.com/xtls/xray-core/common/net"
  18. http_proto "github.com/xtls/xray-core/common/protocol/http"
  19. "github.com/xtls/xray-core/common/signal/done"
  20. "github.com/xtls/xray-core/transport/internet"
  21. "github.com/xtls/xray-core/transport/internet/reality"
  22. "github.com/xtls/xray-core/transport/internet/stat"
  23. v2tls "github.com/xtls/xray-core/transport/internet/tls"
  24. "golang.org/x/net/http2"
  25. "golang.org/x/net/http2/h2c"
  26. )
  27. type requestHandler struct {
  28. config *Config
  29. host string
  30. path string
  31. ln *Listener
  32. sessionMu *sync.Mutex
  33. sessions sync.Map
  34. localAddr gonet.TCPAddr
  35. }
  36. type httpSession struct {
  37. uploadQueue *uploadQueue
  38. // for as long as the GET request is not opened by the client, this will be
  39. // open ("undone"), and the session may be expired within a certain TTL.
  40. // after the client connects, this becomes "done" and the session lives as
  41. // long as the GET request.
  42. isFullyConnected *done.Instance
  43. }
  44. func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessionId string) {
  45. shouldReap := done.New()
  46. go func() {
  47. time.Sleep(30 * time.Second)
  48. shouldReap.Close()
  49. }()
  50. select {
  51. case <-isFullyConnected.Wait():
  52. return
  53. case <-shouldReap.Wait():
  54. h.sessions.Delete(sessionId)
  55. }
  56. }
  57. func (h *requestHandler) upsertSession(sessionId string) *httpSession {
  58. // fast path
  59. currentSessionAny, ok := h.sessions.Load(sessionId)
  60. if ok {
  61. return currentSessionAny.(*httpSession)
  62. }
  63. // slow path
  64. h.sessionMu.Lock()
  65. defer h.sessionMu.Unlock()
  66. currentSessionAny, ok = h.sessions.Load(sessionId)
  67. if ok {
  68. return currentSessionAny.(*httpSession)
  69. }
  70. s := &httpSession{
  71. uploadQueue: NewUploadQueue(int(h.ln.config.GetNormalizedScMaxConcurrentPosts().To)),
  72. isFullyConnected: done.New(),
  73. }
  74. h.sessions.Store(sessionId, s)
  75. go h.maybeReapSession(s.isFullyConnected, sessionId)
  76. return s
  77. }
  78. func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
  79. if len(h.host) > 0 && !internet.IsValidHTTPHost(request.Host, h.host) {
  80. errors.LogInfo(context.Background(), "failed to validate host, request:", request.Host, ", config:", h.host)
  81. writer.WriteHeader(http.StatusNotFound)
  82. return
  83. }
  84. if !strings.HasPrefix(request.URL.Path, h.path) {
  85. errors.LogInfo(context.Background(), "failed to validate path, request:", request.URL.Path, ", config:", h.path)
  86. writer.WriteHeader(http.StatusNotFound)
  87. return
  88. }
  89. h.config.WriteResponseHeader(writer)
  90. sessionId := ""
  91. subpath := strings.Split(request.URL.Path[len(h.path):], "/")
  92. if len(subpath) > 0 {
  93. sessionId = subpath[0]
  94. }
  95. if sessionId == "" {
  96. errors.LogInfo(context.Background(), "no sessionid on request:", request.URL.Path)
  97. writer.WriteHeader(http.StatusBadRequest)
  98. return
  99. }
  100. forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
  101. remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr)
  102. if err != nil {
  103. remoteAddr = &gonet.TCPAddr{}
  104. }
  105. if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
  106. remoteAddr = &net.TCPAddr{
  107. IP: forwardedAddrs[0].IP(),
  108. Port: int(0),
  109. }
  110. }
  111. currentSession := h.upsertSession(sessionId)
  112. scMaxEachPostBytes := int(h.ln.config.GetNormalizedScMaxEachPostBytes().To)
  113. if request.Method == "POST" {
  114. seq := ""
  115. if len(subpath) > 1 {
  116. seq = subpath[1]
  117. }
  118. if seq == "" {
  119. if h.config.Mode == "packet-up" {
  120. errors.LogInfo(context.Background(), "stream-up mode is not allowed")
  121. writer.WriteHeader(http.StatusBadRequest)
  122. return
  123. }
  124. err = currentSession.uploadQueue.Push(Packet{
  125. Reader: request.Body,
  126. })
  127. if err != nil {
  128. errors.LogInfoInner(context.Background(), err, "failed to upload (PushReader)")
  129. writer.WriteHeader(http.StatusConflict)
  130. } else {
  131. writer.WriteHeader(http.StatusOK)
  132. <-request.Context().Done()
  133. }
  134. return
  135. }
  136. if h.config.Mode == "stream-up" {
  137. errors.LogInfo(context.Background(), "packet-up mode is not allowed")
  138. writer.WriteHeader(http.StatusBadRequest)
  139. return
  140. }
  141. payload, err := io.ReadAll(request.Body)
  142. if len(payload) > scMaxEachPostBytes {
  143. errors.LogInfo(context.Background(), "Too large upload. scMaxEachPostBytes is set to ", scMaxEachPostBytes, "but request had size ", len(payload), ". Adjust scMaxEachPostBytes on the server to be at least as large as client.")
  144. writer.WriteHeader(http.StatusRequestEntityTooLarge)
  145. return
  146. }
  147. if err != nil {
  148. errors.LogInfoInner(context.Background(), err, "failed to upload (ReadAll)")
  149. writer.WriteHeader(http.StatusInternalServerError)
  150. return
  151. }
  152. seqInt, err := strconv.ParseUint(seq, 10, 64)
  153. if err != nil {
  154. errors.LogInfoInner(context.Background(), err, "failed to upload (ParseUint)")
  155. writer.WriteHeader(http.StatusInternalServerError)
  156. return
  157. }
  158. err = currentSession.uploadQueue.Push(Packet{
  159. Payload: payload,
  160. Seq: seqInt,
  161. })
  162. if err != nil {
  163. errors.LogInfoInner(context.Background(), err, "failed to upload (PushPayload)")
  164. writer.WriteHeader(http.StatusInternalServerError)
  165. return
  166. }
  167. writer.WriteHeader(http.StatusOK)
  168. } else if request.Method == "GET" {
  169. responseFlusher, ok := writer.(http.Flusher)
  170. if !ok {
  171. panic("expected http.ResponseWriter to be an http.Flusher")
  172. }
  173. // after GET is done, the connection is finished. disable automatic
  174. // session reaping, and handle it in defer
  175. currentSession.isFullyConnected.Close()
  176. defer h.sessions.Delete(sessionId)
  177. // magic header instructs nginx + apache to not buffer response body
  178. writer.Header().Set("X-Accel-Buffering", "no")
  179. // A web-compliant header telling all middleboxes to disable caching.
  180. // Should be able to prevent overloading the cache, or stop CDNs from
  181. // teeing the response stream into their cache, causing slowdowns.
  182. writer.Header().Set("Cache-Control", "no-store")
  183. if !h.config.NoSSEHeader {
  184. // magic header to make the HTTP middle box consider this as SSE to disable buffer
  185. writer.Header().Set("Content-Type", "text/event-stream")
  186. }
  187. writer.WriteHeader(http.StatusOK)
  188. responseFlusher.Flush()
  189. downloadDone := done.New()
  190. conn := splitConn{
  191. writer: &httpResponseBodyWriter{
  192. responseWriter: writer,
  193. downloadDone: downloadDone,
  194. responseFlusher: responseFlusher,
  195. },
  196. reader: currentSession.uploadQueue,
  197. remoteAddr: remoteAddr,
  198. }
  199. h.ln.addConn(stat.Connection(&conn))
  200. // "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
  201. select {
  202. case <-request.Context().Done():
  203. case <-downloadDone.Wait():
  204. }
  205. conn.Close()
  206. } else {
  207. errors.LogInfo(context.Background(), "unsupported method: ", request.Method)
  208. writer.WriteHeader(http.StatusMethodNotAllowed)
  209. }
  210. }
  211. type httpResponseBodyWriter struct {
  212. sync.Mutex
  213. responseWriter http.ResponseWriter
  214. responseFlusher http.Flusher
  215. downloadDone *done.Instance
  216. }
  217. func (c *httpResponseBodyWriter) Write(b []byte) (int, error) {
  218. c.Lock()
  219. defer c.Unlock()
  220. if c.downloadDone.Done() {
  221. return 0, io.ErrClosedPipe
  222. }
  223. n, err := c.responseWriter.Write(b)
  224. if err == nil {
  225. c.responseFlusher.Flush()
  226. }
  227. return n, err
  228. }
  229. func (c *httpResponseBodyWriter) Close() error {
  230. c.Lock()
  231. defer c.Unlock()
  232. c.downloadDone.Close()
  233. return nil
  234. }
  235. type Listener struct {
  236. sync.Mutex
  237. server http.Server
  238. h3server *http3.Server
  239. listener net.Listener
  240. h3listener *quic.EarlyListener
  241. config *Config
  242. addConn internet.ConnHandler
  243. isH3 bool
  244. }
  245. func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
  246. l := &Listener{
  247. addConn: addConn,
  248. }
  249. shSettings := streamSettings.ProtocolSettings.(*Config)
  250. l.config = shSettings
  251. if l.config != nil {
  252. if streamSettings.SocketSettings == nil {
  253. streamSettings.SocketSettings = &internet.SocketConfig{}
  254. }
  255. }
  256. var listener net.Listener
  257. var err error
  258. var localAddr = gonet.TCPAddr{}
  259. handler := &requestHandler{
  260. config: shSettings,
  261. host: shSettings.Host,
  262. path: shSettings.GetNormalizedPath(),
  263. ln: l,
  264. sessionMu: &sync.Mutex{},
  265. sessions: sync.Map{},
  266. localAddr: localAddr,
  267. }
  268. tlsConfig := getTLSConfig(streamSettings)
  269. l.isH3 = len(tlsConfig.NextProtos) == 1 && tlsConfig.NextProtos[0] == "h3"
  270. if port == net.Port(0) { // unix
  271. listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
  272. Name: address.Domain(),
  273. Net: "unix",
  274. }, streamSettings.SocketSettings)
  275. if err != nil {
  276. return nil, errors.New("failed to listen unix domain socket(for SH) on ", address).Base(err)
  277. }
  278. errors.LogInfo(ctx, "listening unix domain socket(for SH) on ", address)
  279. } else if l.isH3 { // quic
  280. Conn, err := internet.ListenSystemPacket(context.Background(), &net.UDPAddr{
  281. IP: address.IP(),
  282. Port: int(port),
  283. }, streamSettings.SocketSettings)
  284. if err != nil {
  285. return nil, errors.New("failed to listen UDP(for SH3) on ", address, ":", port).Base(err)
  286. }
  287. h3listener, err := quic.ListenEarly(Conn, tlsConfig, nil)
  288. if err != nil {
  289. return nil, errors.New("failed to listen QUIC(for SH3) on ", address, ":", port).Base(err)
  290. }
  291. l.h3listener = h3listener
  292. errors.LogInfo(ctx, "listening QUIC(for SH3) on ", address, ":", port)
  293. l.h3server = &http3.Server{
  294. Handler: handler,
  295. }
  296. go func() {
  297. if err := l.h3server.ServeListener(l.h3listener); err != nil {
  298. errors.LogWarningInner(ctx, err, "failed to serve http3 for splithttp")
  299. }
  300. }()
  301. } else { // tcp
  302. localAddr = gonet.TCPAddr{
  303. IP: address.IP(),
  304. Port: int(port),
  305. }
  306. listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
  307. IP: address.IP(),
  308. Port: int(port),
  309. }, streamSettings.SocketSettings)
  310. if err != nil {
  311. return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
  312. }
  313. errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)
  314. }
  315. // tcp/unix (h1/h2)
  316. if listener != nil {
  317. if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
  318. if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
  319. listener = tls.NewListener(listener, tlsConfig)
  320. }
  321. }
  322. if config := reality.ConfigFromStreamSettings(streamSettings); config != nil {
  323. listener = goreality.NewListener(listener, config.GetREALITYConfig())
  324. }
  325. // h2cHandler can handle both plaintext HTTP/1.1 and h2c
  326. h2cHandler := h2c.NewHandler(handler, &http2.Server{})
  327. l.listener = listener
  328. l.server = http.Server{
  329. Handler: h2cHandler,
  330. ReadHeaderTimeout: time.Second * 4,
  331. MaxHeaderBytes: 8192,
  332. }
  333. go func() {
  334. if err := l.server.Serve(l.listener); err != nil {
  335. errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
  336. }
  337. }()
  338. }
  339. return l, err
  340. }
  341. // Addr implements net.Listener.Addr().
  342. func (ln *Listener) Addr() net.Addr {
  343. if ln.h3listener != nil {
  344. return ln.h3listener.Addr()
  345. }
  346. if ln.listener != nil {
  347. return ln.listener.Addr()
  348. }
  349. return nil
  350. }
  351. // Close implements net.Listener.Close().
  352. func (ln *Listener) Close() error {
  353. if ln.h3server != nil {
  354. if err := ln.h3server.Close(); err != nil {
  355. return err
  356. }
  357. } else if ln.listener != nil {
  358. return ln.listener.Close()
  359. }
  360. return errors.New("listener does not have an HTTP/3 server or a net.listener")
  361. }
  362. func getTLSConfig(streamSettings *internet.MemoryStreamConfig) *tls.Config {
  363. config := v2tls.ConfigFromStreamSettings(streamSettings)
  364. if config == nil {
  365. return &tls.Config{}
  366. }
  367. return config.GetTLSConfig()
  368. }
  369. func init() {
  370. common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
  371. }