Explorar o código

Fix `local` DNS server on darwin

We mistakenly believed that `libresolv`'s `search` function worked correctly in NetworkExtension, but it seems only `getaddrinfo` does.

This commit changes the behavior of the `local` DNS server in NetworkExtension to prefer DHCP, falling back to `getaddrinfo` if DHCP servers are unavailable.

It's worth noting that `prefer_go` does not disable DHCP since it respects Dial Fields, but `getaddrinfo` does the opposite. The new behavior only applies to NetworkExtension, not to all scenarios (primarily command-line binaries) as it did previously.

In addition, this commit also improves the DHCP DNS server to use the same robust query logic as `local`.
世界 hai 2 meses
pai
achega
918d4ff3ca

+ 3 - 180
dns/transport/local/local.go

@@ -1,23 +1,18 @@
+//go:build !darwin
+
 package local
 
 import (
 	"context"
-	"errors"
-	"math/rand"
-	"syscall"
-	"time"
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
-	"github.com/sagernet/sing-box/dns/transport"
 	"github.com/sagernet/sing-box/dns/transport/hosts"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
-	"github.com/sagernet/sing/common/buf"
 	E "github.com/sagernet/sing/common/exceptions"
 	"github.com/sagernet/sing/common/logger"
-	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 
 	mDNS "github.com/miekg/dns"
@@ -40,9 +35,6 @@ type Transport struct {
 }
 
 func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) {
-	if C.IsDarwin && !options.PreferGo {
-		return NewResolvTransport(ctx, logger, tag)
-	}
 	transportDialer, err := dns.NewLocalDialer(ctx, options)
 	if err != nil {
 		return nil, err
@@ -97,174 +89,5 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg,
 			return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
 		}
 	}
