瀏覽代碼

DNS: Retry with EDNS0 when response is truncated (#4516)

风扇滑翔翼 10 月之前
父節點
當前提交
86a225cda1
共有 3 個文件被更改,包括 49 次插入15 次删除
  1. 9 7
      app/dns/dnscommon.go
  2. 5 3
      app/dns/dnscommon_test.go
  3. 35 5
      app/dns/nameserver_udp.go

+ 9 - 7
app/dns/dnscommon.go

@@ -31,10 +31,11 @@ type record struct {
 
 // IPRecord is a cacheable item for a resolved domain
 type IPRecord struct {
-	ReqID  uint16
-	IP     []net.Address
-	Expire time.Time
-	RCode  dnsmessage.RCode
+	ReqID     uint16
+	IP        []net.Address
+	Expire    time.Time
+	RCode     dnsmessage.RCode
+	RawHeader *dnsmessage.Header
 }
 
 func (r *IPRecord) getIPs() ([]net.Address, error) {
@@ -179,9 +180,10 @@ func parseResponse(payload []byte) (*IPRecord, error) {
 
 	now := time.Now()
 	ipRecord := &IPRecord{
-		ReqID:  h.ID,
-		RCode:  h.RCode,
-		Expire: now.Add(time.Second * 600),
+		ReqID:     h.ID,
+		RCode:     h.RCode,
+		Expire:    now.Add(time.Second * 600),
+		RawHeader: &h,
 	}
 
 L:

+ 5 - 3
app/dns/dnscommon_test.go

@@ -51,7 +51,7 @@ func Test_parseResponse(t *testing.T) {
 	}{
 		{
 			"empty",
-			&IPRecord{0, []net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess},
+			&IPRecord{0, []net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess, nil},
 			false,
 		},
 		{
@@ -66,12 +66,13 @@ func Test_parseResponse(t *testing.T) {
 				[]net.Address{net.ParseAddress("8.8.8.8"), net.ParseAddress("8.8.4.4")},
 				time.Time{},
 				dnsmessage.RCodeSuccess,
+				nil,
 			},
 			false,
 		},
 		{
 			"aaaa record",
-			&IPRecord{2, []net.Address{net.ParseAddress("2001::123:8888"), net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess},
+			&IPRecord{2, []net.Address{net.ParseAddress("2001::123:8888"), net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess, nil},
 			false,
 		},
 	}
@@ -84,8 +85,9 @@ func Test_parseResponse(t *testing.T) {
 			}
 
 			if got != nil {
-				// reset the time
+				// reset the time and RawHeader
 				got.Expire = time.Time{}
+				got.RawHeader = nil
 			}
 			if cmp.Diff(got, tt.want) != "" {
 				t.Error(cmp.Diff(got, tt.want))

+ 35 - 5
app/dns/nameserver_udp.go

@@ -27,7 +27,7 @@ type ClassicNameServer struct {
 	name          string
 	address       *net.Destination
 	ips           map[string]*record
-	requests      map[uint16]*dnsRequest
+	requests      map[uint16]*udpDnsRequest
 	pub           *pubsub.Service
 	udpServer     *udp.Dispatcher
 	cleanup       *task.Periodic
@@ -35,6 +35,11 @@ type ClassicNameServer struct {
 	queryStrategy QueryStrategy
 }
 
+type udpDnsRequest struct {
+	dnsRequest
+	ctx context.Context
+}
+
 // NewClassicNameServer creates udp server object for remote resolving.
 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) *ClassicNameServer {
 	// default to 53 if unspecific
@@ -45,7 +50,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
 	s := &ClassicNameServer{
 		address:       &address,
 		ips:           make(map[string]*record),
-		requests:      make(map[uint16]*dnsRequest),
+		requests:      make(map[uint16]*udpDnsRequest),
 		pub:           pubsub.NewService(),
 		name:          strings.ToUpper(address.String()),
 		queryStrategy: queryStrategy,
@@ -101,7 +106,7 @@ func (s *ClassicNameServer) Cleanup() error {
 	}
 
 	if len(s.requests) == 0 {
-		s.requests = make(map[uint16]*dnsRequest)
+		s.requests = make(map[uint16]*udpDnsRequest)
 	}
 
 	return nil
@@ -128,6 +133,27 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 		return
 	}
 
+	// if truncated, retry with EDNS0 option(udp payload size: 1350)
+	if ipRec.RawHeader.Truncated {
+		// if already has EDNS0 option, no need to retry
+		if ok && len(req.msg.Additionals) == 0 {
+			// copy necessary meta data from original request
+			// and add EDNS0 option
+			opt := new(dnsmessage.Resource)
+			common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
+			opt.Body = &dnsmessage.OPTResource{}
+			newMsg := *req.msg
+			newReq := *req
+			newMsg.Additionals = append(newMsg.Additionals, *opt)
+			newMsg.ID = s.newReqID()
+			newReq.msg = &newMsg
+			s.addPendingRequest(&newReq)
+			b, _ := dns.PackMessage(newReq.msg)
+			s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
+			return
+		}
+	}
+
 	var rec record
 	switch req.reqType {
 	case dnsmessage.TypeA:
@@ -179,7 +205,7 @@ func (s *ClassicNameServer) newReqID() uint16 {
 	return uint16(atomic.AddUint32(&s.reqID, 1))
 }
 
-func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {
+func (s *ClassicNameServer) addPendingRequest(req *udpDnsRequest) {
 	s.Lock()
 	defer s.Unlock()
 
@@ -194,7 +220,11 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client
 	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))
 
 	for _, req := range reqs {
-		s.addPendingRequest(req)
+		udpReq := &udpDnsRequest{
+			dnsRequest: *req,
+			ctx:        ctx,
+		}
+		s.addPendingRequest(udpReq)
 		b, _ := dns.PackMessage(req.msg)
 		s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
 	}