| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380 | package dnsimport (	"context"	"io"	"sync"	"time"	"github.com/xtls/xray-core/common"	"github.com/xtls/xray-core/common/buf"	"github.com/xtls/xray-core/common/errors"	"github.com/xtls/xray-core/common/net"	dns_proto "github.com/xtls/xray-core/common/protocol/dns"	"github.com/xtls/xray-core/common/session"	"github.com/xtls/xray-core/common/signal"	"github.com/xtls/xray-core/common/task"	"github.com/xtls/xray-core/core"	"github.com/xtls/xray-core/features/dns"	"github.com/xtls/xray-core/features/policy"	"github.com/xtls/xray-core/transport"	"github.com/xtls/xray-core/transport/internet"	"github.com/xtls/xray-core/transport/internet/stat"	"golang.org/x/net/dns/dnsmessage")func init() {	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {		h := new(Handler)		if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error {			core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) {				h.fdns = fdns			})			return h.Init(config.(*Config), dnsClient, policyManager)		}); err != nil {			return nil, err		}		return h, nil	}))}type ownLinkVerifier interface {	IsOwnLink(ctx context.Context) bool}type Handler struct {	client          dns.Client	fdns            dns.FakeDNSEngine	ownLinkVerifier ownLinkVerifier	server          net.Destination	timeout         time.Duration	nonIPQuery      string	blockTypes      []int32}func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error {	h.client = dnsClient	h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle	if v, ok := dnsClient.(ownLinkVerifier); ok {		h.ownLinkVerifier = v	}	if config.Server != nil {		h.server = config.Server.AsDestination()	}	h.nonIPQuery = config.Non_IPQuery	h.blockTypes = config.BlockTypes	return nil}func (h *Handler) isOwnLink(ctx context.Context) bool {	return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx)}func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) {	var parser dnsmessage.Parser	header, err := parser.Start(b)	if err != nil {		errors.LogInfoInner(context.Background(), err, "parser start")		return	}	id = header.ID	q, err := parser.Question()	if err != nil {		errors.LogInfoInner(context.Background(), err, "question")		return	}	domain = q.Name.String()	qType = q.Type	if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA {		return	}	r = true	return}// Process implements proxy.Outbound.func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error {	outbounds := session.OutboundsFromContext(ctx)	ob := outbounds[len(outbounds)-1]	if !ob.Target.IsValid() {		return errors.New("invalid outbound")	}	ob.Name = "dns"	srcNetwork := ob.Target.Network	dest := ob.Target	if h.server.Network != net.Network_Unknown {		dest.Network = h.server.Network	}	if h.server.Address != nil {		dest.Address = h.server.Address	}	if h.server.Port != 0 {		dest.Port = h.server.Port	}	errors.LogInfo(ctx, "handling DNS traffic to ", dest)	conn := &outboundConn{		dialer: func() (stat.Connection, error) {			return d.Dial(ctx, dest)		},		connReady: make(chan struct{}, 1),	}	var reader dns_proto.MessageReader	var writer dns_proto.MessageWriter	if srcNetwork == net.Network_TCP {		reader = dns_proto.NewTCPReader(link.Reader)		writer = &dns_proto.TCPWriter{			Writer: link.Writer,		}	} else {		reader = &dns_proto.UDPReader{			Reader: link.Reader,		}		writer = &dns_proto.UDPWriter{			Writer: link.Writer,		}	}	var connReader dns_proto.MessageReader	var connWriter dns_proto.MessageWriter	if dest.Network == net.Network_TCP {		connReader = dns_proto.NewTCPReader(buf.NewReader(conn))		connWriter = &dns_proto.TCPWriter{			Writer: buf.NewWriter(conn),		}	} else {		connReader = &dns_proto.UDPReader{			Reader: buf.NewPacketReader(conn),		}		connWriter = &dns_proto.UDPWriter{			Writer: buf.NewWriter(conn),		}	}	if session.TimeoutOnlyFromContext(ctx) {		ctx, _ = context.WithCancel(context.Background())	}	ctx, cancel := context.WithCancel(ctx)	timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)	request := func() error {		defer conn.Close()		for {			b, err := reader.ReadMessage()			if err == io.EOF {				return nil			}			if err != nil {				return err			}			timer.Update()			if !h.isOwnLink(ctx) {				isIPQuery, domain, id, qType := parseIPQuery(b.Bytes())				if len(h.blockTypes) > 0 {					for _, blocktype := range h.blockTypes {						if blocktype == int32(qType) {							errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain)							return nil						}					}				}				if isIPQuery {					go h.handleIPQuery(id, qType, domain, writer)				}				if isIPQuery || h.nonIPQuery == "drop" {					b.Release()					continue				}			}			if err := connWriter.WriteMessage(b); err != nil {				return err			}		}	}	response := func() error {		for {			b, err := connReader.ReadMessage()			if err == io.EOF {				return nil			}			if err != nil {				return err			}			timer.Update()			if err := writer.WriteMessage(b); err != nil {				return err			}		}	}	if err := task.Run(ctx, request, response); err != nil {		return errors.New("connection ends").Base(err)	}	return nil}func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {	var ips []net.IP	var err error	var ttl uint32 = 600	switch qType {	case dnsmessage.TypeA:		ips, err = h.client.LookupIP(domain, dns.IPOption{			IPv4Enable: true,			IPv6Enable: false,			FakeEnable: true,		})	case dnsmessage.TypeAAAA:		ips, err = h.client.LookupIP(domain, dns.IPOption{			IPv4Enable: false,			IPv6Enable: true,			FakeEnable: true,		})	}	rcode := dns.RCodeFromError(err)	if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) {		errors.LogInfoInner(context.Background(), err, "ip query")		return	}	if fkr0, ok := h.fdns.(dns.FakeDNSEngineRev0); ok && len(ips) > 0 && fkr0.IsIPInIPPool(net.IPAddress(ips[0])) {		ttl = 1	}	switch qType {	case dnsmessage.TypeA:		for i, ip := range ips {			ips[i] = ip.To4()		}	case dnsmessage.TypeAAAA:		for i, ip := range ips {			ips[i] = ip.To16()		}	}	b := buf.New()	rawBytes := b.Extend(buf.Size)	builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{		ID:                 id,		RCode:              dnsmessage.RCode(rcode),		RecursionAvailable: true,		RecursionDesired:   true,		Response:           true,		Authoritative:      true,	})	builder.EnableCompression()	common.Must(builder.StartQuestions())	common.Must(builder.Question(dnsmessage.Question{		Name:  dnsmessage.MustNewName(domain),		Class: dnsmessage.ClassINET,		Type:  qType,	}))	common.Must(builder.StartAnswers())	rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl}	for _, ip := range ips {		if len(ip) == net.IPv4len {			var r dnsmessage.AResource			copy(r.A[:], ip)			common.Must(builder.AResource(rHeader, r))		} else {			var r dnsmessage.AAAAResource			copy(r.AAAA[:], ip)			common.Must(builder.AAAAResource(rHeader, r))		}	}	msgBytes, err := builder.Finish()	if err != nil {		errors.LogInfoInner(context.Background(), err, "pack message")		b.Release()		return	}	b.Resize(0, int32(len(msgBytes)))	if err := writer.WriteMessage(b); err != nil {		errors.LogInfoInner(context.Background(), err, "write IP answer")	}}type outboundConn struct {	access sync.Mutex	dialer func() (stat.Connection, error)	conn      net.Conn	connReady chan struct{}}func (c *outboundConn) dial() error {	conn, err := c.dialer()	if err != nil {		return err	}	c.conn = conn	c.connReady <- struct{}{}	return nil}func (c *outboundConn) Write(b []byte) (int, error) {	c.access.Lock()	if c.conn == nil {		if err := c.dial(); err != nil {			c.access.Unlock()			errors.LogWarningInner(context.Background(), err, "failed to dial outbound connection")			return len(b), nil		}	}	c.access.Unlock()	return c.conn.Write(b)}func (c *outboundConn) Read(b []byte) (int, error) {	var conn net.Conn	c.access.Lock()	conn = c.conn	c.access.Unlock()	if conn == nil {		_, open := <-c.connReady		if !open {			return 0, io.EOF		}		conn = c.conn	}	return conn.Read(b)}func (c *outboundConn) Close() error {	c.access.Lock()	close(c.connReady)	if c.conn != nil {		c.conn.Close()	}	c.access.Unlock()	return nil}
 |