Ver Fonte

Fix read DNS message

世界 há 3 anos atrás
pai
commit
c5e38203eb
2 ficheiros alterados com 47 adições e 40 exclusões
  1. 1 1
      common/sniff/dns.go
  2. 46 39
      outbound/dns.go

+ 1 - 1
common/sniff/dns.go

@@ -22,7 +22,7 @@ func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter.
 	if err != nil {
 		return nil, err
 	}
-	if length > 512 {
+	if length == 0 {
 		return nil, os.ErrInvalid
 	}
 	_buffer := buf.StackNewSize(int(length))

+ 46 - 39
outbound/dns.go

@@ -3,13 +3,13 @@ package outbound
 import (
 	"context"
 	"encoding/binary"
-	"io"
 	"net"
 	"os"
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/common/canceler"
 	C "github.com/sagernet/sing-box/constant"
+	"github.com/sagernet/sing-dns"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
 	M "github.com/sagernet/sing/common/metadata"
@@ -47,53 +47,60 @@ func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.Pa
 func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
 	defer conn.Close()
 	ctx = adapter.WithContext(ctx, &metadata)
-	_buffer := buf.StackNewSize(1024)
-	defer common.KeepAlive(_buffer)
-	buffer := common.Dup(_buffer)
-	defer buffer.Release()
 	for {
-		var queryLength uint16
-		err := binary.Read(conn, binary.BigEndian, &queryLength)
+		err := d.handleConnection(ctx, conn, metadata)
 		if err != nil {
 			return err
 		}
-		if queryLength > 1024 {
-			return io.ErrShortBuffer
-		}
-		buffer.FullReset()
-		_, err = buffer.ReadFullFrom(conn, int(queryLength))
+	}
+}
+
+func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
+	var queryLength uint16
+	err := binary.Read(conn, binary.BigEndian, &queryLength)
+	if err != nil {
+		return err
+	}
+	if queryLength == 0 {
+		return dns.RCodeFormatError
+	}
+	_buffer := buf.StackNewSize(int(queryLength))
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+	_, err = buffer.ReadFullFrom(conn, int(queryLength))
+	if err != nil {
+		return err
+	}
+	var message dnsmessage.Message
+	err = message.Unpack(buffer.Bytes())
+	if err != nil {
+		return err
+	}
+	if len(message.Questions) > 0 {
+		question := message.Questions[0]
+		metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
+	}
+	go func() error {
+		response, err := d.router.Exchange(ctx, &message)
 		if err != nil {
 			return err
 		}
-		var message dnsmessage.Message
-		err = message.Unpack(buffer.Bytes())
+		_responseBuffer := buf.StackNewPacket()
+		defer common.KeepAlive(_responseBuffer)
+		responseBuffer := common.Dup(_responseBuffer)
+		defer responseBuffer.Release()
+		responseBuffer.Resize(2, 0)
+		n, err := response.AppendPack(responseBuffer.Index(0))
 		if err != nil {
 			return err
 		}
-		if len(message.Questions) > 0 {
-			question := message.Questions[0]
-			metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
-		}
-		go func() error {
-			response, err := d.router.Exchange(ctx, &message)
-			if err != nil {
-				return err
-			}
-			_responseBuffer := buf.StackNewPacket()
-			defer common.KeepAlive(_responseBuffer)
-			responseBuffer := common.Dup(_responseBuffer)
-			defer responseBuffer.Release()
-			responseBuffer.Resize(2, 0)
-			n, err := response.AppendPack(responseBuffer.Index(0))
-			if err != nil {
-				return err
-			}
-			responseBuffer.Truncate(len(n))
-			binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
-			_, err = conn.Write(responseBuffer.Bytes())
-			return err
-		}()
-	}
+		responseBuffer.Truncate(len(n))
+		binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
+		_, err = conn.Write(responseBuffer.Bytes())
+		return err
+	}()
+	return nil
 }
 
 func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
@@ -103,7 +110,7 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
 	var group task.Group
 	group.Append0(func(ctx context.Context) error {
 		defer cancel()
-		_buffer := buf.StackNewSize(1024)
+		_buffer := buf.StackNewSize(dns.FixedPacketSize)
 		defer common.KeepAlive(_buffer)
 		buffer := common.Dup(_buffer)
 		defer buffer.Release()