-	systemConfig := getSystemDNSConfig(t.ctx)
-	if systemConfig.singleRequest || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) {
-		return t.exchangeSingleRequest(ctx, systemConfig, message, domain)
-	} else {
-		return t.exchangeParallel(ctx, systemConfig, message, domain)
-	}
-}
-
-func (t *Transport) exchangeSingleRequest(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
-	var lastErr error
-	for _, fqdn := range systemConfig.nameList(domain) {
-		response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
-		if err != nil {
-			lastErr = err
-			continue
-		}
-		return response, nil
-	}
-	return nil, lastErr
-}
-
-func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
-	returned := make(chan struct{})
-	defer close(returned)
-	type queryResult struct {
-		response *mDNS.Msg
-		err      error
-	}
-	results := make(chan queryResult)
-	startRacer := func(ctx context.Context, fqdn string) {
-		response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
-		if err == nil {
-			if response.Rcode != mDNS.RcodeSuccess {
-				err = dns.RcodeError(response.Rcode)
-			} else if len(dns.MessageToAddresses(response)) == 0 {
-				err = dns.RcodeSuccess
-			}
-		}
-		select {
-		case results <- queryResult{response, err}:
-		case <-returned:
-		}
-	}
-	queryCtx, queryCancel := context.WithCancel(ctx)
-	defer queryCancel()
-	var nameCount int
-	for _, fqdn := range systemConfig.nameList(domain) {
-		nameCount++
-		go startRacer(queryCtx, fqdn)
-	}
-	var errors []error
-	for {
-		select {
-		case <-ctx.Done():
-			return nil, ctx.Err()
-		case result := <-results:
-			if result.err == nil {
-				return result.response, nil
-			}
-			errors = append(errors, result.err)
-			if len(errors) == nameCount {
-				return nil, E.Errors(errors...)
-			}
-		}
-	}
-}
-
-func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) {
-	serverOffset := config.serverOffset()
-	sLen := uint32(len(config.servers))
-	var lastErr error
-	for i := 0; i < config.attempts; i++ {
-		for j := uint32(0); j < sLen; j++ {
-			server := config.servers[(serverOffset+j)%sLen]
-			question := message.Question[0]
-			question.Name = fqdn
-			response, err := t.exchangeOne(ctx, M.ParseSocksaddr(server), question, config.timeout, config.useTCP, config.trustAD)
-			if err != nil {
-				lastErr = err
-				continue
-			}
-			return response, nil
-		}
-	}
-	return nil, E.Cause(lastErr, fqdn)
-}
-
-func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
-	if server.Port == 0 {
-		server.Port = 53
-	}
-	request := &mDNS.Msg{
-		MsgHdr: mDNS.MsgHdr{
-			Id:                uint16(rand.Uint32()),
-			RecursionDesired:  true,
-			AuthenticatedData: ad,
-		},
-		Question: []mDNS.Question{question},
-		Compress: true,
-	}
-	request.SetEdns0(buf.UDPBufferSize, false)
-	if !useTCP {
-		return t.exchangeUDP(ctx, server, request, timeout)
-	} else {
-		return t.exchangeTCP(ctx, server, request, timeout)
-	}
-}
-
-func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) {
-	conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server)
-	if err != nil {
-		return nil, err
-	}
-	defer conn.Close()
-	if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
-		newDeadline := time.Now().Add(timeout)
-		if deadline.After(newDeadline) {
-			deadline = newDeadline
-		}
-		conn.SetDeadline(deadline)
-	}
-	buffer := buf.Get(buf.UDPBufferSize)
-	defer buf.Put(buffer)
-	rawMessage, err := request.PackBuffer(buffer)
-	if err != nil {
-		return nil, E.Cause(err, "pack request")
-	}
-	_, err = conn.Write(rawMessage)
-	if err != nil {
-		if errors.Is(err, syscall.EMSGSIZE) {
-			return t.exchangeTCP(ctx, server, request, timeout)
-		}
-		return nil, E.Cause(err, "write request")
-	}
-	n, err := conn.Read(buffer)
-	if err != nil {
-		if errors.Is(err, syscall.EMSGSIZE) {
-			return t.exchangeTCP(ctx, server, request, timeout)
-		}
-		return nil, E.Cause(err, "read response")
-	}
-	var response mDNS.Msg
-	err = response.Unpack(buffer[:n])
-	if err != nil {
-		return nil, E.Cause(err, "unpack response")
-	}
-	if response.Truncated {
-		return t.exchangeTCP(ctx, server, request, timeout)
-	}
-	return &response, nil
-}
-
-func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) {
-	conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server)
-	if err != nil {
-		return nil, err
-	}
-	defer conn.Close()
-	if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
-		newDeadline := time.Now().Add(timeout)
-		if deadline.After(newDeadline) {
-			deadline = newDeadline
-		}
-		conn.SetDeadline(deadline)
-	}
-	err = transport.WriteMessage(conn, 0, request)
-	if err != nil {
-		return nil, err
-	}
-	return transport.ReadMessage(conn)
+	return t.exchange(ctx, message, domain)
 }

+ 135 - 0
dns/transport/local/local_darwin.go

