managed_transport.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. package httpclient
  2. import (
  3. "io"
  4. "net/http"
  5. "sync"
  6. "sync/atomic"
  7. "github.com/sagernet/sing-box/adapter"
  8. E "github.com/sagernet/sing/common/exceptions"
  9. N "github.com/sagernet/sing/common/network"
  10. )
  11. type innerTransport interface {
  12. http.RoundTripper
  13. CloseIdleConnections()
  14. Close() error
  15. }
  16. var _ adapter.HTTPTransport = (*ManagedTransport)(nil)
  17. type ManagedTransport struct {
  18. epoch atomic.Pointer[transportEpoch]
  19. rebuildAccess sync.Mutex
  20. factory func() (innerTransport, error)
  21. cheapRebuild bool
  22. dialer N.Dialer
  23. headers http.Header
  24. host string
  25. tag string
  26. }
  27. type transportEpoch struct {
  28. transport innerTransport
  29. active atomic.Int64
  30. marked atomic.Bool
  31. closeOnce sync.Once
  32. }
  33. type managedResponseBody struct {
  34. body io.ReadCloser
  35. release func()
  36. once sync.Once
  37. }
  38. func (e *transportEpoch) tryClose() {
  39. e.closeOnce.Do(func() {
  40. e.transport.Close()
  41. })
  42. }
  43. func (b *managedResponseBody) Read(p []byte) (int, error) {
  44. return b.body.Read(p)
  45. }
  46. func (b *managedResponseBody) Close() error {
  47. err := b.body.Close()
  48. b.once.Do(b.release)
  49. return err
  50. }
  51. func (t *ManagedTransport) getEpoch() (*transportEpoch, error) {
  52. epoch := t.epoch.Load()
  53. if epoch != nil {
  54. return epoch, nil
  55. }
  56. t.rebuildAccess.Lock()
  57. defer t.rebuildAccess.Unlock()
  58. epoch = t.epoch.Load()
  59. if epoch != nil {
  60. return epoch, nil
  61. }
  62. inner, err := t.factory()
  63. if err != nil {
  64. return nil, err
  65. }
  66. epoch = &transportEpoch{transport: inner}
  67. t.epoch.Store(epoch)
  68. return epoch, nil
  69. }
  70. func (t *ManagedTransport) acquireEpoch() (*transportEpoch, error) {
  71. for {
  72. epoch, err := t.getEpoch()
  73. if err != nil {
  74. return nil, err
  75. }
  76. epoch.active.Add(1)
  77. if epoch == t.epoch.Load() {
  78. return epoch, nil
  79. }
  80. t.releaseEpoch(epoch)
  81. }
  82. }
  83. func (t *ManagedTransport) releaseEpoch(epoch *transportEpoch) {
  84. if epoch.active.Add(-1) == 0 && epoch.marked.Load() {
  85. epoch.tryClose()
  86. }
  87. }
  88. func (t *ManagedTransport) retireEpoch(epoch *transportEpoch) {
  89. if epoch == nil {
  90. return
  91. }
  92. epoch.marked.Store(true)
  93. if epoch.active.Load() == 0 {
  94. epoch.tryClose()
  95. }
  96. }
  97. func (t *ManagedTransport) RoundTrip(request *http.Request) (*http.Response, error) {
  98. epoch, err := t.acquireEpoch()
  99. if err != nil {
  100. return nil, E.Cause(err, "rebuild http transport")
  101. }
  102. if t.tag != "" {
  103. if transportTag, loaded := transportTagFromContext(request.Context()); loaded && transportTag == t.tag {
  104. t.releaseEpoch(epoch)
  105. return nil, E.New("HTTP request loopback in transport[", t.tag, "]")
  106. }
  107. request = request.Clone(contextWithTransportTag(request.Context(), t.tag))
  108. } else if len(t.headers) > 0 || t.host != "" {
  109. request = request.Clone(request.Context())
  110. }
  111. applyHeaders(request, t.headers, t.host)
  112. response, roundTripErr := epoch.transport.RoundTrip(request)
  113. if roundTripErr != nil || response == nil || response.Body == nil {
  114. t.releaseEpoch(epoch)
  115. return response, roundTripErr
  116. }
  117. response.Body = &managedResponseBody{
  118. body: response.Body,
  119. release: func() { t.releaseEpoch(epoch) },
  120. }
  121. return response, roundTripErr
  122. }
  123. func (t *ManagedTransport) CloseIdleConnections() {
  124. oldEpoch := t.epoch.Swap(nil)
  125. if oldEpoch == nil {
  126. return
  127. }
  128. oldEpoch.transport.CloseIdleConnections()
  129. t.retireEpoch(oldEpoch)
  130. }
  131. func (t *ManagedTransport) Reset() {
  132. oldEpoch := t.epoch.Swap(nil)
  133. if t.cheapRebuild {
  134. t.rebuildAccess.Lock()
  135. if t.epoch.Load() == nil {
  136. inner, err := t.factory()
  137. if err == nil {
  138. t.epoch.Store(&transportEpoch{transport: inner})
  139. }
  140. }
  141. t.rebuildAccess.Unlock()
  142. }
  143. t.retireEpoch(oldEpoch)
  144. }
  145. func (t *ManagedTransport) close() error {
  146. epoch := t.epoch.Swap(nil)
  147. if epoch != nil {
  148. return epoch.transport.Close()
  149. }
  150. return nil
  151. }
  152. var _ adapter.HTTPTransport = (*sharedRef)(nil)
  153. type sharedRef struct {
  154. managed *ManagedTransport
  155. shared *sharedState
  156. idle atomic.Bool
  157. }
  158. type sharedState struct {
  159. activeRefs atomic.Int32
  160. }
  161. func newSharedRef(managed *ManagedTransport, shared *sharedState) *sharedRef {
  162. shared.activeRefs.Add(1)
  163. return &sharedRef{
  164. managed: managed,
  165. shared: shared,
  166. }
  167. }
  168. func (r *sharedRef) RoundTrip(request *http.Request) (*http.Response, error) {
  169. if r.idle.CompareAndSwap(true, false) {
  170. r.shared.activeRefs.Add(1)
  171. }
  172. return r.managed.RoundTrip(request)
  173. }
  174. func (r *sharedRef) CloseIdleConnections() {
  175. if r.idle.CompareAndSwap(false, true) {
  176. if r.shared.activeRefs.Add(-1) == 0 {
  177. r.managed.CloseIdleConnections()
  178. }
  179. }
  180. }
  181. func (r *sharedRef) Reset() {
  182. r.managed.Reset()
  183. }