http.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package networkquality
  2. import (
  3. "context"
  4. "fmt"
  5. "net"
  6. "net/http"
  7. "strings"
  8. C "github.com/sagernet/sing-box/constant"
  9. sBufio "github.com/sagernet/sing/common/bufio"
  10. E "github.com/sagernet/sing/common/exceptions"
  11. M "github.com/sagernet/sing/common/metadata"
  12. N "github.com/sagernet/sing/common/network"
  13. )
  14. func FormatBitrate(bps int64) string {
  15. switch {
  16. case bps >= 1_000_000_000:
  17. return fmt.Sprintf("%.1f Gbps", float64(bps)/1_000_000_000)
  18. case bps >= 1_000_000:
  19. return fmt.Sprintf("%.1f Mbps", float64(bps)/1_000_000)
  20. case bps >= 1_000:
  21. return fmt.Sprintf("%.1f Kbps", float64(bps)/1_000)
  22. default:
  23. return fmt.Sprintf("%d bps", bps)
  24. }
  25. }
  26. func NewHTTPClient(dialer N.Dialer) *http.Client {
  27. transport := &http.Transport{
  28. ForceAttemptHTTP2: true,
  29. TLSHandshakeTimeout: C.TCPTimeout,
  30. }
  31. if dialer != nil {
  32. transport.DialContext = func(ctx context.Context, network string, addr string) (net.Conn, error) {
  33. return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
  34. }
  35. }
  36. return &http.Client{Transport: transport}
  37. }
  38. func baseTransportFromClient(client *http.Client) (*http.Transport, error) {
  39. if client == nil {
  40. return nil, E.New("http client is nil")
  41. }
  42. if client.Transport == nil {
  43. return http.DefaultTransport.(*http.Transport).Clone(), nil
  44. }
  45. transport, ok := client.Transport.(*http.Transport)
  46. if !ok {
  47. return nil, E.New("http client transport must be *http.Transport")
  48. }
  49. return transport.Clone(), nil
  50. }
  51. func newMeasurementClient(
  52. baseClient *http.Client,
  53. connectEndpoint string,
  54. singleConnection bool,
  55. disableKeepAlives bool,
  56. readCounters []N.CountFunc,
  57. writeCounters []N.CountFunc,
  58. ) (*http.Client, error) {
  59. transport, err := baseTransportFromClient(baseClient)
  60. if err != nil {
  61. return nil, err
  62. }
  63. transport.DisableCompression = true
  64. transport.DisableKeepAlives = disableKeepAlives
  65. if singleConnection {
  66. transport.MaxConnsPerHost = 1
  67. transport.MaxIdleConnsPerHost = 1
  68. transport.MaxIdleConns = 1
  69. }
  70. baseDialContext := transport.DialContext
  71. if baseDialContext == nil {
  72. dialer := &net.Dialer{}
  73. baseDialContext = dialer.DialContext
  74. }
  75. transport.DialContext = func(ctx context.Context, network string, addr string) (net.Conn, error) {
  76. dialAddr := addr
  77. if connectEndpoint != "" {
  78. dialAddr = rewriteDialAddress(addr, connectEndpoint)
  79. }
  80. conn, dialErr := baseDialContext(ctx, network, dialAddr)
  81. if dialErr != nil {
  82. return nil, dialErr
  83. }
  84. if len(readCounters) > 0 || len(writeCounters) > 0 {
  85. return sBufio.NewCounterConn(conn, readCounters, writeCounters), nil
  86. }
  87. return conn, nil
  88. }
  89. return &http.Client{
  90. Transport: transport,
  91. CheckRedirect: baseClient.CheckRedirect,
  92. Jar: baseClient.Jar,
  93. Timeout: baseClient.Timeout,
  94. }, nil
  95. }
  96. type MeasurementClientFactory func(
  97. connectEndpoint string,
  98. singleConnection bool,
  99. disableKeepAlives bool,
  100. readCounters []N.CountFunc,
  101. writeCounters []N.CountFunc,
  102. ) (*http.Client, error)
  103. func defaultMeasurementClientFactory(baseClient *http.Client) MeasurementClientFactory {
  104. return func(connectEndpoint string, singleConnection, disableKeepAlives bool, readCounters, writeCounters []N.CountFunc) (*http.Client, error) {
  105. return newMeasurementClient(baseClient, connectEndpoint, singleConnection, disableKeepAlives, readCounters, writeCounters)
  106. }
  107. }
  108. func NewOptionalHTTP3Factory(dialer N.Dialer, useHTTP3 bool) (MeasurementClientFactory, error) {
  109. if !useHTTP3 {
  110. return nil, nil
  111. }
  112. return NewHTTP3MeasurementClientFactory(dialer)
  113. }
  114. func rewriteDialAddress(addr string, connectEndpoint string) string {
  115. connectEndpoint = strings.TrimSpace(connectEndpoint)
  116. host, port, err := net.SplitHostPort(addr)
  117. if err != nil {
  118. return addr
  119. }
  120. endpointHost, endpointPort, err := net.SplitHostPort(connectEndpoint)
  121. if err == nil {
  122. host = endpointHost
  123. if endpointPort != "" {
  124. port = endpointPort
  125. }
  126. } else if connectEndpoint != "" {
  127. host = connectEndpoint
  128. }
  129. return net.JoinHostPort(host, port)
  130. }