@@ -0,0 +1,135 @@
+//go:build darwin
+
+package local
+
+import (
+	"context"
+	"errors"
+	"net"
+
+	"github.com/sagernet/sing-box/adapter"
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/dns"
+	"github.com/sagernet/sing-box/dns/transport/hosts"
+	"github.com/sagernet/sing-box/log"
+	"github.com/sagernet/sing-box/option"
+	"github.com/sagernet/sing/common"
+	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/common/logger"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/service"
+
+	mDNS "github.com/miekg/dns"
+)
+
+func RegisterTransport(registry *dns.TransportRegistry) {
+	dns.RegisterTransport[option.LocalDNSServerOptions](registry, C.DNSTypeLocal, NewTransport)
+}
+
+var _ adapter.DNSTransport = (*Transport)(nil)
+
+type Transport struct {
+	dns.TransportAdapter
+	ctx           context.Context
+	logger        logger.ContextLogger
+	hosts         *hosts.File
+	dialer        N.Dialer
+	preferGo      bool
+	fallback      bool
+	dhcpTransport dhcpTransport
+	resolver      net.Resolver
+}
+
+type dhcpTransport interface {
+	adapter.DNSTransport
+	Fetch() ([]M.Socksaddr, error)
+	Exchange0(ctx context.Context, message *mDNS.Msg, servers []M.Socksaddr) (*mDNS.Msg, error)
+}
+
+func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) {
+	transportDialer, err := dns.NewLocalDialer(ctx, options)
+	if err != nil {
+		return nil, err
+	}
+	transportAdapter := dns.NewTransportAdapterWithLocalOptions(C.DNSTypeLocal, tag, options)
+	return &Transport{
+		TransportAdapter: transportAdapter,
+		ctx:              ctx,
+		logger:           logger,
+		hosts:            hosts.NewFile(hosts.DefaultPath),
+		dialer:           transportDialer,
+		preferGo:         options.PreferGo,
+	}, nil
+}
+
+func (t *Transport) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	inboundManager := service.FromContext[adapter.InboundManager](t.ctx)
+	for _, inbound := range inboundManager.Inbounds() {
+		if inbound.Type() == C.TypeTun {
+			t.fallback = true
+			break
+		}
+	}
+	if t.fallback {
+		t.dhcpTransport = newDHCPTransport(t.TransportAdapter, log.ContextWithOverrideLevel(t.ctx, log.LevelDebug), t.dialer, t.logger)
+		if t.dhcpTransport != nil {
+			err := t.dhcpTransport.Start(stage)
+			if err != nil {
+				return err
+			}
+		}
+	}
+	return nil
+}
+
+func (t *Transport) Close() error {
+	return common.Close(
+		t.dhcpTransport,
+	)
+}
+
+func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
+	question := message.Question[0]
+	domain := dns.FqdnToDomain(question.Name)
+	if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
+		addresses := t.hosts.Lookup(domain)
+		if len(addresses) > 0 {
+			return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
+		}
+	}
+	if !t.fallback {
+		return t.exchange(ctx, message, domain)
+	}
+	if t.dhcpTransport != nil {
+		dhcpTransports, _ := t.dhcpTransport.Fetch()
+		if len(dhcpTransports) > 0 {
+			return t.dhcpTransport.Exchange0(ctx, message, dhcpTransports)
+		}
+	}
+	if t.preferGo {
+		// Assuming the user knows what they are doing, we still execute the query which will fail.
+		return t.exchange(ctx, message, domain)
+	}
+	if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
+		var network string
+		if question.Qtype == mDNS.TypeA {
+			network = "ip4"
+		} else {
+			network = "ip6"
+		}
+		addresses, err := t.resolver.LookupNetIP(ctx, network, domain)
+		if err != nil {
+			var dnsError *net.DNSError
+			if errors.As(err, &dnsError) && dnsError.IsNotFound {
+				return nil, dns.RcodeRefused
+			}
+			return nil, err
+		}
+		return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
+	}
+	return nil, E.New("only A and AAAA queries are supported on Apple platforms when using TUN and DHCP unavailable.")
+}

+ 16 - 0
dns/transport/local/local_darwin_dhcp.go

@@ -0,0 +1,16 @@
+//go:build darwin && with_dhcp
+
+package local
+
+import (
+	"context"
+
+	"github.com/sagernet/sing-box/dns"
+	"github.com/sagernet/sing-box/dns/transport/dhcp"
+	"github.com/sagernet/sing-box/log"
+	N "github.com/sagernet/sing/common/network"
+)
+
+func newDHCPTransport(transportAdapter dns.TransportAdapter, ctx context.Context, dialer N.Dialer, logger log.ContextLogger) dhcpTransport {
+	return dhcp.NewRawTransport(transportAdapter, ctx, dialer, logger)
+}

+ 15 - 0
dns/transport/local/local_darwin_nodhcp.go

