Browse Source

Add hijack_dns for tun

世界 3 years ago
parent
commit
29f78248dc

+ 6 - 27
adapter/upstream.go

@@ -51,27 +51,6 @@ func (w *myUpstreamHandlerWrapper) NewError(ctx context.Context, err error) {
 	w.errorHandler.NewError(ctx, err)
 }
 
-var myContextType = (*MetadataContext)(nil)
-
-type MetadataContext struct {
-	context.Context
-	Metadata InboundContext
-}
-
-func (c *MetadataContext) Value(key any) any {
-	if key == myContextType {
-		return c
-	}
-	return c.Context.Value(key)
-}
-
-func ContextWithMetadata(ctx context.Context, metadata InboundContext) context.Context {
-	return &MetadataContext{
-		Context:  ctx,
-		Metadata: metadata,
-	}
-}
-
 func UpstreamMetadata(metadata InboundContext) M.Metadata {
 	return M.Metadata{
 		Source:      metadata.Source,
@@ -98,15 +77,15 @@ func NewUpstreamContextHandler(
 }
 
 func (w *myUpstreamContextHandlerWrapper) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
-	myCtx := ctx.Value(myContextType).(*MetadataContext)
-	myCtx.Metadata.Destination = metadata.Destination
-	return w.connectionHandler(ctx, conn, myCtx.Metadata)
+	myMetadata := ContextFrom(ctx)
+	myMetadata.Destination = metadata.Destination
+	return w.connectionHandler(ctx, conn, *myMetadata)
 }
 
 func (w *myUpstreamContextHandlerWrapper) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
-	myCtx := ctx.Value(myContextType).(*MetadataContext)
-	myCtx.Metadata.Destination = metadata.Destination
-	return w.packetHandler(ctx, conn, myCtx.Metadata)
+	myMetadata := ContextFrom(ctx)
+	myMetadata.Destination = metadata.Destination
+	return w.packetHandler(ctx, conn, *myMetadata)
 }
 
 func (w *myUpstreamContextHandlerWrapper) NewError(ctx context.Context, err error) {

+ 1 - 1
common/dialer/dialer.go

@@ -21,7 +21,7 @@ func New(router adapter.Router, options option.DialerOptions) N.Dialer {
 func NewOutbound(router adapter.Router, options option.OutboundDialerOptions) N.Dialer {
 	dialer := New(router, options.DialerOptions)
 	domainStrategy := C.DomainStrategy(options.DomainStrategy)
-	if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" && !C.CGO_ENABLED {
+	if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" {
 		fallbackDelay := time.Duration(options.FallbackDelay)
 		if fallbackDelay == 0 {
 			fallbackDelay = time.Millisecond * 300

+ 6 - 0
common/dialer/resolve.go

@@ -32,6 +32,9 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina
 	if !destination.IsFqdn() {
 		return d.dialer.DialContext(ctx, network, destination)
 	}
+	ctx, metadata := adapter.AppendContext(ctx)
+	metadata.Destination = destination
+	metadata.Domain = ""
 	var addresses []netip.Addr
 	var err error
 	if d.strategy == C.DomainStrategyAsIS {
@@ -49,6 +52,9 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
 	if !destination.IsFqdn() {
 		return d.dialer.ListenPacket(ctx, destination)
 	}
+	ctx, metadata := adapter.AppendContext(ctx)
+	metadata.Destination = destination
+	metadata.Domain = ""
 	var addresses []netip.Addr
 	var err error
 	if d.strategy == C.DomainStrategyAsIS {

+ 3 - 0
dns/transport.go

@@ -20,6 +20,9 @@ func NewTransport(ctx context.Context, dialer N.Dialer, logger log.Logger, addre
 		return nil, err
 	}
 	host := serverURL.Hostname()
+	if host == "" {
+		host = address
+	}
 	port := serverURL.Port()
 	switch serverURL.Scheme {
 	case "tls":

+ 1 - 1
dns/transport_tcp.go

@@ -87,7 +87,7 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) {
 		}
 	})
 	conn.err = err
-	if err != nil {
+	if err != nil && !E.IsClosed(err) {
 		t.logger.Debug("connection closed: ", err)
 	}
 }

+ 1 - 1
dns/transport_tls.go

@@ -95,7 +95,7 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) {
 		}
 	})
 	conn.err = err
-	if err != nil {
+	if err != nil && !E.IsClosed(err) {
 		t.logger.Debug("connection closed: ", err)
 	}
 }

+ 2 - 1
dns/transport_udp.go

@@ -8,6 +8,7 @@ import (
 	"github.com/sagernet/sing-box/log"
 	"github.com/sagernet/sing/common"
 	"github.com/sagernet/sing/common/buf"
+	E "github.com/sagernet/sing/common/exceptions"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/common/task"
@@ -83,7 +84,7 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) {
 		}
 	})
 	conn.err = err
