Quellcode durchsuchen

DNS: Add new nonIPQuery "reject" (#4824)

风扇滑翔翼 vor 7 Monaten
Ursprung
Commit
38ed2cc387
2 geänderte Dateien mit 41 neuen und 1 gelöschten Zeilen
  1. 1 1
      infra/conf/dns_proxy.go
  2. 40 0
      proxy/dns/dns.go

+ 1 - 1
infra/conf/dns_proxy.go

@@ -30,7 +30,7 @@ func (c *DNSOutboundConfig) Build() (proto.Message, error) {
 	switch c.NonIPQuery {
 	case "":
 		c.NonIPQuery = "drop"
-	case "drop", "skip":
+	case "drop", "skip", "reject":
 	default:
 		return nil, errors.New(`unknown "nonIPQuery": `, c.NonIPQuery)
 	}

+ 40 - 0
proxy/dns/dns.go

@@ -187,6 +187,9 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 				if len(h.blockTypes) > 0 {
 					for _, blocktype := range h.blockTypes {
 						if blocktype == int32(qType) {
+							if h.nonIPQuery == "reject" {
+								go h.rejectNonIPQuery(id, qType, domain, writer)
+							}
 							errors.LogInfo(ctx, "blocked type ", qType, " query for domain ", domain)
 							return nil
 						}
@@ -199,6 +202,11 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 					b.Release()
 					continue
 				}
+				if h.nonIPQuery == "reject" {
+					go h.rejectNonIPQuery(id, qType, domain, writer)
+					b.Release()
+					continue
+				}
 			}
 
 			if err := connWriter.WriteMessage(b); err != nil {
@@ -317,6 +325,38 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
 	}
 }
 
+func (h *Handler) rejectNonIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
+	b := buf.New()
+	rawBytes := b.Extend(buf.Size)
+	builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
+		ID:                 id,
+		RCode:              dnsmessage.RCodeRefused,
+		RecursionAvailable: true,
+		RecursionDesired:   true,
+		Response:           true,
+		Authoritative:      true,
+	})
+	builder.EnableCompression()
+	common.Must(builder.StartQuestions())
+	common.Must(builder.Question(dnsmessage.Question{
+		Name:  dnsmessage.MustNewName(domain),
+		Class: dnsmessage.ClassINET,
+		Type:  qType,
+	}))
+
+	msgBytes, err := builder.Finish()
+	if err != nil {
+		errors.LogInfoInner(context.Background(), err, "pack reject message")
+		b.Release()
+		return
+	}
+	b.Resize(0, int32(len(msgBytes)))
+
+	if err := writer.WriteMessage(b); err != nil {
+		errors.LogInfoInner(context.Background(), err, "write reject answer")
+	}
+}
+
 type outboundConn struct {
 	access sync.Mutex
 	dialer func() (stat.Connection, error)