@@ -0,0 +1,15 @@
+//go:build darwin && !with_dhcp
+
+package local
+
+import (
+	"context"
+
+	"github.com/sagernet/sing-box/dns"
+	"github.com/sagernet/sing-box/log"
+	N "github.com/sagernet/sing/common/network"
+)
+
+func newDHCPTransport(transportAdapter dns.TransportAdapter, ctx context.Context, dialer N.Dialer, logger log.ContextLogger) dhcpTransport {
+	return nil
+}

+ 0 - 46
dns/transport/local/local_resolv.go

@@ -1,46 +0,0 @@
-//go:build darwin
-
-package local
-
-import (
-	"context"
-
-	"github.com/sagernet/sing-box/adapter"
-	C "github.com/sagernet/sing-box/constant"
-	"github.com/sagernet/sing-box/dns"
-	"github.com/sagernet/sing-box/log"
-	"github.com/sagernet/sing/common/logger"
-
-	mDNS "github.com/miekg/dns"
-)
-
-var _ adapter.DNSTransport = (*ResolvTransport)(nil)
-
-type ResolvTransport struct {
-	dns.TransportAdapter
-	ctx    context.Context
-	logger logger.ContextLogger
-}
-
-func NewResolvTransport(ctx context.Context, logger log.ContextLogger, tag string) (adapter.DNSTransport, error) {
-	return &ResolvTransport{
-		TransportAdapter: dns.NewTransportAdapter(C.DNSTypeLocal, tag, nil),
-		ctx:              ctx,
-		logger:           logger,
-	}, nil
-}
-
-func (t *ResolvTransport) Start(stage adapter.StartStage) error {
-	return nil
-}
-
-func (t *ResolvTransport) Close() error {
-	return nil
-}
-
-func (t *ResolvTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
-	question := message.Question[0]
-	return doBlockingWithCtx(ctx, func() (*mDNS.Msg, error) {
-		return cgoResSearch(question.Name, int(question.Qtype), int(question.Qclass))
-	})
-}

+ 0 - 170
dns/transport/local/local_resolv_linkname.go

