소스 검색

Add skipFakeDNS to inbound session

世界 4 년 전
부모
커밋
14aa152a8a
3개의 변경된 파일19개의 추가작업 그리고 4개의 파일을 삭제
  1. 2 0
      common/session/session.go
  2. 7 0
      features/routing/session/context.go
  3. 10 4
      proxy/dns/dns.go

+ 2 - 0
common/session/session.go

@@ -53,6 +53,8 @@ type Inbound struct {
 	Uid uint32
 	// SagerNet private: AppStatus is the android app's status for the inbound connection
 	AppStatus []string
+	// SagerNet private
+	SkipFakeDNS bool
 }
 
 // Outbound is the metadata of an outbound connection.

+ 7 - 0
features/routing/session/context.go

@@ -137,6 +137,13 @@ func (ctx *Context) GetAppStatus() []string {
 	return ctx.Inbound.AppStatus
 }
 
+func (ctx Context) GetSkipFakeDNS() bool {
+	if ctx.Inbound == nil {
+		return false
+	}
+	return ctx.Inbound.SkipFakeDNS
+}
+
 // AsRoutingContext creates a context from context.context with session info.
 func AsRoutingContext(ctx context.Context) routing.Context {
 	return &Context{

+ 10 - 4
proxy/dns/dns.go

@@ -96,6 +96,12 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 		return newError("invalid outbound")
 	}
 
+	fakeDNS := true
+	inbound := session.InboundFromContext(ctx)
+	if inbound != nil && inbound.SkipFakeDNS {
+		fakeDNS = false
+	}
+
 	srcNetwork := outbound.Target.Network
 
 	dest := outbound.Target
@@ -171,7 +177,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 			if !h.isOwnLink(ctx) {
 				isIPQuery, domain, id, qType := parseIPQuery(b.Bytes())
 				if isIPQuery {
-					go h.handleIPQuery(id, qType, domain, writer)
+					go h.handleIPQuery(id, qType, domain, writer, fakeDNS)
 					continue
 				}
 			}
@@ -208,7 +214,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
 	return nil
 }
 
-func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) {
+func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter, fakedns bool) {
 	var ips []net.IP
 	var err error
 
@@ -219,13 +225,13 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
 		ips, err = h.client.LookupIP(domain, dns.IPOption{
 			IPv4Enable: true,
 			IPv6Enable: false,
-			FakeEnable: true,
+			FakeEnable: fakedns,
 		})
 	case dnsmessage.TypeAAAA:
 		ips, err = h.client.LookupIP(domain, dns.IPOption{
 			IPv4Enable: false,
 			IPv6Enable: true,
-			FakeEnable: true,
+			FakeEnable: fakedns,
 		})
 	}