浏览代码

Fix TCP exchange for local/dhcp DNS servers

世界 1 月之前
父节点
当前提交
5d1d1a1456
共有 2 个文件被更改,包括 124 次插入83 次删除
  1. 59 43
      dns/transport/dhcp/dhcp_shared.go
  2. 65 40
      dns/transport/local/local.go

+ 59 - 43
dns/transport/dhcp/dhcp_shared.go

@@ -2,12 +2,13 @@ package dhcp
 
 import (
 	"context"
+	"errors"
 	"math/rand"
 	"strings"
-	"time"
+	"syscall"
 
-	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
+	"github.com/sagernet/sing-box/dns/transport"
 	"github.com/sagernet/sing/common/buf"
 	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
@@ -83,7 +84,7 @@ func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn
 			server := servers[j]
 			question := message.Question[0]
 			question.Name = fqdn
-			response, err := t.exchangeOne(ctx, server, question, C.DNSTimeout, false, true)
+			response, err := t.exchangeOne(ctx, server, question)
 			if err != nil {
 				lastErr = err
 				continue
@@ -94,62 +95,77 @@ func (t *Transport) tryOneName(ctx context.Context, servers []M.Socksaddr, fqdn
 	return nil, E.Cause(lastErr, fqdn)
 }
 
-func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question, timeout time.Duration, useTCP, ad bool) (*mDNS.Msg, error) {
+func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, question mDNS.Question) (*mDNS.Msg, error) {
 	if server.Port == 0 {
 		server.Port = 53
 	}
-	var networks []string
-	if useTCP {
-		networks = []string{N.NetworkTCP}
-	} else {
-		networks = []string{N.NetworkUDP, N.NetworkTCP}
-	}
 	request := &mDNS.Msg{
 		MsgHdr: mDNS.MsgHdr{
 			Id:                uint16(rand.Uint32()),
 			RecursionDesired:  true,
-			AuthenticatedData: ad,
+			AuthenticatedData: true,
 		},
 		Question: []mDNS.Question{question},
 		Compress: true,
 	}
 	request.SetEdns0(buf.UDPBufferSize, false)
-	buffer := buf.Get(buf.UDPBufferSize)
+	return t.exchangeUDP(ctx, server, request)
+}
+
+func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg) (*mDNS.Msg, error) {
+	conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server)
+	if err != nil {
+		return nil, err
+	}
+	defer conn.Close()
+	if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
+		conn.SetDeadline(deadline)
+	}
+	buffer := buf.Get(1 + request.Len())
 	defer buf.Put(buffer)
-	for _, network := range networks {
-		ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
-		defer cancel()
-		conn, err := t.dialer.DialContext(ctx, network, server)
-		if err != nil {
-			return nil, err
-		}
-		defer conn.Close()
-		if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
-			conn.SetDeadline(deadline)
-		}
-		rawMessage, err := request.PackBuffer(buffer)
-		if err != nil {
-			return nil, E.Cause(err, "pack request")
-		}
-		_, err = conn.Write(rawMessage)
-		if err != nil {
-			return nil, E.Cause(err, "write request")
-		}
-		n, err := conn.Read(buffer)
-		if err != nil {
-			return nil, E.Cause(err, "read response")
-		}
-		var response mDNS.Msg
-		err = response.Unpack(buffer[:n])
-		if err != nil {
-			return nil, E.Cause(err, "unpack response")
+	rawMessage, err := request.PackBuffer(buffer)
+	if err != nil {
+		return nil, E.Cause(err, "pack request")
+	}
+	_, err = conn.Write(rawMessage)
+	if err != nil {
+		if errors.Is(err, syscall.EMSGSIZE) {
+			return t.exchangeTCP(ctx, server, request)
 		}
-		if response.Truncated && network == N.NetworkUDP {
-			continue
+		return nil, E.Cause(err, "write request")
+	}
+	n, err := conn.Read(buffer)
+	if err != nil {
+		if errors.Is(err, syscall.EMSGSIZE) {
+			return t.exchangeTCP(ctx, server, request)
 		}
-		return &response, nil
+		return nil, E.Cause(err, "read response")
+	}
+	var response mDNS.Msg
+	err = response.Unpack(buffer[:n])
+	if err != nil {
+		return nil, E.Cause(err, "unpack response")
+	}
+	if response.Truncated {
+		return t.exchangeTCP(ctx, server, request)
+	}
+	return &response, nil
+}
+
+func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg) (*mDNS.Msg, error) {
+	conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server)
+	if err != nil {
+		return nil, err
+	}
+	defer conn.Close()
+	if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
+		conn.SetDeadline(deadline)
+	}
+	err = transport.WriteMessage(conn, 0, request)
+	if err != nil {
+		return nil, err
 	}
