| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423 | package dnsimport (	"bytes"	"context"	"encoding/binary"	"net/url"	"sync"	"time"	"github.com/quic-go/quic-go"	"github.com/xtls/xray-core/common"	"github.com/xtls/xray-core/common/buf"	"github.com/xtls/xray-core/common/errors"	"github.com/xtls/xray-core/common/log"	"github.com/xtls/xray-core/common/net"	"github.com/xtls/xray-core/common/protocol/dns"	"github.com/xtls/xray-core/common/session"	"github.com/xtls/xray-core/common/signal/pubsub"	"github.com/xtls/xray-core/common/task"	dns_feature "github.com/xtls/xray-core/features/dns"	"github.com/xtls/xray-core/transport/internet/tls"	"golang.org/x/net/dns/dnsmessage"	"golang.org/x/net/http2")// NextProtoDQ - During connection establishment, DNS/QUIC support is indicated// by selecting the ALPN token "dq" in the crypto handshake.const NextProtoDQ = "doq"const handshakeTimeout = time.Second * 8// QUICNameServer implemented DNS over QUICtype QUICNameServer struct {	sync.RWMutex	ips           map[string]*record	pub           *pubsub.Service	cleanup       *task.Periodic	name          string	destination   *net.Destination	connection    quic.Connection	queryStrategy QueryStrategy}// NewQUICNameServer creates DNS-over-QUIC client object for local resolvingfunc NewQUICNameServer(url *url.URL, queryStrategy QueryStrategy) (*QUICNameServer, error) {	errors.LogInfo(context.Background(), "DNS: created Local DNS-over-QUIC client for ", url.String())	var err error	port := net.Port(853)	if url.Port() != "" {		port, err = net.PortFromString(url.Port())		if err != nil {			return nil, err		}	}	dest := net.UDPDestination(net.ParseAddress(url.Hostname()), port)	s := &QUICNameServer{		ips:           make(map[string]*record),		pub:           pubsub.NewService(),		name:          url.String(),		destination:   &dest,		queryStrategy: queryStrategy,	}	s.cleanup = &task.Periodic{		Interval: time.Minute,		Execute:  s.Cleanup,	}	return s, nil}// Name returns client namefunc (s *QUICNameServer) Name() string {	return s.name}// Cleanup clears expired items from cachefunc (s *QUICNameServer) Cleanup() error {	now := time.Now()	s.Lock()	defer s.Unlock()	if len(s.ips) == 0 {		return errors.New("nothing to do. stopping...")	}	for domain, record := range s.ips {		if record.A != nil && record.A.Expire.Before(now) {			record.A = nil		}		if record.AAAA != nil && record.AAAA.Expire.Before(now) {			record.AAAA = nil		}		if record.A == nil && record.AAAA == nil {			errors.LogDebug(context.Background(), s.name, " cleanup ", domain)			delete(s.ips, domain)		} else {			s.ips[domain] = record		}	}	if len(s.ips) == 0 {		s.ips = make(map[string]*record)	}	return nil}func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {	elapsed := time.Since(req.start)	s.Lock()	rec, found := s.ips[req.domain]	if !found {		rec = &record{}	}	updated := false	switch req.reqType {	case dnsmessage.TypeA:		if isNewer(rec.A, ipRec) {			rec.A = ipRec			updated = true		}	case dnsmessage.TypeAAAA:		addr := make([]net.Address, 0)		for _, ip := range ipRec.IP {			if len(ip.IP()) == net.IPv6len {				addr = append(addr, ip)			}		}		ipRec.IP = addr		if isNewer(rec.AAAA, ipRec) {			rec.AAAA = ipRec			updated = true		}	}	errors.LogInfo(context.Background(), s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed)	if updated {		s.ips[req.domain] = rec	}	switch req.reqType {	case dnsmessage.TypeA:		s.pub.Publish(req.domain+"4", nil)	case dnsmessage.TypeAAAA:		s.pub.Publish(req.domain+"6", nil)	}	s.Unlock()	common.Must(s.cleanup.Start())}func (s *QUICNameServer) newReqID() uint16 {	return 0}func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {	errors.LogInfo(ctx, s.name, " querying: ", domain)	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))	var deadline time.Time	if d, ok := ctx.Deadline(); ok {		deadline = d	} else {		deadline = time.Now().Add(time.Second * 5)	}	for _, req := range reqs {		go func(r *dnsRequest) {			// generate new context for each req, using same context			// may cause reqs all aborted if any one encounter an error			dnsCtx := ctx			// reserve internal dns server requested Inbound			if inbound := session.InboundFromContext(ctx); inbound != nil {				dnsCtx = session.ContextWithInbound(dnsCtx, inbound)			}			dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{				Protocol:       "quic",				SkipDNSResolve: true,			})			var cancel context.CancelFunc			dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline)			defer cancel()			b, err := dns.PackMessage(r.msg)			if err != nil {				errors.LogErrorInner(ctx, err, "failed to pack dns query")				return			}			dnsReqBuf := buf.New()			binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))			dnsReqBuf.Write(b.Bytes())			b.Release()			conn, err := s.openStream(dnsCtx)			if err != nil {				errors.LogErrorInner(ctx, err, "failed to open quic connection")				return			}			_, err = conn.Write(dnsReqBuf.Bytes())			if err != nil {				errors.LogErrorInner(ctx, err, "failed to send query")				return			}			_ = conn.Close()			respBuf := buf.New()			defer respBuf.Release()			n, err := respBuf.ReadFullFrom(conn, 2)			if err != nil && n == 0 {				errors.LogErrorInner(ctx, err, "failed to read response length")				return			}			var length int16			err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)			if err != nil {				errors.LogErrorInner(ctx, err, "failed to parse response length")				return			}			respBuf.Clear()			n, err = respBuf.ReadFullFrom(conn, int32(length))			if err != nil && n == 0 {				errors.LogErrorInner(ctx, err, "failed to read response length")				return			}			rec, err := parseResponse(respBuf.Bytes())			if err != nil {				errors.LogErrorInner(ctx, err, "failed to handle response")				return			}			s.updateIP(r, rec)		}(req)	}}func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) {	s.RLock()	record, found := s.ips[domain]	s.RUnlock()	if !found {		return nil, errRecordNotFound	}	var err4 error	var err6 error	var ips []net.Address	var ip6 []net.Address	if option.IPv4Enable {		ips, err4 = record.A.getIPs()	}	if option.IPv6Enable {		ip6, err6 = record.AAAA.getIPs()		ips = append(ips, ip6...)	}	if len(ips) > 0 {		return toNetIP(ips)	}	if err4 != nil {		return nil, err4	}	if err6 != nil {		return nil, err6	}	if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {		return nil, dns_feature.ErrEmptyResponse	}	return nil, errRecordNotFound}// QueryIP is called from dns.Server->queryIPTimeoutfunc (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) {	fqdn := Fqdn(domain)	option = ResolveIpOptionOverride(s.queryStrategy, option)	if !option.IPv4Enable && !option.IPv6Enable {		return nil, dns_feature.ErrEmptyResponse	}	if disableCache {		errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.name)	} else {		ips, err := s.findIPsForDomain(fqdn, option)		if err == nil || err == dns_feature.ErrEmptyResponse {			errors.LogDebugInner(ctx, err, s.name, " cache HIT ", domain, " -> ", ips)			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err})			return ips, err		}	}	// ipv4 and ipv6 belong to different subscription groups	var sub4, sub6 *pubsub.Subscriber	if option.IPv4Enable {		sub4 = s.pub.Subscribe(fqdn + "4")		defer sub4.Close()	}	if option.IPv6Enable {		sub6 = s.pub.Subscribe(fqdn + "6")		defer sub6.Close()	}	done := make(chan interface{})	go func() {		if sub4 != nil {			select {			case <-sub4.Wait():			case <-ctx.Done():			}		}		if sub6 != nil {			select {			case <-sub6.Wait():			case <-ctx.Done():			}		}		close(done)	}()	s.sendQuery(ctx, fqdn, clientIP, option)	start := time.Now()	for {		ips, err := s.findIPsForDomain(fqdn, option)		if err != errRecordNotFound {			log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})			return ips, err		}		select {		case <-ctx.Done():			return nil, ctx.Err()		case <-done:		}	}}func isActive(s quic.Connection) bool {	select {	case <-s.Context().Done():		return false	default:		return true	}}func (s *QUICNameServer) getConnection() (quic.Connection, error) {	var conn quic.Connection	s.RLock()	conn = s.connection	if conn != nil && isActive(conn) {		s.RUnlock()		return conn, nil	}	if conn != nil {		// we're recreating the connection, let's create a new one		_ = conn.CloseWithError(0, "")	}	s.RUnlock()	s.Lock()	defer s.Unlock()	var err error	conn, err = s.openConnection()	if err != nil {		// This does not look too nice, but QUIC (or maybe quic-go)		// doesn't seem stable enough.		// Maybe retransmissions aren't fully implemented in quic-go?		// Anyways, the simple solution is to make a second try when		// it fails to open the QUIC connection.		conn, err = s.openConnection()		if err != nil {			return nil, err		}	}	s.connection = conn	return conn, nil}func (s *QUICNameServer) openConnection() (quic.Connection, error) {	tlsConfig := tls.Config{}	quicConfig := &quic.Config{		HandshakeIdleTimeout: handshakeTimeout,	}	tlsConfig.ServerName = s.destination.Address.String()	conn, err := quic.DialAddr(context.Background(), s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto("http/1.1", http2.NextProtoTLS, NextProtoDQ)), quicConfig)	log.Record(&log.AccessMessage{		From:   "DNS",		To:     s.destination,		Status: log.AccessAccepted,		Detour: "local",	})	if err != nil {		return nil, err	}	return conn, nil}func (s *QUICNameServer) openStream(ctx context.Context) (quic.Stream, error) {	conn, err := s.getConnection()	if err != nil {		return nil, err	}	// open a new stream	return conn.OpenStreamSync(ctx)}
 |