| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- package httpclient
- import (
- "io"
- "net/http"
- "sync"
- "sync/atomic"
- "github.com/sagernet/sing-box/adapter"
- E "github.com/sagernet/sing/common/exceptions"
- N "github.com/sagernet/sing/common/network"
- )
- type innerTransport interface {
- http.RoundTripper
- CloseIdleConnections()
- Close() error
- }
- var _ adapter.HTTPTransport = (*ManagedTransport)(nil)
- type ManagedTransport struct {
- epoch atomic.Pointer[transportEpoch]
- rebuildAccess sync.Mutex
- factory func() (innerTransport, error)
- cheapRebuild bool
- dialer N.Dialer
- headers http.Header
- host string
- tag string
- }
- type transportEpoch struct {
- transport innerTransport
- active atomic.Int64
- marked atomic.Bool
- closeOnce sync.Once
- }
- type managedResponseBody struct {
- body io.ReadCloser
- release func()
- once sync.Once
- }
- func (e *transportEpoch) tryClose() {
- e.closeOnce.Do(func() {
- e.transport.Close()
- })
- }
- func (b *managedResponseBody) Read(p []byte) (int, error) {
- return b.body.Read(p)
- }
- func (b *managedResponseBody) Close() error {
- err := b.body.Close()
- b.once.Do(b.release)
- return err
- }
- func (t *ManagedTransport) getEpoch() (*transportEpoch, error) {
- epoch := t.epoch.Load()
- if epoch != nil {
- return epoch, nil
- }
- t.rebuildAccess.Lock()
- defer t.rebuildAccess.Unlock()
- epoch = t.epoch.Load()
- if epoch != nil {
- return epoch, nil
- }
- inner, err := t.factory()
- if err != nil {
- return nil, err
- }
- epoch = &transportEpoch{transport: inner}
- t.epoch.Store(epoch)
- return epoch, nil
- }
- func (t *ManagedTransport) acquireEpoch() (*transportEpoch, error) {
- for {
- epoch, err := t.getEpoch()
- if err != nil {
- return nil, err
- }
- epoch.active.Add(1)
- if epoch == t.epoch.Load() {
- return epoch, nil
- }
- t.releaseEpoch(epoch)
- }
- }
- func (t *ManagedTransport) releaseEpoch(epoch *transportEpoch) {
- if epoch.active.Add(-1) == 0 && epoch.marked.Load() {
- epoch.tryClose()
- }
- }
- func (t *ManagedTransport) retireEpoch(epoch *transportEpoch) {
- if epoch == nil {
- return
- }
- epoch.marked.Store(true)
- if epoch.active.Load() == 0 {
- epoch.tryClose()
- }
- }
- func (t *ManagedTransport) RoundTrip(request *http.Request) (*http.Response, error) {
- epoch, err := t.acquireEpoch()
- if err != nil {
- return nil, E.Cause(err, "rebuild http transport")
- }
- if t.tag != "" {
- if transportTag, loaded := transportTagFromContext(request.Context()); loaded && transportTag == t.tag {
- t.releaseEpoch(epoch)
- return nil, E.New("HTTP request loopback in transport[", t.tag, "]")
- }
- request = request.Clone(contextWithTransportTag(request.Context(), t.tag))
- } else if len(t.headers) > 0 || t.host != "" {
- request = request.Clone(request.Context())
- }
- applyHeaders(request, t.headers, t.host)
- response, roundTripErr := epoch.transport.RoundTrip(request)
- if roundTripErr != nil || response == nil || response.Body == nil {
- t.releaseEpoch(epoch)
- return response, roundTripErr
- }
- response.Body = &managedResponseBody{
- body: response.Body,
- release: func() { t.releaseEpoch(epoch) },
- }
- return response, roundTripErr
- }
- func (t *ManagedTransport) CloseIdleConnections() {
- oldEpoch := t.epoch.Swap(nil)
- if oldEpoch == nil {
- return
- }
- oldEpoch.transport.CloseIdleConnections()
- t.retireEpoch(oldEpoch)
- }
- func (t *ManagedTransport) Reset() {
- oldEpoch := t.epoch.Swap(nil)
- if t.cheapRebuild {
- t.rebuildAccess.Lock()
- if t.epoch.Load() == nil {
- inner, err := t.factory()
- if err == nil {
- t.epoch.Store(&transportEpoch{transport: inner})
- }
- }
- t.rebuildAccess.Unlock()
- }
- t.retireEpoch(oldEpoch)
- }
- func (t *ManagedTransport) close() error {
- epoch := t.epoch.Swap(nil)
- if epoch != nil {
- return epoch.transport.Close()
- }
- return nil
- }
- var _ adapter.HTTPTransport = (*sharedRef)(nil)
- type sharedRef struct {
- managed *ManagedTransport
- shared *sharedState
- idle atomic.Bool
- }
- type sharedState struct {
- activeRefs atomic.Int32
- }
- func newSharedRef(managed *ManagedTransport, shared *sharedState) *sharedRef {
- shared.activeRefs.Add(1)
- return &sharedRef{
- managed: managed,
- shared: shared,
- }
- }
- func (r *sharedRef) RoundTrip(request *http.Request) (*http.Response, error) {
- if r.idle.CompareAndSwap(true, false) {
- r.shared.activeRefs.Add(1)
- }
- return r.managed.RoundTrip(request)
- }
- func (r *sharedRef) CloseIdleConnections() {
- if r.idle.CompareAndSwap(false, true) {
- if r.shared.activeRefs.Add(-1) == 0 {
- r.managed.CloseIdleConnections()
- }
- }
- }
- func (r *sharedRef) Reset() {
- r.managed.Reset()
- }
|