-	if err != nil {
+	if err != nil && !E.IsClosed(err) {
 		t.logger.Debug("connection closed: ", err)
 	}
 }

+ 1 - 1
inbound/direct.go

@@ -79,6 +79,6 @@ func (d *Direct) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.B
 	case 3:
 		metadata.Destination.Port = d.overrideDestination.Port
 	}
-	d.udpNat.NewPacketDirect(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), metadata.Source.AddrPort(), conn, buffer, adapter.UpstreamMetadata(metadata))
+	d.udpNat.NewPacketDirect(adapter.WithContext(log.ContextWithID(ctx), &metadata), metadata.Source.AddrPort(), conn, buffer, adapter.UpstreamMetadata(metadata))
 	return nil
 }

+ 107 - 0
inbound/dns.go

@@ -0,0 +1,107 @@
+package inbound
+
+import (
+	"context"
+	"encoding/binary"
+	"io"
+	"net"
+
+	"github.com/sagernet/sing-box/adapter"
+	"github.com/sagernet/sing-box/log"
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
+	N "github.com/sagernet/sing/common/network"
+
+	"golang.org/x/net/dns/dnsmessage"
+)
+
+func NewDNSConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn net.Conn, metadata adapter.InboundContext) error {
+	_buffer := buf.StackNewSize(1024)
+	defer common.KeepAlive(_buffer)
+	buffer := common.Dup(_buffer)
+	defer buffer.Release()
+	for {
+		var queryLength uint16
+		err := binary.Read(conn, binary.BigEndian, &queryLength)
+		if err != nil {
+			return err
+		}
+		if queryLength > 1024 {
+			return io.ErrShortBuffer
+		}
+		buffer.FullReset()
+		_, err = buffer.ReadFullFrom(conn, int(queryLength))
+		if err != nil {
+			return err
+		}
+		var message dnsmessage.Message
+		err = message.Unpack(buffer.Bytes())
+		if err != nil {
+			return err
+		}
+		if len(message.Questions) > 0 {
+			question := message.Questions[0]
+			metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
+			logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
+		}
+		response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
+		if err != nil {
+			return err
+		}
+		buffer.FullReset()
+		responseBuffer, err := response.AppendPack(buffer.Index(0))
+		if err != nil {
+			return err
+		}
+		err = binary.Write(conn, binary.BigEndian, uint16(len(responseBuffer)))
+		if err != nil {
+			return err
+		}
+		_, err = conn.Write(responseBuffer)
+		if err != nil {
+			return err
+		}
+	}
+}
+
+func NewDNSPacketConnection(ctx context.Context, router adapter.Router, logger log.Logger, conn N.PacketConn, metadata adapter.InboundContext) error {
+	for {
+		buffer := buf.StackNewSize(1024)
+		destination, err := conn.ReadPacket(buffer)
+		if err != nil {
+			buffer.Release()
+			return err
+		}
+		var message dnsmessage.Message
+		err = message.Unpack(buffer.Bytes())
+		if err != nil {
+			return err
+		}
+		if len(message.Questions) > 0 {
+			question := message.Questions[0]
+			metadata.Domain = string(question.Name.Data[:question.Name.Length-1])
+			logger.WithContext(ctx).Debug("inbound dns query ", formatDNSQuestion(question), " from ", metadata.Source)
+		}
+		go func() error {
+			defer buffer.Release()
+			response, err := router.Exchange(adapter.WithContext(ctx, &metadata), &message)
+			if err != nil {
+				return err
+			}
+			buffer.FullReset()
+			responseBuffer, err := response.AppendPack(buffer.Index(0))
+			if err != nil {
+				return err
+			}
+			buffer.Truncate(len(responseBuffer))
+			err = conn.WritePacket(buffer, destination)
+			return err
+		}()
+	}
+}
+
+func formatDNSQuestion(question dnsmessage.Question) string {
+	domain := question.Name.String()
+	domain = domain[:len(domain)-1]
+	return string(question.Name.Data[:question.Name.Length-1]) + " " + question.Type.String()[4:] + " " + question.Class.String()[5:]
+}