-	panic("unexpected")
+	return transport.ReadMessage(conn)
 }
 
 func (t *Transport) nameList(name string) []string {

+ 65 - 40
dns/transport/local/local.go

@@ -10,6 +10,7 @@ import (
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-box/dns"
+	"github.com/sagernet/sing-box/dns/transport"
 	"github.com/sagernet/sing-box/dns/transport/hosts"
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing-box/option"
@@ -151,12 +152,6 @@ func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, questio
 	if server.Port == 0 {
 		server.Port = 53
 	}
-	var networks []string
-	if useTCP {
-		networks = []string{N.NetworkTCP}
-	} else {
-		networks = []string{N.NetworkUDP, N.NetworkTCP}
-	}
 	request := &mDNS.Msg{
 		MsgHdr: mDNS.MsgHdr{
 			Id:                uint16(rand.Uint32()),
@@ -167,43 +162,73 @@ func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, questio
 		Compress: true,
 	}
 	request.SetEdns0(buf.UDPBufferSize, false)
-	buffer := buf.Get(buf.UDPBufferSize)
-	defer buf.Put(buffer)
-	for _, network := range networks {
-		ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
-		defer cancel()
-		conn, err := t.dialer.DialContext(ctx, network, server)
-		if err != nil {
-			return nil, err
-		}
-		defer conn.Close()
-		if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
-			conn.SetDeadline(deadline)
-		}
-		rawMessage, err := request.PackBuffer(buffer)
-		if err != nil {
-			return nil, E.Cause(err, "pack request")
-		}
-		_, err = conn.Write(rawMessage)
-		if err != nil {
-			if errors.Is(err, syscall.EMSGSIZE) && network == N.NetworkUDP {
-				continue
-			}
-			return nil, E.Cause(err, "write request")
+	if !useTCP {
+		return t.exchangeUDP(ctx, server, request, timeout)
+	} else {
+		return t.exchangeTCP(ctx, server, request, timeout)
+	}
+}
+
+func (t *Transport) exchangeUDP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) {
+	conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, server)
+	if err != nil {
+		return nil, err
+	}
+	defer conn.Close()
+	if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
+		newDeadline := time.Now().Add(timeout)
+		if deadline.After(newDeadline) {
+			deadline = newDeadline
 		}
-		n, err := conn.Read(buffer)
-		if err != nil {
-			return nil, E.Cause(err, "read response")
+		conn.SetDeadline(deadline)
+	}
+	buffer := buf.Get(1 + request.Len())
+	defer buf.Put(buffer)
+	rawMessage, err := request.PackBuffer(buffer)
+	if err != nil {
+		return nil, E.Cause(err, "pack request")
+	}
+	_, err = conn.Write(rawMessage)
+	if err != nil {
+		if errors.Is(err, syscall.EMSGSIZE) {
+			return t.exchangeTCP(ctx, server, request, timeout)
 		}
-		var response mDNS.Msg
-		err = response.Unpack(buffer[:n])
-		if err != nil {
-			return nil, E.Cause(err, "unpack response")
+		return nil, E.Cause(err, "write request")
+	}
+	n, err := conn.Read(buffer)
+	if err != nil {
+		if errors.Is(err, syscall.EMSGSIZE) {
+			return t.exchangeTCP(ctx, server, request, timeout)
 		}
-		if response.Truncated && network == N.NetworkUDP {
-			continue
+		return nil, E.Cause(err, "read response")
+	}
+	var response mDNS.Msg
+	err = response.Unpack(buffer[:n])
+	if err != nil {
+		return nil, E.Cause(err, "unpack response")
+	}
+	if response.Truncated {
+		return t.exchangeTCP(ctx, server, request, timeout)
+	}
+	return &response, nil
+}
+
+func (t *Transport) exchangeTCP(ctx context.Context, server M.Socksaddr, request *mDNS.Msg, timeout time.Duration) (*mDNS.Msg, error) {
+	conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, server)
+	if err != nil {
+		return nil, err
+	}
+	defer conn.Close()
+	if deadline, loaded := ctx.Deadline(); loaded && !deadline.IsZero() {
+		newDeadline := time.Now().Add(timeout)
+		if deadline.After(newDeadline) {
+			deadline = newDeadline
 		}
-		return &response, nil
+		conn.SetDeadline(deadline)
+	}
+	err = transport.WriteMessage(conn, 0, request)
+	if err != nil {
+		return nil, err
 	}
-	panic("unexpected")
+	return transport.ReadMessage(conn)
 }