瀏覽代碼

Add fallback local DNS server for iOS

世界 8 月之前
父節點
當前提交
9a5f69f435

+ 1 - 1
common/tls/ech_client.go

@@ -221,7 +221,7 @@ func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverNam
 			return nil, err
 		}
 		if response.Rcode != mDNS.RcodeSuccess {
-			return nil, dns.RCodeError(response.Rcode)
+			return nil, dns.RcodeError(response.Rcode)
 		}
 		for _, rr := range response.Answer {
 			switch resource := rr.(type) {

+ 74 - 4
dns/client.go

@@ -17,7 +17,7 @@ import (
 	"github.com/sagernet/sing/contrab/freelru"
 	"github.com/sagernet/sing/contrab/maphash"
 
-	"github.com/miekg/dns"
+	dns "github.com/miekg/dns"
 )
 
 var (
@@ -484,7 +484,7 @@ func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransp
 
 func MessageToAddresses(response *dns.Msg) ([]netip.Addr, error) {
 	if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
-		return nil, RCodeError(response.Rcode)
+		return nil, RcodeError(response.Rcode)
 	}
 	addresses := make([]netip.Addr, 0, len(response.Answer))
 	for _, rawAnswer := range response.Answer {
@@ -508,10 +508,10 @@ func wrapError(err error) error {
 	switch dnsErr := err.(type) {
 	case *net.DNSError:
 		if dnsErr.IsNotFound {
-			return RCodeNameError
+			return RcodeNameError
 		}
 	case *net.AddrError:
-		return RCodeNameError
+		return RcodeNameError
 	}
 	return err
 }
@@ -561,3 +561,73 @@ func FixedResponse(id uint16, question dns.Question, addresses []netip.Addr, tim
 	}
 	return &response
 }
+
+func FixedResponseCNAME(id uint16, question dns.Question, record string, timeToLive uint32) *dns.Msg {
+	response := dns.Msg{
+		MsgHdr: dns.MsgHdr{
+			Id:       id,
+			Rcode:    dns.RcodeSuccess,
+			Response: true,
+		},
+		Question: []dns.Question{question},
+		Answer: []dns.RR{
+			&dns.CNAME{
+				Hdr: dns.RR_Header{
+					Name:   question.Name,
+					Rrtype: dns.TypeCNAME,
+					Class:  dns.ClassINET,
+					Ttl:    timeToLive,
+				},
+				Target: record,
+			},
+		},
+	}
+	return &response
+}
+
+func FixedResponseTXT(id uint16, question dns.Question, records []string, timeToLive uint32) *dns.Msg {
+	response := dns.Msg{
+		MsgHdr: dns.MsgHdr{
+			Id:       id,
+			Rcode:    dns.RcodeSuccess,
+			Response: true,
+		},
+		Question: []dns.Question{question},
+		Answer: []dns.RR{
+			&dns.TXT{
+				Hdr: dns.RR_Header{
+					Name:   question.Name,
+					Rrtype: dns.TypeA,
+					Class:  dns.ClassINET,
+					Ttl:    timeToLive,
+				},
+				Txt: records,
+			},
+		},
+	}
+	return &response
+}
+
+func FixedResponseMX(id uint16, question dns.Question, records []*net.MX, timeToLive uint32) *dns.Msg {
+	response := dns.Msg{
+		MsgHdr: dns.MsgHdr{
+			Id:       id,
+			Rcode:    dns.RcodeSuccess,
+			Response: true,
+		},
+		Question: []dns.Question{question},
+	}
+	for _, record := range records {
+		response.Answer = append(response.Answer, &dns.MX{
+			Hdr: dns.RR_Header{
+				Name:   question.Name,
+				Rrtype: dns.TypeA,
+				Class:  dns.ClassINET,
+				Ttl:    timeToLive,
+			},
+			Preference: record.Pref,
+			Mx:         record.Host,
+		})
+	}
+	return &response
+}

+ 9 - 25
dns/rcode.go

@@ -1,33 +1,17 @@
 package dns
 
-import F "github.com/sagernet/sing/common/format"
+import (
+	mDNS "github.com/miekg/dns"
+)
 
 const (
-	RCodeSuccess        RCodeError = 0 // NoError
-	RCodeFormatError    RCodeError = 1 // FormErr
-	RCodeServerFailure  RCodeError = 2 // ServFail
-	RCodeNameError      RCodeError = 3 // NXDomain
-	RCodeNotImplemented RCodeError = 4 // NotImp
-	RCodeRefused        RCodeError = 5 // Refused
+	RcodeFormatError RcodeError = mDNS.RcodeFormatError
+	RcodeNameError   RcodeError = mDNS.RcodeNameError
+	RcodeRefused     RcodeError = mDNS.RcodeRefused
 )
 
-type RCodeError uint16
+type RcodeError int
 
-func (e RCodeError) Error() string {
-	switch e {
-	case RCodeSuccess:
-		return "success"
-	case RCodeFormatError:
-		return "format error"
-	case RCodeServerFailure:
-		return "server failure"
-	case RCodeNameError:
-		return "name error"
-	case RCodeNotImplemented:
-		return "not implemented"
-	case RCodeRefused:
-		return "refused"
-	default:
-		return F.ToString("unknown error: ", uint16(e))
-	}
+func (e RcodeError) Error() string {
+	return mDNS.RcodeToString[int(e)]
 }

+ 2 - 2
dns/router.go

@@ -336,13 +336,13 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ
 			}
 		} else if len(responseAddrs) == 0 {
 			r.logger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
-			err = RCodeNameError
+			err = RcodeNameError
 		}
 	}
 	responseAddrs, cached = r.client.LookupCache(domain, options.Strategy)
 	if cached {
 		if len(responseAddrs) == 0 {
-			return nil, RCodeNameError
+			return nil, RcodeNameError
 		}
 		return responseAddrs, nil
 	}

+ 0 - 4
dns/transport/local/local.go

@@ -19,10 +19,6 @@ import (
 	mDNS "github.com/miekg/dns"
 )
 
-func RegisterTransport(registry *dns.TransportRegistry) {
-	dns.RegisterTransport[option.LocalDNSServerOptions](registry, C.DNSTypeLocal, NewTransport)
-}
-
 var _ adapter.DNSTransport = (*Transport)(nil)
 
 type Transport struct {

+ 201 - 0
dns/transport/local/local_fallback.go

@@ -0,0 +1,201 @@
+package local
+
+import (
+	"context"
+	"errors"
+	"net"
+
+	"github.com/sagernet/sing-box/adapter"
+	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-box/dns"
+	"github.com/sagernet/sing-box/experimental/libbox/platform"
+	"github.com/sagernet/sing-box/log"
+	"github.com/sagernet/sing-box/option"
+	E "github.com/sagernet/sing/common/exceptions"
+	"github.com/sagernet/sing/service"
+
+	mDNS "github.com/miekg/dns"
+)
+
+func RegisterTransport(registry *dns.TransportRegistry) {
+	dns.RegisterTransport[option.LocalDNSServerOptions](registry, C.DNSTypeLocal, NewFallbackTransport)
+}
+
+type FallbackTransport struct {
+	adapter.DNSTransport
+	ctx      context.Context
+	fallback bool
+	resolver net.Resolver
+}
+
+func NewFallbackTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) {
+	transport, err := NewTransport(ctx, logger, tag, options)
+	if err != nil {
+		return nil, err
+	}
+	return &FallbackTransport{
+		DNSTransport: transport,
+		ctx:          ctx,
+	}, nil
+}
+
+func (f *FallbackTransport) Start(stage adapter.StartStage) error {
+	if stage != adapter.StartStateStart {
+		return nil
+	}
+	platformInterface := service.FromContext[platform.Interface](f.ctx)
+	if platformInterface == nil {
+		return nil
+	}
+	inboundManager := service.FromContext[adapter.InboundManager](f.ctx)
+	for _, inbound := range inboundManager.Inbounds() {
+		if inbound.Type() == C.TypeTun {
+			// platform tun hijacks DNS, so we can only use cgo resolver here
+			f.fallback = true
+			break
+		}
+	}
+	return nil
+}
+
+func (f *FallbackTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
+	if f.fallback {
+		return f.DNSTransport.Exchange(ctx, message)
+	}
+	question := message.Question[0]
+	domain := dns.FqdnToDomain(question.Name)
+	if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
+		var network string
+		if question.Qtype == mDNS.TypeA {
+			network = "ip4"
+		} else {
+			network = "ip6"
+		}
+		addresses, err := f.resolver.LookupNetIP(ctx, network, domain)
+		if err != nil {
+			var dnsError *net.DNSError
+			if errors.As(err, &dnsError) && dnsError.IsNotFound {
+				return nil, dns.RcodeRefused
+			}
+			return nil, err
+		}
+		return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil
+	} else if question.Qtype == mDNS.TypeNS {
+		records, err := f.resolver.LookupNS(ctx, domain)
+		if err != nil {
+			var dnsError *net.DNSError
+			if errors.As(err, &dnsError) && dnsError.IsNotFound {
+				return nil, dns.RcodeRefused
+			}
+			return nil, err
+		}
+		response := &mDNS.Msg{
+			MsgHdr: mDNS.MsgHdr{
+				Id:       message.Id,
+				Rcode:    mDNS.RcodeSuccess,
+				Response: true,
+			},
+			Question: []mDNS.Question{question},
+		}
+		for _, record := range records {
+			response.Answer = append(response.Answer, &mDNS.NS{
+				Hdr: mDNS.RR_Header{
+					Name:   question.Name,
+					Rrtype: mDNS.TypeNS,
+					Class:  mDNS.ClassINET,
+					Ttl:    C.DefaultDNSTTL,
+				},
+				Ns: record.Host,
+			})
+		}
+		return response, nil
+	} else if question.Qtype == mDNS.TypeCNAME {
+		cname, err := f.resolver.LookupCNAME(ctx, domain)
+		if err != nil {
+			var dnsError *net.DNSError
+			if errors.As(err, &dnsError) && dnsError.IsNotFound {
+				return nil, dns.RcodeRefused
+			}
+			return nil, err
+		}
+		return &mDNS.Msg{
+			MsgHdr: mDNS.MsgHdr{
+				Id:       message.Id,
+				Rcode:    mDNS.RcodeSuccess,
+				Response: true,
+			},
+			Question: []mDNS.Question{question},
+			Answer: []mDNS.RR{
+				&mDNS.CNAME{
+					Hdr: mDNS.RR_Header{
+						Name:   question.Name,
+						Rrtype: mDNS.TypeCNAME,
+						Class:  mDNS.ClassINET,
+						Ttl:    C.DefaultDNSTTL,
+					},
+					Target: cname,
+				},
+			},
+		}, nil
+	} else if question.Qtype == mDNS.TypeTXT {
+		records, err := f.resolver.LookupTXT(ctx, domain)
+		if err != nil {
+			var dnsError *net.DNSError
+			if errors.As(err, &dnsError) && dnsError.IsNotFound {
+				return nil, dns.RcodeRefused
+			}
+			return nil, err
+		}
+		return &mDNS.Msg{
+			MsgHdr: mDNS.MsgHdr{
+				Id:       message.Id,
+				Rcode:    mDNS.RcodeSuccess,
+				Response: true,
+			},
+			Question: []mDNS.Question{question},
+			Answer: []mDNS.RR{
+				&mDNS.TXT{
+					Hdr: mDNS.RR_Header{
+						Name:   question.Name,
+						Rrtype: mDNS.TypeCNAME,
+						Class:  mDNS.ClassINET,
+						Ttl:    C.DefaultDNSTTL,
+					},
+					Txt: records,
+				},
+			},
+		}, nil
+	} else if question.Qtype == mDNS.TypeMX {
+		records, err := f.resolver.LookupMX(ctx, domain)
+		if err != nil {
+			var dnsError *net.DNSError
+			if errors.As(err, &dnsError) && dnsError.IsNotFound {
+				return nil, dns.RcodeRefused
+			}
+			return nil, err
+		}
+		response := &mDNS.Msg{
+			MsgHdr: mDNS.MsgHdr{
+				Id:       message.Id,
+				Rcode:    mDNS.RcodeSuccess,
+				Response: true,
+			},
+			Question: []mDNS.Question{question},
+		}
+		for _, record := range records {
+			response.Answer = append(response.Answer, &mDNS.MX{
+				Hdr: mDNS.RR_Header{
+					Name:   question.Name,
+					Rrtype: mDNS.TypeA,
+					Class:  mDNS.ClassINET,
+					Ttl:    C.DefaultDNSTTL,
+				},
+				Preference: record.Pref,
+				Mx:         record.Host,
+			})
+		}
+		return response, nil
+	} else {
+		return nil, E.New("only A, AAAA, NS, CNAME, TXT, MX queries are supported on current platform when using TUN, please switch to a fixed DNS server.")
+	}
+}

