http3_transport.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. //go:build with_quic
  2. package httpclient
  3. import (
  4. "context"
  5. stdTLS "crypto/tls"
  6. "errors"
  7. "net/http"
  8. "sync"
  9. "time"
  10. "github.com/sagernet/quic-go"
  11. "github.com/sagernet/quic-go/http3"
  12. "github.com/sagernet/sing-box/common/tls"
  13. "github.com/sagernet/sing-box/option"
  14. "github.com/sagernet/sing/common/bufio"
  15. E "github.com/sagernet/sing/common/exceptions"
  16. M "github.com/sagernet/sing/common/metadata"
  17. N "github.com/sagernet/sing/common/network"
  18. )
  19. type http3Transport struct {
  20. h3Transport *http3.Transport
  21. }
  22. type http3BrokenEntry struct {
  23. until time.Time
  24. backoff time.Duration
  25. }
  26. type http3FallbackTransport struct {
  27. h3Transport *http3.Transport
  28. h2Fallback innerTransport
  29. fallbackDelay time.Duration
  30. brokenAccess sync.Mutex
  31. broken map[string]http3BrokenEntry
  32. }
  33. func newHTTP3RoundTripper(
  34. rawDialer N.Dialer,
  35. baseTLSConfig tls.Config,
  36. options option.QUICOptions,
  37. ) *http3.Transport {
  38. var handshakeTimeout time.Duration
  39. if baseTLSConfig != nil {
  40. handshakeTimeout = baseTLSConfig.HandshakeTimeout()
  41. }
  42. quicConfig := &quic.Config{
  43. InitialStreamReceiveWindow: options.StreamReceiveWindow.Value(),
  44. MaxStreamReceiveWindow: options.StreamReceiveWindow.Value(),
  45. InitialConnectionReceiveWindow: options.ConnectionReceiveWindow.Value(),
  46. MaxConnectionReceiveWindow: options.ConnectionReceiveWindow.Value(),
  47. KeepAlivePeriod: time.Duration(options.KeepAlivePeriod),
  48. MaxIdleTimeout: time.Duration(options.IdleTimeout),
  49. DisablePathMTUDiscovery: options.DisablePathMTUDiscovery,
  50. }
  51. if options.InitialPacketSize > 0 {
  52. quicConfig.InitialPacketSize = uint16(options.InitialPacketSize)
  53. }
  54. if options.MaxConcurrentStreams > 0 {
  55. quicConfig.MaxIncomingStreams = int64(options.MaxConcurrentStreams)
  56. }
  57. if handshakeTimeout > 0 {
  58. quicConfig.HandshakeIdleTimeout = handshakeTimeout
  59. }
  60. h3Transport := &http3.Transport{
  61. TLSClientConfig: &stdTLS.Config{},
  62. QUICConfig: quicConfig,
  63. Dial: func(ctx context.Context, addr string, tlsConfig *stdTLS.Config, quicConfig *quic.Config) (*quic.Conn, error) {
  64. if handshakeTimeout > 0 && quicConfig.HandshakeIdleTimeout == 0 {
  65. quicConfig = quicConfig.Clone()
  66. quicConfig.HandshakeIdleTimeout = handshakeTimeout
  67. }
  68. if baseTLSConfig != nil {
  69. var err error
  70. tlsConfig, err = buildSTDTLSConfig(baseTLSConfig, M.ParseSocksaddr(addr), []string{http3.NextProtoH3})
  71. if err != nil {
  72. return nil, err
  73. }
  74. } else {
  75. tlsConfig = tlsConfig.Clone()
  76. tlsConfig.NextProtos = []string{http3.NextProtoH3}
  77. }
  78. conn, err := rawDialer.DialContext(ctx, N.NetworkUDP, M.ParseSocksaddr(addr))
  79. if err != nil {
  80. return nil, err
  81. }
  82. quicConn, err := quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsConfig, quicConfig)
  83. if err != nil {
  84. conn.Close()
  85. return nil, err
  86. }
  87. return quicConn, nil
  88. },
  89. }
  90. return h3Transport
  91. }
  92. func newHTTP3Transport(
  93. rawDialer N.Dialer,
  94. baseTLSConfig tls.Config,
  95. options option.QUICOptions,
  96. ) (innerTransport, error) {
  97. return &http3Transport{
  98. h3Transport: newHTTP3RoundTripper(rawDialer, baseTLSConfig, options),
  99. }, nil
  100. }
  101. func newHTTP3FallbackTransport(
  102. rawDialer N.Dialer,
  103. baseTLSConfig tls.Config,
  104. h2Fallback innerTransport,
  105. options option.QUICOptions,
  106. fallbackDelay time.Duration,
  107. ) (innerTransport, error) {
  108. return &http3FallbackTransport{
  109. h3Transport: newHTTP3RoundTripper(rawDialer, baseTLSConfig, options),
  110. h2Fallback: h2Fallback,
  111. fallbackDelay: fallbackDelay,
  112. broken: make(map[string]http3BrokenEntry),
  113. }, nil
  114. }
  115. func (t *http3Transport) RoundTrip(request *http.Request) (*http.Response, error) {
  116. return t.h3Transport.RoundTrip(request)
  117. }
  118. func (t *http3Transport) CloseIdleConnections() {
  119. t.h3Transport.CloseIdleConnections()
  120. }
  121. func (t *http3Transport) Close() error {
  122. t.CloseIdleConnections()
  123. return t.h3Transport.Close()
  124. }
  125. func (t *http3FallbackTransport) RoundTrip(request *http.Request) (*http.Response, error) {
  126. if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
  127. return t.h2Fallback.RoundTrip(request)
  128. }
  129. return t.roundTripHTTP3(request)
  130. }
  131. func (t *http3FallbackTransport) roundTripHTTP3(request *http.Request) (*http.Response, error) {
  132. authority := requestAuthority(request)
  133. if t.h3Broken(authority) {
  134. return t.h2FallbackRoundTrip(request)
  135. }
  136. response, err := t.h3Transport.RoundTripOpt(request, http3.RoundTripOpt{OnlyCachedConn: true})
  137. if err == nil {
  138. t.clearH3Broken(authority)
  139. return response, nil
  140. }
  141. if !errors.Is(err, http3.ErrNoCachedConn) {
  142. t.markH3Broken(authority)
  143. return t.h2FallbackRoundTrip(cloneRequestForRetry(request))
  144. }
  145. if !requestReplayable(request) {
  146. response, err = t.h3Transport.RoundTrip(request)
  147. if err == nil {
  148. t.clearH3Broken(authority)
  149. return response, nil
  150. }
  151. t.markH3Broken(authority)
  152. return nil, err
  153. }
  154. return t.roundTripHTTP3Race(request, authority)
  155. }
  156. func (t *http3FallbackTransport) roundTripHTTP3Race(request *http.Request, authority string) (*http.Response, error) {
  157. ctx, cancel := context.WithCancel(request.Context())
  158. defer cancel()
  159. type result struct {
  160. response *http.Response
  161. err error
  162. h3 bool
  163. }
  164. results := make(chan result, 2)
  165. startRoundTrip := func(request *http.Request, useH3 bool) {
  166. request = request.WithContext(ctx)
  167. var (
  168. response *http.Response
  169. err error
  170. )
  171. if useH3 {
  172. response, err = t.h3Transport.RoundTrip(request)
  173. } else {
  174. response, err = t.h2FallbackRoundTrip(request)
  175. }
  176. results <- result{response: response, err: err, h3: useH3}
  177. }
  178. goroutines := 1
  179. received := 0
  180. drainRemaining := func() {
  181. cancel()
  182. for range goroutines - received {
  183. go func() {
  184. loser := <-results
  185. if loser.response != nil && loser.response.Body != nil {
  186. loser.response.Body.Close()
  187. }
  188. }()
  189. }
  190. }
  191. go startRoundTrip(cloneRequestForRetry(request), true)
  192. timer := time.NewTimer(t.fallbackDelay)
  193. defer timer.Stop()
  194. var (
  195. h3Err error
  196. fallbackErr error
  197. )
  198. for {
  199. select {
  200. case <-timer.C:
  201. if goroutines == 1 {
  202. goroutines++
  203. go startRoundTrip(cloneRequestForRetry(request), false)
  204. }
  205. case raceResult := <-results:
  206. received++
  207. if raceResult.err == nil {
  208. if raceResult.h3 {
  209. t.clearH3Broken(authority)
  210. }
  211. drainRemaining()
  212. return raceResult.response, nil
  213. }
  214. if raceResult.h3 {
  215. t.markH3Broken(authority)
  216. h3Err = raceResult.err
  217. if goroutines == 1 {
  218. goroutines++
  219. if !timer.Stop() {
  220. select {
  221. case <-timer.C:
  222. default:
  223. }
  224. }
  225. go startRoundTrip(cloneRequestForRetry(request), false)
  226. }
  227. } else {
  228. fallbackErr = raceResult.err
  229. }
  230. if received < goroutines {
  231. continue
  232. }
  233. drainRemaining()
  234. switch {
  235. case h3Err != nil && fallbackErr != nil:
  236. return nil, E.Errors(h3Err, fallbackErr)
  237. case fallbackErr != nil:
  238. return nil, fallbackErr
  239. default:
  240. return nil, h3Err
  241. }
  242. }
  243. }
  244. }
  245. func (t *http3FallbackTransport) h2FallbackRoundTrip(request *http.Request) (*http.Response, error) {
  246. if fallback, isFallback := t.h2Fallback.(*http2FallbackTransport); isFallback {
  247. return fallback.roundTrip(request, true)
  248. }
  249. return t.h2Fallback.RoundTrip(request)
  250. }
  251. func (t *http3FallbackTransport) CloseIdleConnections() {
  252. t.h3Transport.CloseIdleConnections()
  253. t.h2Fallback.CloseIdleConnections()
  254. }
  255. func (t *http3FallbackTransport) Close() error {
  256. t.CloseIdleConnections()
  257. return t.h3Transport.Close()
  258. }
  259. func (t *http3FallbackTransport) h3Broken(authority string) bool {
  260. if authority == "" {
  261. return false
  262. }
  263. t.brokenAccess.Lock()
  264. defer t.brokenAccess.Unlock()
  265. entry, found := t.broken[authority]
  266. if !found {
  267. return false
  268. }
  269. if entry.until.IsZero() || !time.Now().Before(entry.until) {
  270. delete(t.broken, authority)
  271. return false
  272. }
  273. return true
  274. }
  275. func (t *http3FallbackTransport) clearH3Broken(authority string) {
  276. if authority == "" {
  277. return
  278. }
  279. t.brokenAccess.Lock()
  280. delete(t.broken, authority)
  281. t.brokenAccess.Unlock()
  282. }
  283. func (t *http3FallbackTransport) markH3Broken(authority string) {
  284. if authority == "" {
  285. return
  286. }
  287. t.brokenAccess.Lock()
  288. defer t.brokenAccess.Unlock()
  289. entry := t.broken[authority]
  290. if entry.backoff == 0 {
  291. entry.backoff = 5 * time.Minute
  292. } else {
  293. entry.backoff *= 2
  294. if entry.backoff > 48*time.Hour {
  295. entry.backoff = 48 * time.Hour
  296. }
  297. }
  298. entry.until = time.Now().Add(entry.backoff)
  299. t.broken[authority] = entry
  300. }