http2_fallback_transport.go 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package httpclient
  2. import (
  3. "context"
  4. stdTLS "crypto/tls"
  5. "errors"
  6. "net"
  7. "net/http"
  8. "sync"
  9. "github.com/sagernet/sing-box/common/tls"
  10. "github.com/sagernet/sing-box/option"
  11. E "github.com/sagernet/sing/common/exceptions"
  12. M "github.com/sagernet/sing/common/metadata"
  13. N "github.com/sagernet/sing/common/network"
  14. "golang.org/x/net/http2"
  15. )
  16. var errHTTP2Fallback = E.New("fallback to HTTP/1.1")
  17. type http2FallbackTransport struct {
  18. h2Transport *http2.Transport
  19. h1Transport *http1Transport
  20. fallbackAccess sync.RWMutex
  21. fallbackAuthority map[string]struct{}
  22. }
  23. func newHTTP2FallbackTransport(rawDialer N.Dialer, baseTLSConfig tls.Config, options option.HTTP2Options) (*http2FallbackTransport, error) {
  24. h1 := newHTTP1Transport(rawDialer, baseTLSConfig)
  25. h2Transport, err := ConfigureHTTP2Transport(options)
  26. if err != nil {
  27. return nil, err
  28. }
  29. h2Transport.DialTLSContext = func(ctx context.Context, network, addr string, _ *stdTLS.Config) (net.Conn, error) {
  30. return dialTLS(ctx, rawDialer, baseTLSConfig, M.ParseSocksaddr(addr), []string{http2.NextProtoTLS, "http/1.1"}, http2.NextProtoTLS)
  31. }
  32. return &http2FallbackTransport{
  33. h2Transport: h2Transport,
  34. h1Transport: h1,
  35. fallbackAuthority: make(map[string]struct{}),
  36. }, nil
  37. }
  38. func (t *http2FallbackTransport) isH2Fallback(authority string) bool {
  39. if authority == "" {
  40. return false
  41. }
  42. t.fallbackAccess.RLock()
  43. _, found := t.fallbackAuthority[authority]
  44. t.fallbackAccess.RUnlock()
  45. return found
  46. }
  47. func (t *http2FallbackTransport) markH2Fallback(authority string) {
  48. if authority == "" {
  49. return
  50. }
  51. t.fallbackAccess.Lock()
  52. t.fallbackAuthority[authority] = struct{}{}
  53. t.fallbackAccess.Unlock()
  54. }
  55. func (t *http2FallbackTransport) RoundTrip(request *http.Request) (*http.Response, error) {
  56. return t.roundTrip(request, true)
  57. }
  58. func (t *http2FallbackTransport) roundTrip(request *http.Request, allowHTTP1Fallback bool) (*http.Response, error) {
  59. if request.URL.Scheme != "https" || requestRequiresHTTP1(request) {
  60. return t.h1Transport.RoundTrip(request)
  61. }
  62. authority := requestAuthority(request)
  63. if t.isH2Fallback(authority) {
  64. if !allowHTTP1Fallback {
  65. return nil, errHTTP2Fallback
  66. }
  67. return t.h1Transport.RoundTrip(request)
  68. }
  69. response, err := t.h2Transport.RoundTrip(request)
  70. if err == nil {
  71. return response, nil
  72. }
  73. if !errors.Is(err, errHTTP2Fallback) || !allowHTTP1Fallback {
  74. return nil, err
  75. }
  76. t.markH2Fallback(authority)
  77. return t.h1Transport.RoundTrip(cloneRequestForRetry(request))
  78. }
  79. func (t *http2FallbackTransport) CloseIdleConnections() {
  80. t.h1Transport.CloseIdleConnections()
  81. t.h2Transport.CloseIdleConnections()
  82. }
  83. func (t *http2FallbackTransport) Close() error {
  84. t.CloseIdleConnections()
  85. return nil
  86. }