Pārlūkot izejas kodu

Fix DNS transports

世界 1 dienu atpakaļ
vecāks
revīzija
3bbb23ea3b

+ 1 - 0
adapter/dns.go

@@ -68,6 +68,7 @@ type DNSTransport interface {
 	Type() string
 	Tag() string
 	Dependencies() []string
+	Reset()
 	Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
 }
 

+ 1 - 1
dns/router.go

@@ -444,6 +444,6 @@ func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) {
 func (r *Router) ResetNetwork() {
 	r.ClearCache()
 	for _, transport := range r.transport.Transports() {
-		transport.Close()
+		transport.Reset()
 	}
 }

+ 145 - 0
dns/transport/base.go

@@ -0,0 +1,145 @@
+package transport
+
+import (
+	"context"
+	"os"
+	"sync"
+
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/dns"
+	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/common/logger"
+)
+
+type TransportState int
+
+const (
+	StateNew TransportState = iota
+	StateStarted
+	StateClosing
+	StateClosed
+)
+
+var (
+	ErrTransportClosed = os.ErrClosed
+	ErrConnectionReset = E.New("connection reset")
+)
+
+type BaseTransport struct {
+	dns.TransportAdapter
+	Logger logger.ContextLogger
+
+	mutex           sync.Mutex
+	state           TransportState
+	inFlight        int32
+	queriesComplete chan struct{}
+	closeCtx        context.Context
+	closeCancel     context.CancelFunc
+}
+
+func NewBaseTransport(adapter dns.TransportAdapter, logger logger.ContextLogger) *BaseTransport {
+	ctx, cancel := context.WithCancel(context.Background())
+	return &BaseTransport{
+		TransportAdapter: adapter,
+		Logger:           logger,
+		state:            StateNew,
+		closeCtx:         ctx,
+		closeCancel:      cancel,
+	}
+}
+
+func (t *BaseTransport) State() TransportState {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	return t.state
+}
+
+func (t *BaseTransport) SetStarted() error {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	switch t.state {
+	case StateNew:
+		t.state = StateStarted
+		return nil
+	case StateStarted:
+		return nil
+	default:
+		return ErrTransportClosed
+	}
+}
+
+func (t *BaseTransport) BeginQuery() bool {
+	t.mutex.Lock()
+	defer t.mutex.Unlock()
+	if t.state != StateStarted {
+		return false
+	}
+	t.inFlight++
+	return true
+}
+
+func (t *BaseTransport) EndQuery() {
+	t.mutex.Lock()
+	if t.inFlight > 0 {
+		t.inFlight--
+	}
+	if t.inFlight == 0 && t.queriesComplete != nil {
+		close(t.queriesComplete)
+		t.queriesComplete = nil
+	}
+	t.mutex.Unlock()
+}
+
+func (t *BaseTransport) CloseContext() context.Context {
+	return t.closeCtx
+}
+
+func (t *BaseTransport) Shutdown(ctx context.Context) error {
+	t.mutex.Lock()
+
+	if t.state >= StateClosing {
+		t.mutex.Unlock()
+		return nil
+	}
+
+	if t.state == StateNew {
+		t.state = StateClosed
+		t.mutex.Unlock()
+		t.closeCancel()
+		return nil
+	}
+
+	t.state = StateClosing
+
+	if t.inFlight == 0 {
+		t.state = StateClosed
+		t.mutex.Unlock()
+		t.closeCancel()
+		return nil
+	}
+
+	t.queriesComplete = make(chan struct{})
+	queriesComplete := t.queriesComplete
+	t.mutex.Unlock()
+
+	t.closeCancel()
+
+	select {
+	case <-queriesComplete:
+		t.mutex.Lock()
+		t.state = StateClosed
+		t.mutex.Unlock()
+		return nil
+	case <-ctx.Done():
+		t.mutex.Lock()
+		t.state = StateClosed
+		t.mutex.Unlock()
+		return ctx.Err()
+	}
+}
+
+func (t *BaseTransport) Close() error {
+	ctx, cancel := context.WithTimeout(context.Background(), C.TCPTimeout)
+	defer cancel()
+	return t.Shutdown(ctx)
+}

+ 205 - 0
dns/transport/connector.go

@@ -0,0 +1,205 @@
+package transport
+
+import (
+	"context"
+	"net"
+	"sync"
+)
+
+type ConnectorCallbacks[T any] struct {
+	IsClosed func(connection T) bool
+	Close    func(connection T)
+	Reset    func(connection T)
+}
+
+type Connector[T any] struct {
+	dial      func(ctx context.Context) (T, error)
+	callbacks ConnectorCallbacks[T]
+
+	access        sync.Mutex
+	connection    T
+	hasConnection bool
+	connecting    chan struct{}
+
+	closeCtx context.Context
+	closed   bool
+}
+
+func NewConnector[T any](closeCtx context.Context, dial func(context.Context) (T, error), callbacks ConnectorCallbacks[T]) *Connector[T] {
+	return &Connector[T]{
+		dial:      dial,
+		callbacks: callbacks,
+		closeCtx:  closeCtx,
+	}
+}
+
+func NewSingleflightConnector(closeCtx context.Context, dial func(context.Context) (*Connection, error)) *Connector[*Connection] {
+	return NewConnector(closeCtx, dial, ConnectorCallbacks[*Connection]{
+		IsClosed: func(connection *Connection) bool {
+			return connection.IsClosed()
+		},
+		Close: func(connection *Connection) {
+			connection.CloseWithError(ErrTransportClosed)
+		},
+		Reset: func(connection *Connection) {
+			connection.CloseWithError(ErrConnectionReset)
+		},
+	})
+}
+
+func (c *Connector[T]) Get(ctx context.Context) (T, error) {
+	var zero T
+	for {
+		c.access.Lock()
+
+		if c.closed {
+			c.access.Unlock()
+			return zero, ErrTransportClosed
+		}
+
+		if c.hasConnection && !c.callbacks.IsClosed(c.connection) {
+			connection := c.connection
+			c.access.Unlock()
+			return connection, nil
+		}
+
+		c.hasConnection = false
+
+		if c.connecting != nil {
+			connecting := c.connecting
+			c.access.Unlock()
+
+			select {
+			case <-connecting:
+				continue
+			case <-ctx.Done():
+				return zero, ctx.Err()
+			case <-c.closeCtx.Done():
+				return zero, ErrTransportClosed
+			}
+		}
+
+		c.connecting = make(chan struct{})
+		c.access.Unlock()
+
+		connection, err := c.dialWithCancellation(ctx)
+
+		c.access.Lock()
+		close(c.connecting)
+		c.connecting = nil
+
+		if err != nil {
+			c.access.Unlock()
+			return zero, err
+		}
+
+		if c.closed {
+			c.callbacks.Close(connection)
+			c.access.Unlock()
+			return zero, ErrTransportClosed
+		}
+
+		c.connection = connection
+		c.hasConnection = true
+		result := c.connection
+		c.access.Unlock()
+
+		return result, nil
+	}
+}
+
+func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, error) {
+	dialCtx, cancel := context.WithCancel(ctx)
+	defer cancel()
+
+	go func() {
+		select {
+		case <-c.closeCtx.Done():
+			cancel()
+		case <-dialCtx.Done():
+		}
+	}()
+
+	return c.dial(dialCtx)
+}
+
+func (c *Connector[T]) Close() error {
+	c.access.Lock()
+	defer c.access.Unlock()
+
+	if c.closed {
+		return nil
+	}
+	c.closed = true
+
+	if c.hasConnection {
+		c.callbacks.Close(c.connection)
+		c.hasConnection = false
+	}
+
+	return nil
+}
+
+func (c *Connector[T]) Reset() {
+	c.access.Lock()
+	defer c.access.Unlock()
+
+	if c.hasConnection {
+		c.callbacks.Reset(c.connection)
+		c.hasConnection = false
+	}
+}
+
+type Connection struct {
+	net.Conn
+
+	closeOnce  sync.Once
+	done       chan struct{}
+	closeError error
+}
+
+func WrapConnection(conn net.Conn) *Connection {
+	return &Connection{
+		Conn: conn,
+		done: make(chan struct{}),
+	}
+}
+
+func (c *Connection) Done() <-chan struct{} {
+	return c.done
+}
+
+func (c *Connection) IsClosed() bool {
+	select {
+	case <-c.done:
+		return true
+	default:
+		return false
+	}
+}
+
+func (c *Connection) CloseError() error {
+	select {
+	case <-c.done:
+		if c.closeError != nil {
+			return c.closeError
+		}
+		return ErrTransportClosed
+	default:
+		return nil
+	}
+}
+
+func (c *Connection) Close() error {
+	return c.CloseWithError(ErrTransportClosed)
+}
+
+func (c *Connection) CloseWithError(err error) error {
+	var returnError error
+	c.closeOnce.Do(func() {
+		c.closeError = err
+		returnError = c.Conn.Close()
+		close(c.done)
+	})
+	return returnError
+}

