Explorar o código

Improve udp dns close

世界 %!s(int64=3) %!d(string=hai) anos
pai
achega
9e9e6f7ee6
Modificáronse 2 ficheiros con 57 adicións e 29 borrados
  1. 1 1
      inbound/default.go
  2. 56 28
      outbound/dns.go

+ 1 - 1
inbound/default.go

@@ -325,7 +325,7 @@ func (a *myInboundAdapter) NewError(ctx context.Context, err error) {
 func NewError(logger log.ContextLogger, ctx context.Context, err error) {
 	common.Close(err)
 	if E.IsClosedOrCanceled(err) {
-		logger.DebugContext(ctx, "connection closed")
+		logger.TraceContext(ctx, "connection closed: ", err)
 		return
 	}
 	logger.ErrorContext(ctx, err)

+ 56 - 28
outbound/dns.go

@@ -6,6 +6,8 @@ import (
 	"io"
 	"net"
 	"os"
+	"sync"
+	"time"
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
@@ -14,6 +16,7 @@ import (
 	"github.com/sagernet/sing/common/buf"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/task"
 
 	"golang.org/x/net/dns/dnsmessage"
 )
@@ -45,6 +48,7 @@ func (d *DNS) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.Pa
 }
 
 func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
+	defer conn.Close()
 	ctx = adapter.WithContext(ctx, &metadata)
 	_buffer := buf.StackNewSize(1024)
 	defer common.KeepAlive(_buffer)
@@ -97,45 +101,69 @@ func (d *DNS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter
 }
 
 func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
+	defer conn.Close()
 	ctx = adapter.WithContext(ctx, &metadata)
 	_buffer := buf.StackNewSize(1024)
 	defer common.KeepAlive(_buffer)
 	buffer := common.Dup(_buffer)
 	defer buffer.Release()
-	for {
-		buffer.FullReset()
-		destination, err := conn.ReadPacket(buffer)
-		if err != nil {
-			return err
-		}
-		var message dnsmessage.Message
-		err = message.Unpack(buffer.Bytes())
-		if err != nil {
-			return err
-		}
-		if len(message.Questions) > 0 {
-			question := message.Questions[0]
-			metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
-			d.logger.DebugContext(ctx, "inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
-		}
-		go func() error {
-			response, err := d.router.Exchange(ctx, &message)
+	var wg sync.WaitGroup
+	fastClose, cancel := context.WithCancel(ctx)
+	err := task.Run(fastClose, func() error {
+		var count int
+		for {
+			buffer.FullReset()
+			destination, err := conn.ReadPacket(buffer)
 			if err != nil {
 				return err
 			}
-			_responseBuffer := buf.StackNewSize(1024)
-			defer common.KeepAlive(_responseBuffer)
-			responseBuffer := common.Dup(_responseBuffer)
-			defer responseBuffer.Release()
-			n, err := response.AppendPack(responseBuffer.Index(0))
+			var message dnsmessage.Message
+			err = message.Unpack(buffer.Bytes())
 			if err != nil {
 				return err
 			}
-			responseBuffer.Truncate(len(n))
-			err = conn.WritePacket(responseBuffer, destination)
-			return err
-		}()
-	}
+			if len(message.Questions) > 0 {
+				question := message.Questions[0]
+				metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
+				d.logger.DebugContext(ctx, "inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
+			}
+			wg.Add(1)
+			go func() error {
+				defer wg.Done()
+				response, err := d.router.Exchange(ctx, &message)
+				if err != nil {
+					return err
+				}
+				_responseBuffer := buf.StackNewSize(1024)
+				defer common.KeepAlive(_responseBuffer)
+				responseBuffer := common.Dup(_responseBuffer)
+				defer responseBuffer.Release()
+				n, err := response.AppendPack(responseBuffer.Index(0))
+				if err != nil {
+					return err
+				}
+				responseBuffer.Truncate(len(n))
+				err = conn.WritePacket(responseBuffer, destination)
+				return err
+			}()
+			count++
+			if count == 2 {
+				break
+			}
+		}
+		cancel()
+		return nil
+	}, func() error {
+		timer := time.NewTimer(5 * time.Second)
+		select {
+		case <-timer.C:
+			cancel()
+		case <-fastClose.Done():
+		}
+		return nil
+	})
+	wg.Wait()
+	return err
 }
 
 func formatDNSQuestion(question dnsmessage.Question) string {