浏览代码

Fix hijack_dns

世界 3 年之前
父节点
当前提交
a104d18277
共有 1 个文件被更改,包括 33 次插入24 次删除
  1. 33 24
      inbound/dns.go

+ 33 - 24
inbound/dns.go

@@ -16,6 +16,7 @@ import (
 )
 
 func NewDNSConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn net.Conn, metadata adapter.InboundContext) error {
+	ctx = adapter.WithContext(ctx, &metadata)
 	_buffer := buf.StackNewSize(1024)
 	defer common.KeepAlive(_buffer)
 	buffer := common.Dup(_buffer)
@@ -44,32 +45,38 @@ func NewDNSConnection(ctx context.Context, router adapter.Router, logger log.Log
 			metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
 			logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
 		}
-		response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
-		if err != nil {
-			return err
-		}
-		buffer.FullReset()
-		responseBuffer, err := response.AppendPack(buffer.Index(0))
-		if err != nil {
-			return err
-		}
-		err = binary.Write(conn, binary.BigEndian, uint16(len(responseBuffer)))
-		if err != nil {
-			return err
-		}
-		_, err = conn.Write(responseBuffer)
-		if err != nil {
+		go func() error {
+			response, err := router.Exchange(ctx, &message)
+			if err != nil {
+				return err
+			}
+			_responseBuffer := buf.StackNewSize(1024)
+			defer common.KeepAlive(_responseBuffer)
+			responseBuffer := common.Dup(_responseBuffer)
+			defer responseBuffer.Release()
+			responseBuffer.Resize(2, 0)
+			n, err := response.AppendPack(responseBuffer.Index(0))
+			if err != nil {
+				return err
+			}
+			responseBuffer.Truncate(len(n))
+			binary.BigEndian.PutUint16(responseBuffer.ExtendHeader(2), uint16(len(n)))
+			_, err = conn.Write(responseBuffer.Bytes())
 			return err
-		}
+		}()
 	}
 }
 
 func NewDNSPacketConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn N.PacketConn, metadata adapter.InboundContext) error {
+	ctx = adapter.WithContext(ctx, &metadata)
+	_buffer := buf.StackNewSize(1024)
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
 	for {
-		buffer := buf.StackNewSize(1024)
+		buffer.FullReset()
 		destination, err := conn.ReadPacket(buffer)
 		if err != nil {
-			buffer.Release()
 			return err
 		}
 		var message dnsmessage.Message
@@ -83,18 +90,20 @@ func NewDNSPacketConnection(ctx context.Context, router adapter.Router, logger l
 			logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
 		}
 		go func() error {
-			defer buffer.Release()
-			response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
+			response, err := router.Exchange(ctx, &message)
 			if err != nil {
 				return err
 			}
-			buffer.FullReset()
-			responseBuffer, err := response.AppendPack(buffer.Index(0))
+			_responseBuffer := buf.StackNewSize(1024)
+			defer common.KeepAlive(_responseBuffer)
+			responseBuffer := common.Dup(_responseBuffer)
+			defer responseBuffer.Release()
+			n, err := response.AppendPack(responseBuffer.Index(0))
 			if err != nil {
 				return err
 			}
-			buffer.Truncate(len(responseBuffer))
-			err = conn.WritePacket(buffer, destination)
+			responseBuffer.Truncate(len(n))
+			err = conn.WritePacket(responseBuffer, destination)
 			return err
 		}()
 	}