Browse Source

DNS outbound: Fix some issues (#5081)

patterniha 3 months ago
parent
commit
197b319f9a
1 changed files with 51 additions and 21 deletions
  1. 51 21
      proxy/dns/dns.go

+ 51 - 21
proxy/dns/dns.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	go_errors "errors"
 	"io"
+	"strings"
 	"sync"
 	"time"
 
@@ -168,11 +169,15 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 	}
 
 	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout)
+	terminate := func() {
+		cancel()
+		conn.Close()
+	}
+	timer := signal.CancelAfterInactivity(ctx, terminate, h.timeout)
+	defer timer.SetTimeout(0)
 
 	request := func() error {
-		defer conn.Close()
-
+		defer timer.SetTimeout(0)
 		for {
 			b, err := reader.ReadMessage()
 			if err == io.EOF {
@@ -190,24 +195,33 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 				if len(h.blockTypes) > 0 {
 					for _, blocktype := range h.blockTypes {
 						if blocktype == int32(qType) {
+							b.Release()
+							errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain)
 							if h.nonIPQuery == "reject" {
-								go h.rejectNonIPQuery(id, qType, domain, writer)
+								err := h.rejectNonIPQuery(id, qType, domain, writer)
+								if err != nil {
+									return err
+								}
 							}
-							errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain)
 							return nil
 						}
 					}
 				}
 				if isIPQuery {
-					go h.handleIPQuery(id, qType, domain, writer)
+					b.Release()
+					go h.handleIPQuery(id, qType, domain, writer, timer)
+					continue
 				}
-				if isIPQuery || h.nonIPQuery == "drop" {
+				if h.nonIPQuery == "drop" {
 					b.Release()
 					continue
 				}
 				if h.nonIPQuery == "reject" {
-					go h.rejectNonIPQuery(id, qType, domain, writer)
 					b.Release()
+					err := h.rejectNonIPQuery(id, qType, domain, writer)
+					if err != nil {
+						return err
+					}
 					continue
 				}
 			}
@@ -219,6 +233,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 	}
 
 	response := func() error {
+		defer timer.SetTimeout(0)
 		for {
 			b, err := connReader.ReadMessage()
 			if err == io.EOF {
@@ -244,7 +259,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 	return nil
 }
 
-func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
+func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter, timer *signal.ActivityTimer) {
 	var ips []net.IP
 	var err error
 
@@ -319,16 +334,21 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
 	if err != nil {
 		errors.LogInfoInner(context.Background(), err, "pack message")
 		b.Release()
-		return
+		timer.SetTimeout(0)
 	}
 	b.Resize(0, int32(len(msgBytes)))
 
 	if err := writer.WriteMessage(b); err != nil {
 		errors.LogInfoInner(context.Background(), err, "write IP answer")
+		timer.SetTimeout(0)
 	}
 }
 
-func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
+func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) error {
+	domainT := strings.TrimSuffix(domain, ".")
+	if domainT == "" {
+		return errors.New("empty domain name")
+	}
 	b := buf.New()
 	rawBytes := b.Extend(buf.Size)
 	builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
@@ -349,20 +369,22 @@ func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain stri
 	if err != nil {
 		errors.LogInfo(context.Background(), "unexpected domain ", domain, " when building reject message: ", err)
 		b.Release()
-		return
+		return err
 	}
 
 	msgBytes, err := builder.Finish()
 	if err != nil {
 		errors.LogInfoInner(context.Background(), err, "pack reject message")
 		b.Release()
-		return
+		return err
 	}
 	b.Resize(0, int32(len(msgBytes)))
 
 	if err := writer.WriteMessage(b); err != nil {
 		errors.LogInfoInner(context.Background(), err, "write reject answer")
+		return err
 	}
+	return nil
 }
 
 type outboundConn struct {
@@ -371,6 +393,7 @@ type outboundConn struct {
 
 	conn      net.Conn
 	connReady chan struct{}
+	closed    bool
 }
 
 func (c *outboundConn) dial() error {
@@ -385,12 +408,16 @@ func (c *outboundConn) dial() error {
 
 func (c *outboundConn) Write(b []byte) (int, error) {
 	c.access.Lock()
+	if c.closed {
+		c.access.Unlock()
+		return 0, errors.New("outbound connection closed")
+	}
 
 	if c.conn == nil {
 		if err := c.dial(); err != nil {
 			c.access.Unlock()
 			errors.LogWarningInner(context.Background(), err, "failed to dial outbound connection")
-			return len(b), nil
+			return 0, err
 		}
 	}
 
@@ -400,24 +427,27 @@ func (c *outboundConn) Write(b []byte) (int, error) {
 }
 
 func (c *outboundConn) Read(b []byte) (int, error) {
-	var conn net.Conn
 	c.access.Lock()
-	conn = c.conn
-	c.access.Unlock()
+	if c.closed {
+		c.access.Unlock()
+		return 0, io.EOF
+	}
 
-	if conn == nil {
+	if c.conn == nil {
+		c.access.Unlock()
 		_, open := <-c.connReady
 		if !open {
 			return 0, io.EOF
 		}
-		conn = c.conn
+		return c.conn.Read(b)
 	}
-
-	return conn.Read(b)
+	c.access.Unlock()
+	return c.conn.Read(b)
 }
 
 func (c *outboundConn) Close() error {
 	c.access.Lock()
+	c.closed = true
 	close(c.connReady)
 	if c.conn != nil {
 		c.conn.Close()