@@ -1,170 +0,0 @@
-// Copyright 2022 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build darwin
-
-package local
-
-import (
-	"context"
-	"errors"
-	"runtime"
-	"syscall"
-	"unsafe"
-	_ "unsafe"
-
-	E "github.com/sagernet/sing/common/exceptions"
-
-	mDNS "github.com/miekg/dns"
-)
-
-type (
-	_C_char               = byte
-	_C_int                = int32
-	_C_uchar              = byte
-	_C_ushort             = uint16
-	_C_uint               = uint32
-	_C_ulong              = uint64
-	_C_struct___res_state = ResState
-	_C_struct_sockaddr    = syscall.RawSockaddr
-)
-
-func _C_free(p unsafe.Pointer) { runtime.KeepAlive(p) }
-
-func _C_malloc(n uintptr) unsafe.Pointer {
-	if n <= 0 {
-		n = 1
-	}
-	return unsafe.Pointer(&make([]byte, n)[0])
-}
-
-const (
-	MAXNS     = 3
-	MAXDNSRCH = 6
-)
-
-type ResState struct {
-	Retrans    _C_int
-	Retry      _C_int
-	Options    _C_ulong
-	Nscount    _C_int
-	Nsaddrlist [MAXNS]_C_struct_sockaddr
-	Id         _C_ushort
-	Dnsrch     [MAXDNSRCH + 1]*_C_char
-	Defname    [256]_C_char
-	Pfcode     _C_ulong
-	Ndots      _C_uint
-	Nsort      _C_uint
-	stub       [128]byte
-}
-
-//go:linkname ResNinit internal/syscall/unix.ResNinit
-func ResNinit(state *_C_struct___res_state) error
-
-//go:linkname ResNsearch internal/syscall/unix.ResNsearch
-func ResNsearch(state *_C_struct___res_state, dname *byte, class, typ int, ans *byte, anslen int) (int, error)
-
-//go:linkname ResNclose internal/syscall/unix.ResNclose
-func ResNclose(state *_C_struct___res_state)
-
-//go:linkname GoString internal/syscall/unix.GoString
-func GoString(p *byte) string
-
-// doBlockingWithCtx executes a blocking function in a separate goroutine when the provided
-// context is cancellable. It is intended for use with calls that don't support context
-// cancellation (cgo, syscalls). blocking func may still be running after this function finishes.
-// For the duration of the execution of the blocking function, the thread is 'acquired' using [acquireThread],
-// blocking might not be executed when the context gets canceled early.
-func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (T, error) {
-	if err := acquireThread(ctx); err != nil {
-		var zero T
-		return zero, err
-	}
-
-	if ctx.Done() == nil {
-		defer releaseThread()
-		return blocking()
-	}
-
-	type result struct {
-		res T
-		err error
-	}
-
-	res := make(chan result, 1)
-	go func() {
-		defer releaseThread()
-		var r result
-		r.res, r.err = blocking()
-		res <- r
-	}()
-
-	select {
-	case r := <-res:
-		return r.res, r.err
-	case <-ctx.Done():
-		var zero T
-		return zero, ctx.Err()
-	}
-}
-
-//go:linkname acquireThread net.acquireThread
-func acquireThread(ctx context.Context) error
-
-//go:linkname releaseThread net.releaseThread
-func releaseThread()
-
-func cgoResSearch(hostname string, rtype, class int) (*mDNS.Msg, error) {
-	resStateSize := unsafe.Sizeof(_C_struct___res_state{})
-	var state *_C_struct___res_state
-	if resStateSize > 0 {
-		mem := _C_malloc(resStateSize)
-		defer _C_free(mem)
-		memSlice := unsafe.Slice((*byte)(mem), resStateSize)
-		clear(memSlice)
-		state = (*_C_struct___res_state)(unsafe.Pointer(&memSlice[0]))
-	}
-	if err := ResNinit(state); err != nil {
-		return nil, errors.New("res_ninit failure: " + err.Error())
-	}
-	defer ResNclose(state)
-
-	bufSize := maxDNSPacketSize
-	buf := (*_C_uchar)(_C_malloc(uintptr(bufSize)))
-	defer _C_free(unsafe.Pointer(buf))
-
-	s, err := syscall.BytePtrFromString(hostname)
-	if err != nil {
-		return nil, err
-	}
-
-	var size int
-	for {
-		size, _ = ResNsearch(state, s, class, rtype, buf, bufSize)
-		if size <= bufSize || size > 0xffff {
-			break
-		}
-
-		// Allocate a bigger buffer to fit the entire msg.
-		_C_free(unsafe.Pointer(buf))
-		bufSize = size
-		buf = (*_C_uchar)(_C_malloc(uintptr(bufSize)))
-	}
-
-	var msg mDNS.Msg
-	if size == -1 {
-		// macOS's libresolv seems to directly return -1 for responses that are not success responses but are exchanged.
-		// However, we still need the response, so we fall back to parsing the entire buffer.
-		err = msg.Unpack(unsafe.Slice(buf, bufSize))
-		if err != nil {
-			return nil, E.New("res_nsearch failure")
-		}
-	} else {
-		err = msg.Unpack(unsafe.Slice(buf, size))
-		if err != nil {
-			return nil, err
-		}
-	}
-	return &msg, nil
-}

+ 0 - 15
dns/transport/local/local_resolv_stub.go

@@ -1,15 +0,0 @@
-//go:build !darwin
-
-package local
-
-import (
-	"context"
-	"os"
-
-	"github.com/sagernet/sing-box/adapter"
-	"github.com/sagernet/sing-box/log"
-)
-
-func NewResolvTransport(ctx context.Context, logger log.ContextLogger, tag string) (adapter.DNSTransport, error) {
-	return nil, os.ErrInvalid
-}

+ 191 - 0
dns/transport/local/local_shared.go

