Browse Source

Improve DNS truncate behavior

世界 1 year ago
parent
commit
917514e09f
1 changed files with 5 additions and 35 deletions
  1. 5 35
      outbound/dns.go

+ 5 - 35
outbound/dns.go

@@ -46,8 +46,8 @@ func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.Pa
 }
 }
 
 
 func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
 func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
+	metadata.Destination = M.Socksaddr{}
 	defer conn.Close()
 	defer conn.Close()
-	ctx = adapter.WithContext(ctx, &metadata)
 	for {
 	for {
 		err := d.handleConnection(ctx, conn, metadata)
 		err := d.handleConnection(ctx, conn, metadata)
 		if err != nil {
 		if err != nil {
@@ -98,6 +98,7 @@ func (d *DNS) handleConnection(ctx context.Context, conn net.Conn, metadata adap
 }
 }
 
 
 func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
 func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
+	metadata.Destination = M.Socksaddr{}
 	var reader N.PacketReader = conn
 	var reader N.PacketReader = conn
 	var counters []N.CountFunc
 	var counters []N.CountFunc
 	var cachedPackets []*N.PacketBuffer
 	var cachedPackets []*N.PacketBuffer
@@ -111,14 +112,11 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
 			}
 			}
 		}
 		}
 		if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
 		if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
-			readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
-				MTU: dns.FixedPacketSize,
-			})
+			readWaiter.InitializeReadWaiter(N.ReadWaitOptions{})
 			return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata)
 			return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata)
 		}
 		}
 		break
 		break
 	}
 	}
-	ctx = adapter.WithContext(ctx, &metadata)
 	fastClose, cancel := common.ContextWithCancelCause(ctx)
 	fastClose, cancel := common.ContextWithCancelCause(ctx)
 	timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
 	timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
 	var group task.Group
 	var group task.Group
@@ -167,15 +165,11 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
 					return err
 					return err
 				}
 				}
 				timeout.Update()
 				timeout.Update()
-				responseBuffer := buf.NewPacket()
-				responseBuffer.Resize(1024, 0)
-				n, err := response.PackBuffer(responseBuffer.FreeBytes())
+				responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
 				if err != nil {
 				if err != nil {
 					cancel(err)
 					cancel(err)
-					responseBuffer.Release()
 					return err
 					return err
 				}
 				}
-				responseBuffer.Truncate(len(n))
 				err = conn.WritePacket(responseBuffer, destination)
 				err = conn.WritePacket(responseBuffer, destination)
 				if err != nil {
 				if err != nil {
 					cancel(err)
 					cancel(err)
@@ -241,16 +235,11 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa
 					return err
 					return err
 				}
 				}
 				timeout.Update()
 				timeout.Update()
-				response = truncateDNSMessage(response, 512) // TODO: add an option to custom UDP buffer size
-				responseBuffer := buf.NewSize(dns.FixedPacketSize)
-				responseBuffer.Resize(1024, 0)
-				n, err := response.PackBuffer(responseBuffer.FreeBytes())
+				responseBuffer, err := dns.TruncateDNSMessage(&message, response, 1024)
 				if err != nil {
 				if err != nil {
 					cancel(err)
 					cancel(err)
-					responseBuffer.Release()
 					return err
 					return err
 				}
 				}
-				responseBuffer.Truncate(len(n))
 				err = conn.WritePacket(responseBuffer, destination)
 				err = conn.WritePacket(responseBuffer, destination)
 				if err != nil {
 				if err != nil {
 					cancel(err)
 					cancel(err)
@@ -264,22 +253,3 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa
 	})
 	})
 	return group.Run(fastClose)
 	return group.Run(fastClose)
 }
 }
-
-func truncateDNSMessage(response *mDNS.Msg, maxLen int) *mDNS.Msg {
-	responseLen := response.Len()
-	if responseLen <= maxLen {
-		return response
-	}
-	newResponse := *response
-	response = &newResponse
-	for len(response.Answer) > 0 && responseLen > maxLen {
-		response.Answer = response.Answer[:len(response.Answer)-1]
-		response.Truncated = true
-		responseLen = response.Len()
-	}
-	if responseLen > maxLen {
-		response.Ns = nil
-		response.Extra = nil
-	}
-	return response
-}