+ 2 - 2
inbound/shadowsocks.go

@@ -73,9 +73,9 @@ func newShadowsocks(ctx context.Context, router adapter.Router, logger log.Logge
 }
 
 func (h *Shadowsocks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
-	return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
+	return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
 }
 
 func (h *Shadowsocks) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
-	return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
+	return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
 }

+ 2 - 2
inbound/shadowsocks_multi.go

@@ -68,11 +68,11 @@ func newShadowsocksMulti(ctx context.Context, router adapter.Router, logger log.
 }
 
 func (h *ShadowsocksMulti) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
-	return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
+	return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
 }
 
 func (h *ShadowsocksMulti) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
-	return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
+	return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
 }
 
 func (h *ShadowsocksMulti) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {

+ 2 - 2
inbound/shadowsocks_relay.go

@@ -68,11 +68,11 @@ func newShadowsocksRelay(ctx context.Context, router adapter.Router, logger log.
 }
 
 func (h *ShadowsocksRelay) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
-	return h.service.NewConnection(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, adapter.UpstreamMetadata(metadata))
+	return h.service.NewConnection(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata))
 }
 
 func (h *ShadowsocksRelay) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata adapter.InboundContext) error {
-	return h.service.NewPacket(adapter.ContextWithMetadata(log.ContextWithID(ctx), metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
+	return h.service.NewPacket(adapter.WithContext(log.ContextWithID(ctx), &metadata), conn, buffer, adapter.UpstreamMetadata(metadata))
 }
 
 func (h *ShadowsocksRelay) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {

+ 58 - 40
inbound/tun.go

@@ -20,6 +20,7 @@ import (
 	F "github.com/sagernet/sing/common/format"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/task"
 )
 
 var _ adapter.Inbound = (*Tun)(nil)
@@ -27,23 +28,42 @@ var _ adapter.Inbound = (*Tun)(nil)
 type Tun struct {
 	tag string
 
-	ctx     context.Context
-	router  adapter.Router
-	logger  log.Logger
-	options option.TunInboundOptions
-
-	tunName string
-	tunFd   uintptr
-	tun     *tun.GVisorTun
+	ctx            context.Context
+	router         adapter.Router
+	logger         log.Logger
+	inboundOptions option.InboundOptions
+	tunName        string
+	tunMTU         uint32
+	inet4Address   netip.Prefix
+	inet6Address   netip.Prefix
+	autoRoute      bool
+	hijackDNS      bool
+
+	tunFd uintptr
+	tun   *tun.GVisorTun
 }
 
 func NewTun(ctx context.Context, router adapter.Router, logger log.Logger, tag string, options option.TunInboundOptions) (*Tun, error) {
+	tunName := options.InterfaceName
+	if tunName == "" {
+		tunName = mkInterfaceName()
+	}
+	tunMTU := options.MTU
+	if tunMTU == 0 {
+		tunMTU = 1500
+	}
 	return &Tun{
-		tag:     tag,
-		ctx:     ctx,
-		router:  router,
-		logger:  logger,
-		options: options,
+		tag:            tag,
+		ctx:            ctx,
+		router:         router,
+		logger:         logger,
+		inboundOptions: options.InboundOptions,
+		tunName:        tunName,
+		tunMTU:         tunMTU,
+		inet4Address:   netip.Prefix(options.Inet4Address),
+		inet6Address:   netip.Prefix(options.Inet6Address),
+		autoRoute:      options.AutoRoute,
+		hijackDNS:      options.HijackDNS,
 	}, nil
 }
 
@@ -56,38 +76,26 @@ func (t *Tun) Tag() string {
 }
 
 func (t *Tun) Start() error {
-	tunName := t.options.InterfaceName
-	if tunName == "" {
-		tunName = mkInterfaceName()
-	}
-	var mtu uint32
-	if t.options.MTU != 0 {
-		mtu = t.options.MTU
-	} else {
-		mtu = 1500
-	}
-
-	tunFd, err := tun.Open(tunName)
+	tunFd, err := tun.Open(t.tunName)
 	if err != nil {
 		return E.Cause(err, "create tun interface")
 	}
-	err = tun.Configure(tunName, netip.Prefix(t.options.Inet4Address), netip.Prefix(t.options.Inet6Address), mtu, t.options.AutoRoute)
+	err = tun.Configure(t.tunName, t.inet4Address, t.inet6Address, t.tunMTU, t.autoRoute)
 	if err != nil {
 		return E.Cause(err, "configure tun interface")
 	}
-	t.tunName = tunName
 	t.tunFd = tunFd
-	t.tun = tun.NewGVisor(t.ctx, tunFd, mtu, t)
+	t.tun = tun.NewGVisor(t.ctx, tunFd, t.tunMTU, t)
 	err = t.tun.Start()
 	if err != nil {
 		return err
 	}
-	t.logger.Info("started at ", tunName)
+	t.logger.Info("started at ", t.tunName)
 	return nil
 }
 
 func (t *Tun) Close() error {
-	err := tun.UnConfigure(t.tunName, netip.Prefix(t.options.Inet4Address), netip.Prefix(t.options.Inet6Address), t.options.AutoRoute)
+	err := tun.UnConfigure(t.tunName, t.inet4Address, t.inet6Address, t.autoRoute)
 	if err != nil {
 		return err
 	}
@@ -98,30 +106,40 @@ func (t *Tun) Close() error {
 }
 
 func (t *Tun) NewConnection(ctx context.Context, conn net.Conn, upstreamMetadata M.Metadata) error {
-	t.logger.WithContext(ctx).Info("inbound connection from ", upstreamMetadata.Source)
-	t.logger.WithContext(ctx).Info("inbound connection to ", upstreamMetadata.Destination)
 	var metadata adapter.InboundContext
 	metadata.Inbound = t.tag
 	metadata.Network = C.NetworkTCP
 	metadata.Source = upstreamMetadata.Source
 	metadata.Destination = upstreamMetadata.Destination
-	metadata.SniffEnabled = t.options.SniffEnabled
-	metadata.SniffOverrideDestination = t.options.SniffOverrideDestination
-	metadata.DomainStrategy = C.DomainStrategy(t.options.DomainStrategy)
+	metadata.SniffEnabled = t.inboundOptions.SniffEnabled
+	metadata.SniffOverrideDestination = t.inboundOptions.SniffOverrideDestination
+	metadata.DomainStrategy = C.DomainStrategy(t.inboundOptions.DomainStrategy)
+	if t.hijackDNS && upstreamMetadata.Destination.Port == 53 {
+		return task.Run(ctx, func() error {
+			return NewDNSConnection(ctx, t.router, t.logger, conn, metadata)
+		})
+	}
+	t.logger.WithContext(ctx).Info("inbound connection from ", metadata.Source)
+	t.logger.WithContext(ctx).Info("inbound connection to ", metadata.Destination)
 	return t.router.RouteConnection(ctx, conn, metadata)
 }
 
 func (t *Tun) NewPacketConnection(ctx context.Context, conn N.PacketConn, upstreamMetadata M.Metadata) error {
-	t.logger.WithContext(ctx).Info("inbound packet connection from ", upstreamMetadata.Source)
-	t.logger.WithContext(ctx).Info("inbound packet connection to ", upstreamMetadata.Destination)
 	var metadata adapter.InboundContext
 	metadata.Inbound = t.tag
 	metadata.Network = C.NetworkUDP
 	metadata.Source = upstreamMetadata.Source
 	metadata.Destination = upstreamMetadata.Destination
-	metadata.SniffEnabled = t.options.SniffEnabled
-	metadata.SniffOverrideDestination = t.options.SniffOverrideDestination
-	metadata.DomainStrategy = C.DomainStrategy(t.options.DomainStrategy)
+	metadata.SniffEnabled = t.inboundOptions.SniffEnabled
+	metadata.SniffOverrideDestination = t.inboundOptions.SniffOverrideDestination
+	metadata.DomainStrategy = C.DomainStrategy(t.inboundOptions.DomainStrategy)
+	if t.hijackDNS && upstreamMetadata.Destination.Port == 53 {
+		return task.Run(ctx, func() error {
+			return NewDNSPacketConnection(ctx, t.router, t.logger, conn, metadata)
+		})
+	}
+	t.logger.WithContext(ctx).Info("inbound packet connection from ", metadata.Source)
+	t.logger.WithContext(ctx).Info("inbound packet connection to ", metadata.Destination)
 	return t.router.RoutePacketConnection(ctx, conn, metadata)
 }
 

+ 6 - 5
option/inbound.go

@@ -144,10 +144,11 @@ type ShadowsocksDestination struct {
 }
 
 type TunInboundOptions struct {
-	InterfaceName string       `json:"interface_name"`
-	MTU           uint32       `json:"mtu,omitempty"`
-	Inet4Address  ListenPrefix `json:"inet4_address"`
-	Inet6Address  ListenPrefix `json:"inet6_address"`
-	AutoRoute     bool         `json:"auto_route"`
+	InterfaceName string       `json:"interface_name,omitempty"`
+	MTU           uint32       `json:"mtu,omitempty,omitempty"`
+	Inet4Address  ListenPrefix `json:"inet4_address,omitempty"`
+	Inet6Address  ListenPrefix `json:"inet6_address,omitempty"`
+	AutoRoute     bool         `json:"auto_route,omitempty"`
+	HijackDNS     bool         `json:"hijack_dns,omitempty"`
 	InboundOptions
 }

+ 23 - 17
route/router.go

@@ -9,6 +9,7 @@ import (
 	"net/url"
 	"os"
 	"path/filepath"
+	"reflect"
 	"strings"
 	"time"
 
@@ -128,23 +129,28 @@ func NewRouter(ctx context.Context, logger log.Logger, options option.RouteOptio
 			} else {
 				detour = dialer.NewDetour(router, server.Detour)
 			}
-			serverURL, err := url.Parse(server.Address)
-			if err != nil {
-				return nil, err
-			}
-			serverAddress := serverURL.Hostname()
-			_, notIpAddress := netip.ParseAddr(serverAddress)
-			if server.AddressResolver != "" {
-				if !transportTagMap[server.AddressResolver] {
-					return nil, E.New("parse dns server[", tag, "]: address resolver not found: ", server.AddressResolver)
+			if server.Address != "local" {
+				serverURL, err := url.Parse(server.Address)
+				if err != nil {
+					return nil, err
 				}
-				if upstream, exists := dummyTransportMap[server.AddressResolver]; exists {
-					detour = dns.NewDialerWrapper(detour, C.DomainStrategy(server.AddressStrategy), router.dnsClient, upstream)
-				} else {
-					continue
+				serverAddress := serverURL.Hostname()
+				if serverAddress == "" {
+					serverAddress = server.Address
+				}
+				_, notIpAddress := netip.ParseAddr(serverAddress)
+				if server.AddressResolver != "" {
+					if !transportTagMap[server.AddressResolver] {
+						return nil, E.New("parse dns server[", tag, "]: address resolver not found: ", server.AddressResolver)
+					}
+					if upstream, exists := dummyTransportMap[server.AddressResolver]; exists {
+						detour = dns.NewDialerWrapper(detour, C.DomainStrategy(server.AddressStrategy), router.dnsClient, upstream)
+					} else {
+						continue
+					}
+				} else if notIpAddress != nil {
+					return nil, E.New("parse dns server[", tag, "]: missing address_resolver")
 				}
-			} else if notIpAddress != nil {
-				return nil, E.New("parse dns server[", tag, "]: missing address_resolver")
 			}
 			transport, err := dns.NewTransport(ctx, detour, logger, server.Address)
 			if err != nil {
@@ -419,7 +425,7 @@ func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata ad
 }
 
 func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
-	if metadata.SniffEnabled && metadata.Destination.Port == 443 {
+	if metadata.SniffEnabled {
 		_buffer := buf.StackNewPacket()
 		defer common.KeepAlive(_buffer)
 		buffer := common.Dup(_buffer)
@@ -489,7 +495,7 @@ func (r *Router) match(ctx context.Context, metadata adapter.InboundContext, def
 func (r *Router) matchDNS(ctx context.Context) adapter.DNSTransport {
 	metadata := adapter.ContextFrom(ctx)
 	if metadata == nil {
-		r.dnsLogger.WithContext(ctx).Warn("no context")
+		r.dnsLogger.WithContext(ctx).Warn("no context: ", reflect.TypeOf(ctx))
 		return r.defaultTransport
 	}
 	for i, rule := range r.dnsRules {