@@ -0,0 +1,191 @@
+package local
+
+import (
+	"context"
+	"errors"
+	"math/rand"
+	"syscall"
+	"time"
+
+	"github.com/sagernet/sing-box/dns"
+	"github.com/sagernet/sing-box/dns/transport"
+	"github.com/sagernet/sing/common/buf"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+
+	mDNS "github.com/miekg/dns"
+)
+
+func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
+	systemConfig := getSystemDNSConfig(t.ctx)
+	if systemConfig.singleRequest || !(message.Question[0].Qtype == mDNS.TypeA || message.Question[0].Qtype == mDNS.TypeAAAA) {
+		return t.exchangeSingleRequest(ctx, systemConfig, message, domain)
+	} else {
+		return t.exchangeParallel(ctx, systemConfig, message, domain)
+	}
+}
+
+func (t *Transport) exchangeSingleRequest(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
+	var lastErr error
+	for _, fqdn := range systemConfig.nameList(domain) {
+		response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
+		if err != nil {
+			lastErr = err
+			continue
+		}
+		return response, nil
+	}
+	return nil, lastErr
+}
+
+func (t *Transport) exchangeParallel(ctx context.Context, systemConfig *dnsConfig, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
+	returned := make(chan struct{})
+	defer close(returned)
+	type queryResult struct {
+		response *mDNS.Msg
+		err      error
+	}
+	results := make(chan queryResult)
+	startRacer := func(ctx context.Context, fqdn string) {
+		response, err := t.tryOneName(ctx, systemConfig, fqdn, message)
+		if err == nil {
+			if response.Rcode != mDNS.RcodeSuccess {
+				err = dns.RcodeError(response.Rcode)
+			} else if len(dns.MessageToAddresses(response)) == 0 {
+				err = E.New(fqdn, ": empty result")
+			}
+		}
+		select {
+		case results <- queryResult{response, err}:
+		case <-returned:
+		}
+	}
+	queryCtx, queryCancel := context.WithCancel(ctx)
+	defer queryCancel()
+	var nameCount int
+	for _, fqdn := range systemConfig.nameList(domain) {
+		nameCount++
+		go startRacer(queryCtx, fqdn)
+	}
+	var errors []error
+	for {
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case result := <-results:
+			if result.err == nil {
+				return result.response, nil
+			}
+			errors = append(errors, result.err)
+			if len(errors) == nameCount {
+				return nil, E.Errors(errors...)
+			}
+		}
+	}
+}
+
+func (t *Transport) tryOneName(ctx context.Context, config *dnsConfig, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) {
+	serverOffset := config.serverOffset()
+	sLen := uint32(len(config.servers))
+	var lastErr error
+	for i := 0; i < config.attempts; i++ {
+		for j := uint32(0); j < sLen; j++ {
+			server := config.servers[(serverOffset+j)%sLen]
+			question := message.Question[0]
+			question.Name = fqdn
+			response, err := t.exchangeOne(ctx, M.ParseSocksaddr(server), question, config.timeout, config.useTCP, config.trustAD)
+			if err != nil {
+				lastErr = err
+				continue
+			}
+			return response, nil
+		}
+	}
+	return nil, E.Cause(lastErr, fqdn)
+}
+
+func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
+	if server.Port == 0 {
+		server.Port = 53
+	}
+	request := &mDNS.Msg{
+		MsgHdr: mDNS.MsgHdr{
+			Id:                uint16(rand.Uint32()),
+			RecursionDesired:  true,
+			AuthenticatedData: ad,
+		},
+		Question: []mDNS.Question{question},
+		Compress: true,
+	}
+	request.SetEdns0(buf.UDPBufferSize, false)
+	if !useTCP {
+		return t.exchangeUDP(ctx, server, request, timeout)
+	} else {
+		return t.exchangeTCP(ctx, server, request, timeout)
+	}
+}
+
+func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) {
+	conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server)
+	if err != nil {
+		return nil, err
+	}
+	defer conn.Close()
+	if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
+		newDeadline := time.Now().Add(timeout)
+		if deadline.After(newDeadline) {
+			deadline = newDeadline
+		}
+		conn.SetDeadline(deadline)
+	}
+	buffer := buf.Get(buf.UDPBufferSize)
+	defer buf.Put(buffer)
+	rawMessage, err := request.PackBuffer(buffer)
+	if err != nil {
+		return nil, E.Cause(err, "pack request")
+	}
+	_, err = conn.Write(rawMessage)
+	if err != nil {
+		if errors.Is(err, syscall.EMSGSIZE) {
+			return t.exchangeTCP(ctx, server, request, timeout)
+		}
+		return nil, E.Cause(err, "write request")
+	}
+	n, err := conn.Read(buffer)
+	if err != nil {
+		if errors.Is(err, syscall.EMSGSIZE) {
+			return t.exchangeTCP(ctx, server, request, timeout)
+		}
+		return nil, E.Cause(err, "read response")
+	}
+	var response mDNS.Msg
+	err = response.Unpack(buffer[:n])
+	if err != nil {
+		return nil, E.Cause(err, "unpack response")
+	}
+	if response.Truncated {
+		return t.exchangeTCP(ctx, server, request, timeout)
+	}
+	return &response, nil
+}
+
+func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) {
+	conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server)
+	if err != nil {
+		return nil, err
+	}
+	defer conn.Close()
+	if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
+		newDeadline := time.Now().Add(timeout)
+		if deadline.After(newDeadline) {
+			deadline = newDeadline
+		}
+		conn.SetDeadline(deadline)
+	}
+	err = transport.WriteMessage(conn, 0, request)
+	if err != nil {
+		return nil, err
+	}
+	return transport.ReadMessage(conn)
+}