+ 7 - 0
dns/transport/dhcp/dhcp.go

@@ -108,6 +108,13 @@ func (t *Transport) Close() error {
 	return nil
 }
 
+func (t *Transport) Reset() {
+	t.transportLock.Lock()
+	t.updatedAt = time.Time{}
+	t.servers = nil
+	t.transportLock.Unlock()
+}
+
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
 	servers, err := t.fetch()
 	if err != nil {

+ 4 - 0
dns/transport/fakeip/memory.go

@@ -82,8 +82,12 @@ func (s *MemoryStorage) FakeIPLoadDomain(domain string, isIPv6 bool) (netip.Addr
 }
 
 func (s *MemoryStorage) FakeIPReset() error {
+	s.addressAccess.Lock()
+	s.domainAccess.Lock()
 	s.addressCache = make(map[netip.Addr]string)
 	s.domainCache4 = make(map[string]netip.Addr)
 	s.domainCache6 = make(map[string]netip.Addr)
+	s.domainAccess.Unlock()
+	s.addressAccess.Unlock()
 	return nil
 }

+ 28 - 10
dns/transport/fakeip/store.go

@@ -3,6 +3,7 @@ package fakeip
 import (
 	"context"
 	"net/netip"
+	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
 	E "github.com/sagernet/sing/common/exceptions"
@@ -13,13 +14,15 @@ import (
 var _ adapter.FakeIPStore = (*Store)(nil)
 
 type Store struct {
-	ctx          context.Context
-	logger       logger.Logger
-	inet4Range   netip.Prefix
-	inet6Range   netip.Prefix
-	storage      adapter.FakeIPStorage
-	inet4Current netip.Addr
-	inet6Current netip.Addr
+	ctx        context.Context
+	logger     logger.Logger
+	inet4Range netip.Prefix
+	inet6Range netip.Prefix
+	storage    adapter.FakeIPStorage
+
+	addressAccess sync.Mutex
+	inet4Current  netip.Addr
+	inet6Current  netip.Addr
 }
 
 func NewStore(ctx context.Context, logger logger.Logger, inet4Range netip.Prefix, inet6Range netip.Prefix) *Store {
@@ -65,18 +68,30 @@ func (s *Store) Close() error {
 	if s.storage == nil {
 		return nil
 	}
-	return s.storage.FakeIPSaveMetadata(&adapter.FakeIPMetadata{
+	s.addressAccess.Lock()
+	metadata := &adapter.FakeIPMetadata{
 		Inet4Range:   s.inet4Range,
 		Inet6Range:   s.inet6Range,
 		Inet4Current: s.inet4Current,
 		Inet6Current: s.inet6Current,
-	})
+	}
+	s.addressAccess.Unlock()
+	return s.storage.FakeIPSaveMetadata(metadata)
 }
 
 func (s *Store) Create(domain string, isIPv6 bool) (netip.Addr, error) {
 	if address, loaded := s.storage.FakeIPLoadDomain(domain, isIPv6); loaded {
 		return address, nil
 	}
+
+	s.addressAccess.Lock()
+	defer s.addressAccess.Unlock()
+
+	// Double-check after acquiring lock
+	if address, loaded := s.storage.FakeIPLoadDomain(domain, isIPv6); loaded {
+		return address, nil
+	}
+
 	var address netip.Addr
 	if !isIPv6 {
 		if !s.inet4Current.IsValid() {
@@ -99,7 +114,10 @@ func (s *Store) Create(domain string, isIPv6 bool) (netip.Addr, error) {
 		s.inet6Current = nextAddress
 		address = nextAddress
 	}
-	s.storage.FakeIPStoreAsync(address, domain, s.logger)
+	err := s.storage.FakeIPStore(address, domain)
+	if err != nil {
+		s.logger.Warn("save FakeIP cache: ", err)
+	}
 	s.storage.FakeIPSaveMetadataAsync(&adapter.FakeIPMetadata{
 		Inet4Range:   s.inet4Range,
 		Inet6Range:   s.inet6Range,

+ 3 - 0
dns/transport/hosts/hosts.go

@@ -59,6 +59,9 @@ func (t *Transport) Close() error {
 	return nil
 }
 
+func (t *Transport) Reset() {
+}
+
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
 	question := message.Question[0]
 	domain := mDNS.CanonicalName(question.Name)

+ 12 - 2
dns/transport/https.go

@@ -145,6 +145,13 @@ func (t *HTTPSTransport) Close() error {
 	return nil
 }
 
+func (t *HTTPSTransport) Reset() {
+	t.transportAccess.Lock()
+	defer t.transportAccess.Unlock()
+	t.transport.CloseIdleConnections()
+	t.transport = t.transport.Clone()
+}
+
 func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
 	startAt := time.Now()
 	response, err := t.exchange(ctx, message)
@@ -182,7 +189,10 @@ func (t *HTTPSTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS
 	request.Header = t.headers.Clone()
 	request.Header.Set("Content-Type", MimeType)
 	request.Header.Set("Accept", MimeType)
-	response, err := t.transport.RoundTrip(request)
+	t.transportAccess.Lock()
+	currentTransport := t.transport
+	t.transportAccess.Unlock()
+	response, err := currentTransport.RoundTrip(request)
 	requestBuffer.Release()
 	if err != nil {
 		return nil, err
@@ -194,12 +204,12 @@ func (t *HTTPSTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS
 	var responseMessage mDNS.Msg
 	if response.ContentLength > 0 {
 		responseBuffer := buf.NewSize(int(response.ContentLength))
+		defer responseBuffer.Release()
 		_, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength))
 		if err != nil {
 			return nil, err
 		}
 		err = responseMessage.Unpack(responseBuffer.Bytes())
-		responseBuffer.Release()
 	} else {
 		rawMessage, err = io.ReadAll(response.Body)
 		if err != nil {

+ 3 - 0
dns/transport/local/local.go

@@ -76,6 +76,9 @@ func (t *Transport) Close() error {
 	return nil
 }
 
+func (t *Transport) Reset() {
+}
+
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
 	if t.resolved != nil {
 		resolverObject := t.resolved.Object()

+ 6 - 0
dns/transport/local/local_darwin.go

@@ -92,6 +92,12 @@ func (t *Transport) Close() error {
 	)
 }
 
+func (t *Transport) Reset() {
+	if t.dhcpTransport != nil {
+		t.dhcpTransport.Reset()
+	}
+}
+
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
 	question := message.Question[0]
 	if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {

+ 52 - 19
dns/transport/quic/http3.go

@@ -8,10 +8,12 @@ import (
 	"net/http"
 	"net/url"
 	"strconv"
+	"sync"
 
 	"github.com/sagernet/quic-go"
 	"github.com/sagernet/quic-go/http3"
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
@@ -23,6 +25,7 @@ import (
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
 	"github.com/sagernet/sing/common/logger"
+	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	sHTTP "github.com/sagernet/sing/protocol/http"
 
@@ -37,11 +40,14 @@ func RegisterHTTP3Transport(registry *dns.TransportRegistry) {
 
 type HTTP3Transport struct {
 	dns.TransportAdapter
-	logger      logger.ContextLogger
-	dialer      N.Dialer
-	destination *url.URL
-	headers     http.Header
-	transport   *http3.Transport
+	logger          logger.ContextLogger
+	dialer          N.Dialer
+	destination     *url.URL
+	headers         http.Header
+	serverAddr      M.Socksaddr
+	tlsConfig       *tls.STDConfig
+	transportAccess sync.Mutex
+	transport       *http3.Transport
 }
 
 func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) {
@@ -95,33 +101,57 @@ func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options
 	if !serverAddr.IsValid() {
 		return nil, E.New("invalid server address: ", serverAddr)
 	}
-	return &HTTP3Transport{
+	t := &HTTP3Transport{
 		TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTP3, tag, options.RemoteDNSServerOptions),
 		logger:           logger,
 		dialer:           transportDialer,
 		destination:      &destinationURL,
 		headers:          headers,
-		transport: &http3.Transport{
-			Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (*quic.Conn, error) {
-				conn, dialErr := transportDialer.DialContext(ctx, N.NetworkUDP, serverAddr)
-				if dialErr != nil {
-					return nil, dialErr
-				}
-				return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg)
-			},
-			TLSClientConfig: stdConfig,
+		serverAddr:       serverAddr,
+		tlsConfig:        stdConfig,
+	}
+	t.transport = t.newTransport()
+	return t, nil
+}
+
+func (t *HTTP3Transport) newTransport() *http3.Transport {
+	return &http3.Transport{
+		Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (*quic.Conn, error) {
+			conn, dialErr := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
+			if dialErr != nil {
+				return nil, dialErr
+			}
+			quicConn, dialErr := quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg)
+			if dialErr != nil {
+				conn.Close()
+				return nil, dialErr
+			}
+			return quicConn, nil
 		},
-	}, nil
+		TLSClientConfig: t.tlsConfig,
+	}
 }
 
 func (t *HTTP3Transport) Start(stage adapter.StartStage) error {
-	return nil
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	return dialer.InitializeDetour(t.dialer)
 }
 
 func (t *HTTP3Transport) Close() error {
+	t.transportAccess.Lock()
+	defer t.transportAccess.Unlock()
 	return t.transport.Close()
 }
 
+func (t *HTTP3Transport) Reset() {
+	t.transportAccess.Lock()
+	defer t.transportAccess.Unlock()
+	t.transport.Close()
+	t.transport = t.newTransport()
+}
+
 func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
 	exMessage := *message
 	exMessage.Id = 0
@@ -140,7 +170,10 @@ func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS
 	request.Header = t.headers.Clone()
 	request.Header.Set("Content-Type", transport.MimeType)
 	request.Header.Set("Accept", transport.MimeType)
-	response, err := t.transport.RoundTrip(request)
+	t.transportAccess.Lock()
+	currentTransport := t.transport
+	t.transportAccess.Unlock()
+	response, err := currentTransport.RoundTrip(request)
 	requestBuffer.Release()
 	if err != nil {
 		return nil, err
@@ -152,12 +185,12 @@ func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS
 	var responseMessage mDNS.Msg
 	if response.ContentLength > 0 {
 		responseBuffer := buf.NewSize(int(response.ContentLength))
+		defer responseBuffer.Release()
 		_, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength))
 		if err != nil {
 			return nil, err
 		}
 		err = responseMessage.Unpack(responseBuffer.Bytes())
-		responseBuffer.Release()
 	} else {
 		rawMessage, err = io.ReadAll(response.Body)
 		if err != nil {

+ 82 - 56
dns/transport/quic/quic.go

@@ -3,10 +3,11 @@ package quic
 import (
 	"context"
 	"errors"
-	"sync"
+	"os"
 
 	"github.com/sagernet/quic-go"
 	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/common/dialer"
 	"github.com/sagernet/sing-box/common/tls"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
@@ -17,7 +18,6 @@ import (
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
-	"github.com/sagernet/sing/common/logger"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 
@@ -31,14 +31,14 @@ func RegisterTransport(registry *dns.TransportRegistry) {
 }
 
 type Transport struct {
-	dns.TransportAdapter
+	*transport.BaseTransport
+
 	ctx        context.Context
-	logger     logger.ContextLogger
 	dialer     N.Dialer
 	serverAddr M.Socksaddr
 	tlsConfig  tls.Config
-	access     sync.Mutex
-	connection *quic.Conn
+
+	connector *transport.Connector[*quic.Conn]
 }
 
 func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
@@ -62,38 +62,84 @@ func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options
 	if !serverAddr.IsValid() {
 		return nil, E.New("invalid server address: ", serverAddr)
 	}
-	return &Transport{
-		TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions),
-		ctx:              ctx,
-		logger:           logger,
-		dialer:           transportDialer,
-		serverAddr:       serverAddr,
-		tlsConfig:        tlsConfig,
-	}, nil
+
+	t := &Transport{
+		BaseTransport: transport.NewBaseTransport(
+			dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions),
+			logger,
+		),
+		ctx:        ctx,
+		dialer:     transportDialer,
+		serverAddr: serverAddr,
+		tlsConfig:  tlsConfig,
+	}
+
+	t.connector = transport.NewConnector(t.CloseContext(), t.dial, transport.ConnectorCallbacks[*quic.Conn]{
+		IsClosed: func(connection *quic.Conn) bool {
+			return common.Done(connection.Context())
+		},
+		Close: func(connection *quic.Conn) {
+			connection.CloseWithError(0, "")
+		},
+		Reset: func(connection *quic.Conn) {
+			connection.CloseWithError(0, "")
+		},
+	})
+
+	return t, nil
+}
+
+func (t *Transport) dial(ctx context.Context) (*quic.Conn, error) {
+	conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
+	if err != nil {
+		return nil, E.Cause(err, "dial UDP connection")
+	}
+	earlyConnection, err := sQUIC.DialEarly(
+		ctx,
+		bufio.NewUnbindPacketConn(conn),
+		t.serverAddr.UDPAddr(),
+		t.tlsConfig,
+		nil,
+	)
+	if err != nil {
+		conn.Close()
+		return nil, E.Cause(err, "establish QUIC connection")
+	}
+	return earlyConnection, nil
 }
 
 func (t *Transport) Start(stage adapter.StartStage) error {
-	return nil
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	err := t.SetStarted()
+	if err != nil {
+		return err
+	}
+	return dialer.InitializeDetour(t.dialer)
 }
 
 func (t *Transport) Close() error {
-	t.access.Lock()
-	defer t.access.Unlock()
-	connection := t.connection
-	if connection != nil {
-		connection.CloseWithError(0, "")
-	}
-	return nil
+	return E.Errors(t.BaseTransport.Close(), t.connector.Close())
+}
+
+func (t *Transport) Reset() {
+	t.connector.Reset()
 }
 
 func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
+	if !t.BeginQuery() {
+		return nil, transport.ErrTransportClosed
+	}
+	defer t.EndQuery()
+
 	var (
 		conn     *quic.Conn
 		err      error
 		response *mDNS.Msg
 	)
 	for i := 0; i < 2; i++ {
-		conn, err = t.openConnection()
+		conn, err = t.connector.Get(ctx)
 		if err != nil {
 			return nil, err
 		}
@@ -103,58 +149,38 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg,
 		} else if !isQUICRetryError(err) {
 			return nil, err
 		} else {
-			conn.CloseWithError(quic.ApplicationErrorCode(0), "")
+			t.connector.Reset()
 			continue
 		}
 	}
 	return nil, err
 }
 
-func (t *Transport) openConnection() (*quic.Conn, error) {
-	connection := t.connection
-	if connection != nil && !common.Done(connection.Context()) {
-		return connection, nil
-	}
-	t.access.Lock()
-	defer t.access.Unlock()
-	connection = t.connection
-	if connection != nil && !common.Done(connection.Context()) {
-		return connection, nil
-	}
-	conn, err := t.dialer.DialContext(t.ctx, N.NetworkUDP, t.serverAddr)
-	if err != nil {
-		return nil, err
-	}
-	earlyConnection, err := sQUIC.DialEarly(
-		t.ctx,
-		bufio.NewUnbindPacketConn(conn),
-		t.serverAddr.UDPAddr(),
-		t.tlsConfig,
-		nil,
-	)
-	if err != nil {
-		return nil, err
-	}
-	t.connection = earlyConnection
-	return earlyConnection, nil
-}
-
 func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, conn *quic.Conn) (*mDNS.Msg, error) {
 	stream, err := conn.OpenStreamSync(ctx)
 	if err != nil {
-		return nil, err
+		return nil, E.Cause(err, "open stream")
 	}
+	defer stream.CancelRead(0)
 	err = transport.WriteMessage(stream, 0, message)
 	if err != nil {
 		stream.Close()
-		return nil, err
+		return nil, E.Cause(err, "write request")
 	}
 	stream.Close()
-	return transport.ReadMessage(stream)
+	response, err := transport.ReadMessage(stream)
+	if err != nil {
+		return nil, E.Cause(err, "read response")
+	}
+	return response, nil
 }
 
 // https://github.com/AdguardTeam/dnsproxy/blob/fd1868577652c639cce3da00e12ca548f421baf1/upstream/upstream_quic.go#L394
 func isQUICRetryError(err error) (ok bool) {
+	if errors.Is(err, os.ErrClosed) {
+		return true
+	}
+
 	var qAppErr *quic.ApplicationError
 	if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 {
 		return true

+ 10 - 3
dns/transport/tcp.go

@@ -62,17 +62,24 @@ func (t *TCPTransport) Close() error {
 	return nil
 }
 
+func (t *TCPTransport) Reset() {
+}
+
 func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
 	conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
 	if err != nil {
-		return nil, err
+		return nil, E.Cause(err, "dial TCP connection")
 	}
 	defer conn.Close()
 	err = WriteMessage(conn, 0, message)
 	if err != nil {
-		return nil, err
+		return nil, E.Cause(err, "write request")
+	}
+	response, err := ReadMessage(conn)
+	if err != nil {
+		return nil, E.Cause(err, "read response")
 	}
-	return ReadMessage(conn)
+	return response, nil
 }
 
 func ReadMessage(reader io.Reader) (*mDNS.Msg, error) {

+ 40 - 12
dns/transport/tls.go

@@ -3,6 +3,7 @@ package transport
 import (
 	"context"
 	"sync"
+	"time"
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/dialer"
@@ -28,8 +29,8 @@ func RegisterTLS(registry *dns.TransportRegistry) {
 }
 
 type TLSTransport struct {
-	dns.TransportAdapter
-	logger      logger.ContextLogger
+	*BaseTransport
+
 	dialer      tls.Dialer
 	serverAddr  M.Socksaddr
 	tlsConfig   tls.Config
@@ -65,11 +66,10 @@ func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options o
 
 func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport {
 	return &TLSTransport{
-		TransportAdapter: adapter,
-		logger:           logger,
-		dialer:           tls.NewDialer(dialer, tlsConfig),
-		serverAddr:       serverAddr,
-		tlsConfig:        tlsConfig,
+		BaseTransport: NewBaseTransport(adapter, logger),
+		dialer:        tls.NewDialer(dialer, tlsConfig),
+		serverAddr:    serverAddr,
+		tlsConfig:     tlsConfig,
 	}
 }
 
@@ -77,37 +77,59 @@ func (t *TLSTransport) Start(stage adapter.StartStage) error {
 	if stage != adapter.StartStateStart {
 		return nil
 	}
+	err := t.SetStarted()
+	if err != nil {
+		return err
+	}
 	return dialer.InitializeDetour(t.dialer)
 }
 
 func (t *TLSTransport) Close() error {
+	t.access.Lock()
+	for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
+		connection.Value.Close()
+	}
+	t.connections.Init()
+	t.access.Unlock()
+	return t.BaseTransport.Close()
+}
+
+func (t *TLSTransport) Reset() {
 	t.access.Lock()
 	defer t.access.Unlock()
 	for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
 		connection.Value.Close()
 	}
 	t.connections.Init()
-	return nil
 }
 
 func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
+	if !t.BeginQuery() {
+		return nil, ErrTransportClosed
+	}
+	defer t.EndQuery()
+
 	t.access.Lock()
 	conn := t.connections.PopFront()
 	t.access.Unlock()
 	if conn != nil {
-		response, err := t.exchange(message, conn)
+		response, err := t.exchange(ctx, message, conn)
 		if err == nil {
 			return response, nil
 		}
+		t.Logger.DebugContext(ctx, "discarded pooled connection: ", err)
 	}
 	tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr)
 	if err != nil {
-		return nil, err
+		return nil, E.Cause(err, "dial TLS connection")
 	}
-	return t.exchange(message, &tlsDNSConn{Conn: tlsConn})
+	return t.exchange(ctx, message, &tlsDNSConn{Conn: tlsConn})
 }
 
-func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
+func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
+	if deadline, ok := ctx.Deadline(); ok {
+		conn.SetDeadline(deadline)
+	}
 	conn.queryId++
 	err := WriteMessage(conn, conn.queryId, message)
 	if err != nil {
@@ -120,6 +142,12 @@ func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg,
 		return nil, E.Cause(err, "read response")
 	}
 	t.access.Lock()
+	if t.State() >= StateClosing {
+		t.access.Unlock()
+		conn.Close()
+		return response, nil
+	}
+	conn.SetDeadline(time.Time{})
 	t.connections.PushBack(conn)
 	t.access.Unlock()
 	return response, nil

+ 142 - 117
dns/transport/udp.go

@@ -2,9 +2,8 @@ package transport
 
 import (
 	"context"
-	"net"
-	"os"
 	"sync"
+	"sync/atomic"
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/dialer"
@@ -28,15 +27,23 @@ func RegisterUDP(registry *dns.TransportRegistry) {
 }
 
 type UDPTransport struct {
-	dns.TransportAdapter
-	logger       logger.ContextLogger
-	dialer       N.Dialer
-	serverAddr   M.Socksaddr
-	udpSize      int
-	tcpTransport *TCPTransport
-	access       sync.Mutex
-	conn         *dnsConnection
-	done         chan struct{}
+	*BaseTransport
+
+	dialer     N.Dialer
+	serverAddr M.Socksaddr
+	udpSize    atomic.Int32
+
+	connector *Connector[*Connection]
+
+	callbackAccess sync.RWMutex
+	queryId        uint16
+	callbacks      map[uint16]*udpCallback
+}
+
+type udpCallback struct {
+	access   sync.Mutex
+	response *mDNS.Msg
+	done     chan struct{}
 }
 
 func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
@@ -54,180 +61,198 @@ func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options o
 	return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil
 }
 
-func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
-	return &UDPTransport{
-		TransportAdapter: adapter,
-		logger:           logger,
-		dialer:           dialer,
-		serverAddr:       serverAddr,
-		udpSize:          2048,
-		tcpTransport: &TCPTransport{
-			dialer:     dialer,
-			serverAddr: serverAddr,
-		},
-		done: make(chan struct{}),
+func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialerInstance N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
+	t := &UDPTransport{
+		BaseTransport: NewBaseTransport(adapter, logger),
+		dialer:        dialerInstance,
+		serverAddr:    serverAddr,
+		callbacks:     make(map[uint16]*udpCallback),
+	}
+	t.udpSize.Store(2048)
+	t.connector = NewSingleflightConnector(t.CloseContext(), t.dial)
+	return t
+}
+
+func (t *UDPTransport) dial(ctx context.Context) (*Connection, error) {
+	rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
+	if err != nil {
+		return nil, E.Cause(err, "dial UDP connection")
 	}
+	conn := WrapConnection(rawConn)
+	go t.recvLoop(conn)
+	return conn, nil
 }
 
 func (t *UDPTransport) Start(stage adapter.StartStage) error {
 	if stage != adapter.StartStateStart {
 		return nil
 	}
+	err := t.SetStarted()
+	if err != nil {
+		return err
+	}
 	return dialer.InitializeDetour(t.dialer)
 }
 
 func (t *UDPTransport) Close() error {
-	t.access.Lock()
-	defer t.access.Unlock()
-	close(t.done)
-	t.done = make(chan struct{})
-	return nil
+	return E.Errors(t.BaseTransport.Close(), t.connector.Close())
+}
+
+func (t *UDPTransport) Reset() {
+	t.connector.Reset()
+}
+
+func (t *UDPTransport) nextAvailableQueryId() (uint16, error) {
+	start := t.queryId
+	for {
+		t.queryId++
+		if _, exists := t.callbacks[t.queryId]; !exists {
+			return t.queryId, nil
+		}
+		if t.queryId == start {
+			return 0, E.New("no available query ID")
+		}
+	}
 }
 
 func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
+	if !t.BeginQuery() {
+		return nil, ErrTransportClosed
+	}
+	defer t.EndQuery()
+
 	response, err := t.exchange(ctx, message)
 	if err != nil {
 		return nil, err
 	}
 	if response.Truncated {
-		t.logger.InfoContext(ctx, "response truncated, retrying with TCP")
-		return t.tcpTransport.Exchange(ctx, message)
+		t.Logger.InfoContext(ctx, "response truncated, retrying with TCP")
+		return t.exchangeTCP(ctx, message)
+	}
+	return response, nil
+}
+
+func (t *UDPTransport) exchangeTCP(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
+	conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
+	if err != nil {
+		return nil, E.Cause(err, "dial TCP connection")
+	}
+	defer conn.Close()
+	err = WriteMessage(conn, message.Id, message)
+	if err != nil {
+		return nil, E.Cause(err, "write request")
+	}
+	response, err := ReadMessage(conn)
+	if err != nil {
+		return nil, E.Cause(err, "read response")
 	}
 	return response, nil
 }
 
 func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
-	t.access.Lock()
 	if edns0Opt := message.IsEdns0(); edns0Opt != nil {
-		if udpSize := int(edns0Opt.UDPSize()); udpSize > t.udpSize {
-			t.udpSize = udpSize
-			close(t.done)
-			t.done = make(chan struct{})
+		udpSize := int32(edns0Opt.UDPSize())
+		for {
+			current := t.udpSize.Load()
+			if udpSize <= current {
+				break
+			}
+			if t.udpSize.CompareAndSwap(current, udpSize) {
+				t.connector.Reset()
+				break
+			}
 		}
 	}
-	t.access.Unlock()
-	conn, err := t.open(ctx)
+
+	conn, err := t.connector.Get(ctx)
 	if err != nil {
 		return nil, err
 	}
-	buffer := buf.NewSize(1 + message.Len())
-	defer buffer.Release()
-	exMessage := *message
-	exMessage.Compress = true
-	messageId := message.Id
-	callback := &dnsCallback{
+
+	callback := &udpCallback{
 		done: make(chan struct{}),
 	}
-	conn.access.Lock()
-	conn.queryId++
-	exMessage.Id = conn.queryId
-	conn.callbacks[exMessage.Id] = callback
-	conn.access.Unlock()
+
+	t.callbackAccess.Lock()
+	queryId, err := t.nextAvailableQueryId()
+	if err != nil {
+		t.callbackAccess.Unlock()
+		return nil, err
+	}
+	t.callbacks[queryId] = callback
+	t.callbackAccess.Unlock()
+
 	defer func() {
-		conn.access.Lock()
-		delete(conn.callbacks, exMessage.Id)
-		conn.access.Unlock()
+		t.callbackAccess.Lock()
+		delete(t.callbacks, queryId)
+		t.callbackAccess.Unlock()
 	}()
+
+	buffer := buf.NewSize(1 + message.Len())
+	defer buffer.Release()
+
+	exMessage := *message
+	exMessage.Compress = true
+	originalId := message.Id
+	exMessage.Id = queryId
+
 	rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
 	if err != nil {
 		return nil, err
 	}
+
 	_, err = conn.Write(rawMessage)
 	if err != nil {
-		conn.Close(err)
-		return nil, err
+		conn.CloseWithError(err)
+		return nil, E.Cause(err, "write request")
 	}
+
 	select {
 	case <-callback.done:
-		callback.message.Id = messageId
-		return callback.message, nil
-	case <-conn.done:
-		return nil, conn.err
-	case <-t.done:
-		return nil, os.ErrClosed
+		callback.response.Id = originalId
+		return callback.response, nil
+	case <-conn.Done():
+		return nil, conn.CloseError()
+	case <-t.CloseContext().Done():
+		return nil, ErrTransportClosed
 	case <-ctx.Done():
-		conn.Close(ctx.Err())
 		return nil, ctx.Err()
 	}
 }
 
-func (t *UDPTransport) open(ctx context.Context) (*dnsConnection, error) {
-	t.access.Lock()
-	defer t.access.Unlock()
-	if t.conn != nil {
-		select {
-		case <-t.conn.done:
-		default:
-			return t.conn, nil
-		}
-	}
-	conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
-	if err != nil {
-		return nil, err
-	}
-	dnsConn := &dnsConnection{
-		Conn:      conn,
-		done:      make(chan struct{}),
-		callbacks: make(map[uint16]*dnsCallback),
-	}
-	go t.recvLoop(dnsConn)
-	t.conn = dnsConn
-	return dnsConn, nil
-}
-
-func (t *UDPTransport) recvLoop(conn *dnsConnection) {
+func (t *UDPTransport) recvLoop(conn *Connection) {
 	for {
-		buffer := buf.NewSize(t.udpSize)
+		buffer := buf.NewSize(int(t.udpSize.Load()))
 		_, err := buffer.ReadOnceFrom(conn)
 		if err != nil {
 			buffer.Release()
-			conn.Close(err)
+			conn.CloseWithError(err)
 			return
 		}
+
 		var message mDNS.Msg
 		err = message.Unpack(buffer.Bytes())
 		buffer.Release()
 		if err != nil {
-			conn.Close(err)
-			return
+			t.Logger.Debug("discarded malformed UDP response: ", err)
+			continue
 		}
-		conn.access.RLock()
-		callback, loaded := conn.callbacks[message.Id]
-		conn.access.RUnlock()
+
+		t.callbackAccess.RLock()
+		callback, loaded := t.callbacks[message.Id]
+		t.callbackAccess.RUnlock()
+
 		if !loaded {
 			continue
 		}
+
 		callback.access.Lock()
 		select {
 		case <-callback.done:
 		default:
-			callback.message = &message
+			callback.response = &message
 			close(callback.done)
 		}
 		callback.access.Unlock()
 	}
 }
-
-type dnsConnection struct {
-	net.Conn
-	access    sync.RWMutex
-	done      chan struct{}
-	closeOnce sync.Once
-	err       error
-	queryId   uint16
-	callbacks map[uint16]*dnsCallback
-}
-
-func (c *dnsConnection) Close(err error) {
-	c.closeOnce.Do(func() {
-		c.err = err
-		close(c.done)
-	})
-	c.Conn.Close()
-}
-
-type dnsCallback struct {
-	access  sync.Mutex
-	message *mDNS.Msg
-	done    chan struct{}
-}

+ 3 - 0
experimental/libbox/dns.go

@@ -46,6 +46,9 @@ func (p *platformTransport) Close() error {
 	return nil
 }
 
+func (p *platformTransport) Reset() {
+}
+
 func (p *platformTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
 	response := &ExchangeContext{
 		context: ctx,

+ 10 - 0
service/resolved/transport.go

@@ -110,6 +110,16 @@ func (t *Transport) Close() error {
 	return nil
 }
 
+func (t *Transport) Reset() {
+	t.linkAccess.RLock()
+	defer t.linkAccess.RUnlock()
+	for _, servers := range t.linkServers {
+		for _, server := range servers.Servers {
+			server.Reset()
+		}
+	}
+}
+
 func (t *Transport) updateTransports(link *TransportLink) error {
 	t.linkAccess.Lock()
 	defer t.linkAccess.Unlock()