+ 0 - 3
dns/transport/local/resolv_windows.go

@@ -69,9 +69,6 @@ func dnsReadConfig(_ string) *dnsConfig {
 	return conf
 }
 
-//go:linkname defaultNS net.defaultNS
-var defaultNS []string
-
 func adapterAddresses() ([]*windows.IpAdapterAddresses, error) {
 	var b []byte
 	l := uint32(15000) // recommended initial size

+ 1 - 1
dns/transport/predefined.go

@@ -79,5 +79,5 @@ func (t *PredefinedTransport) Exchange(ctx context.Context, message *mDNS.Msg) (
 			}
 		}
 	}
-	return nil, dns.RCodeNameError
+	return nil, dns.RcodeNameError
 }

+ 1 - 1
experimental/libbox/dns.go

@@ -134,7 +134,7 @@ func (c *ExchangeContext) RawSuccess(result []byte) {
 }
 
 func (c *ExchangeContext) ErrorCode(code int32) {
-	c.error = dns.RCodeError(code)
+	c.error = dns.RcodeError(code)
 }
 
 func (c *ExchangeContext) ErrnoCode(code int32) {

+ 1 - 1
protocol/dns/handle.go

@@ -26,7 +26,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.DNSRouter, conn
 		return err
 	}
 	if queryLength == 0 {
-		return dns.RCodeFormatError
+		return dns.RcodeFormatError
 	}
 	buffer := buf.NewSize(int(queryLength))
 	defer buffer.Release()

+ 1 - 1
protocol/tailscale/dns_transport.go

@@ -287,7 +287,7 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M
 			return nil, E.New("missing default resolvers")
 		}
 	}
-	return nil, dns.RCodeNameError
+	return nil, dns.RcodeNameError
 }
 
 type DNSDialer struct {