+ 0 - 72
dns/transport/local/resolv_darwin.go

@@ -1,72 +0,0 @@
-package local
-
-import (
-	"context"
-	"net/netip"
-	"syscall"
-	"time"
-	"unsafe"
-
-	E "github.com/sagernet/sing/common/exceptions"
-
-	"github.com/miekg/dns"
-)
-
-func dnsReadConfig(_ context.Context, _ string) *dnsConfig {
-	resStateSize := unsafe.Sizeof(_C_struct___res_state{})
-	var state *_C_struct___res_state
-	if resStateSize > 0 {
-		mem := _C_malloc(resStateSize)
-		defer _C_free(mem)
-		memSlice := unsafe.Slice((*byte)(mem), resStateSize)
-		clear(memSlice)
-		state = (*_C_struct___res_state)(unsafe.Pointer(&memSlice[0]))
-	}
-	if err := ResNinit(state); err != nil {
-		return &dnsConfig{
-			servers:  defaultNS,
-			search:   dnsDefaultSearch(),
-			ndots:    1,
-			timeout:  5 * time.Second,
-			attempts: 2,
-			err:      E.Cause(err, "libresolv initialization failed"),
-		}
-	}
-	defer ResNclose(state)
-	conf := &dnsConfig{
-		ndots:    1,
-		timeout:  5 * time.Second,
-		attempts: int(state.Retry),
-	}
-	for i := 0; i < int(state.Nscount); i++ {
-		addr := parseRawSockaddr(&state.Nsaddrlist[i])
-		if addr.IsValid() {
-			conf.servers = append(conf.servers, addr.String())
-		}
-	}
-	for i := 0; ; i++ {
-		search := state.Dnsrch[i]
-		if search == nil {
-			break
-		}
-		name := dns.Fqdn(GoString(search))
-		if name == "" {
-			continue
-		}
-		conf.search = append(conf.search, name)
-	}
-	return conf
-}
-
-func parseRawSockaddr(rawSockaddr *syscall.RawSockaddr) netip.Addr {
-	switch rawSockaddr.Family {
-	case syscall.AF_INET:
-		sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(rawSockaddr))
-		return netip.AddrFrom4(sa.Addr)
-	case syscall.AF_INET6:
-		sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(rawSockaddr))
-		return netip.AddrFrom16(sa.Addr)
-	default:
-		return netip.Addr{}
-	}
-}

+ 1 - 1
dns/transport/local/resolv_unix.go

@@ -1,4 +1,4 @@
-//go:build !windows && !darwin
+//go:build !windows
 
 package local