| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 | 
							- package dhcp
 
- import (
 
- 	"context"
 
- 	"math/rand"
 
- 	"strings"
 
- 	"time"
 
- 	C "github.com/sagernet/sing-box/constant"
 
- 	"github.com/sagernet/sing-box/dns"
 
- 	"github.com/sagernet/sing/common/buf"
 
- 	E "github.com/sagernet/sing/common/exceptions"
 
- 	M "github.com/sagernet/sing/common/metadata"
 
- 	N "github.com/sagernet/sing/common/network"
 
- 	mDNS "github.com/miekg/dns"
 
- )
 
- func (t *Transport) exchangeSingleRequest(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
 
- 	var lastErr error
 
- 	for _, fqdn := range t.nameList(domain) {
 
- 		response, err := t.tryOneName(ctx, servers, fqdn, message)
 
- 		if err != nil {
 
- 			lastErr = err
 
- 			continue
 
- 		}
 
- 		return response, nil
 
- 	}
 
- 	return nil, lastErr
 
- }
 
- func (t *Transport) exchangeParallel(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
 
- 	returned := make(chan struct{})
 
- 	defer close(returned)
 
- 	type queryResult struct {
 
- 		response *mDNS.Msg
 
- 		err      error
 
- 	}
 
- 	results := make(chan queryResult)
 
- 	startRacer := func(ctx context.Context, fqdn string) {
 
- 		response, err := t.tryOneName(ctx, servers, fqdn, message)
 
- 		if err == nil {
 
- 			if response.Rcode != mDNS.RcodeSuccess {
 
- 				err = dns.RcodeError(response.Rcode)
 
- 			} else if len(dns.MessageToAddresses(response)) == 0 {
 
- 				err = E.New(fqdn, ": empty result")
 
- 			}
 
- 		}
 
- 		select {
 
- 		case results <- queryResult{response, err}:
 
- 		case <-returned:
 
- 		}
 
- 	}
 
- 	queryCtx, queryCancel := context.WithCancel(ctx)
 
- 	defer queryCancel()
 
- 	var nameCount int
 
- 	for _, fqdn := range t.nameList(domain) {
 
- 		nameCount++
 
- 		go startRacer(queryCtx, fqdn)
 
- 	}
 
- 	var errors []error
 
- 	for {
 
- 		select {
 
- 		case <-ctx.Done():
 
- 			return nil, ctx.Err()
 
- 		case result := <-results:
 
- 			if result.err == nil {
 
- 				return result.response, nil
 
- 			}
 
- 			errors = append(errors, result.err)
 
- 			if len(errors) == nameCount {
 
- 				return nil, E.Errors(errors...)
 
- 			}
 
- 		}
 
- 	}
 
- }
 
- func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn string, message *mDNS.Msg) (*mDNS.Msg, error) {
 
- 	sLen := len(servers)
 
- 	var lastErr error
 
- 	for i := 0; i < t.attempts; i++ {
 
- 		for j := 0; j < sLen; j++ {
 
- 			server := servers[j]
 
- 			question := message.Question[0]
 
- 			question.Name = fqdn
 
- 			response, err := t.exchangeOne(ctx, server, question, C.DNSTimeout, false, true)
 
- 			if err != nil {
 
- 				lastErr = err
 
- 				continue
 
- 			}
 
- 			return response, nil
 
- 		}
 
- 	}
 
- 	return nil, E.Cause(lastErr, fqdn)
 
- }
 
- func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
 
- 	if server.Port == 0 {
 
- 		server.Port = 53
 
- 	}
 
- 	var networks []string
 
- 	if useTCP {
 
- 		networks = []string{N.NetworkTCP}
 
- 	} else {
 
- 		networks = []string{N.NetworkUDP, N.NetworkTCP}
 
- 	}
 
- 	request := &mDNS.Msg{
 
- 		MsgHdr: mDNS.MsgHdr{
 
- 			Id:                uint16(rand.Uint32()),
 
- 			RecursionDesired:  true,
 
- 			AuthenticatedData: ad,
 
- 		},
 
- 		Question: []mDNS.Question{question},
 
- 		Compress: true,
 
- 	}
 
- 	request.SetEdns0(buf.UDPBufferSize, false)
 
- 	buffer := buf.Get(buf.UDPBufferSize)
 
- 	defer buf.Put(buffer)
 
- 	for _, network := range networks {
 
- 		ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
 
- 		defer cancel()
 
- 		conn, err := t.dialer.DialContext(ctx, network, server)
 
- 		if err != nil {
 
- 			return nil, err
 
- 		}
 
- 		defer conn.Close()
 
- 		if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
 
- 			conn.SetDeadline(deadline)
 
- 		}
 
- 		rawMessage, err := request.PackBuffer(buffer)
 
- 		if err != nil {
 
- 			return nil, E.Cause(err, "pack request")
 
- 		}
 
- 		_, err = conn.Write(rawMessage)
 
- 		if err != nil {
 
- 			return nil, E.Cause(err, "write request")
 
- 		}
 
- 		n, err := conn.Read(buffer)
 
- 		if err != nil {
 
- 			return nil, E.Cause(err, "read response")
 
- 		}
 
- 		var response mDNS.Msg
 
- 		err = response.Unpack(buffer[:n])
 
- 		if err != nil {
 
- 			return nil, E.Cause(err, "unpack response")
 
- 		}
 
- 		if response.Truncated && network == N.NetworkUDP {
 
- 			continue
 
- 		}
 
- 		return &response, nil
 
- 	}
 
- 	panic("unexpected")
 
- }
 
- func (t *Transport) nameList(name string) []string {
 
- 	l := len(name)
 
- 	rooted := l > 0 && name[l-1] == '.'
 
- 	if l > 254 || l == 254 && !rooted {
 
- 		return nil
 
- 	}
 
- 	if rooted {
 
- 		if avoidDNS(name) {
 
- 			return nil
 
- 		}
 
- 		return []string{name}
 
- 	}
 
- 	hasNdots := strings.Count(name, ".") >= t.ndots
 
- 	name += "."
 
- 	// l++
 
- 	names := make([]string, 0, 1+len(t.search))
 
- 	if hasNdots && !avoidDNS(name) {
 
- 		names = append(names, name)
 
- 	}
 
- 	for _, suffix := range t.search {
 
- 		fqdn := name + suffix
 
- 		if !avoidDNS(fqdn) && len(fqdn) <= 254 {
 
- 			names = append(names, fqdn)
 
- 		}
 
- 	}
 
- 	if !hasNdots && !avoidDNS(name) {
 
- 		names = append(names, name)
 
- 	}
 
- 	return names
 
- }
 
- func avoidDNS(name string) bool {
 
- 	if name == "" {
 
- 		return true
 
- 	}
 
- 	if name[len(name)-1] == '.' {
 
- 		name = name[:len(name)-1]
 
- 	}
 
- 	return strings.HasSuffix(name, ".onion")
 
- }
 
 
  |