| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297 |
- //go:build linux
- package resolved
- import (
- "context"
- "net/netip"
- "os"
- "strings"
- "sync"
- "sync/atomic"
- "time"
- "github.com/sagernet/sing-box/adapter"
- "github.com/sagernet/sing-box/common/dialer"
- "github.com/sagernet/sing-box/common/tls"
- C "github.com/sagernet/sing-box/constant"
- "github.com/sagernet/sing-box/dns"
- "github.com/sagernet/sing-box/dns/transport"
- "github.com/sagernet/sing-box/log"
- "github.com/sagernet/sing-box/option"
- "github.com/sagernet/sing/common"
- E "github.com/sagernet/sing/common/exceptions"
- "github.com/sagernet/sing/common/logger"
- M "github.com/sagernet/sing/common/metadata"
- "github.com/sagernet/sing/service"
- mDNS "github.com/miekg/dns"
- )
- func RegisterTransport(registry *dns.TransportRegistry) {
- dns.RegisterTransport[option.ResolvedDNSServerOptions](registry, C.TypeResolved, NewTransport)
- }
- var _ adapter.DNSTransport = (*Transport)(nil)
- type Transport struct {
- dns.TransportAdapter
- ctx context.Context
- logger logger.ContextLogger
- serviceTag string
- acceptDefaultResolvers bool
- ndots int
- timeout time.Duration
- attempts int
- rotate bool
- service *Service
- linkAccess sync.RWMutex
- linkServers map[*TransportLink]*LinkServers
- }
- type LinkServers struct {
- Link *TransportLink
- Servers []adapter.DNSTransport
- serverOffset uint32
- }
- func (c *LinkServers) ServerOffset(rotate bool) uint32 {
- if rotate {
- return atomic.AddUint32(&c.serverOffset, 1) - 1
- }
- return 0
- }
- func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.ResolvedDNSServerOptions) (adapter.DNSTransport, error) {
- return &Transport{
- TransportAdapter: dns.NewTransportAdapter(C.DNSTypeDHCP, tag, nil),
- ctx: ctx,
- logger: logger,
- serviceTag: options.Service,
- acceptDefaultResolvers: options.AcceptDefaultResolvers,
- // ndots: options.NDots,
- // timeout: time.Duration(options.Timeout),
- // attempts: options.Attempts,
- // rotate: options.Rotate,
- ndots: 1,
- timeout: 5 * time.Second,
- attempts: 2,
- linkServers: make(map[*TransportLink]*LinkServers),
- }, nil
- }
- func (t *Transport) Start(stage adapter.StartStage) error {
- if stage != adapter.StartStateInitialize {
- return nil
- }
- serviceManager := service.FromContext[adapter.ServiceManager](t.ctx)
- service, loaded := serviceManager.Get(t.serviceTag)
- if !loaded {
- return E.New("service not found: ", t.serviceTag)
- }
- resolvedInbound, isResolved := service.(*Service)
- if !isResolved {
- return E.New("service is not resolved: ", t.serviceTag)
- }
- resolvedInbound.updateCallback = t.updateTransports
- resolvedInbound.deleteCallback = t.deleteTransport
- t.service = resolvedInbound
- return nil
- }
- func (t *Transport) Close() error {
- t.linkAccess.RLock()
- defer t.linkAccess.RUnlock()
- for _, servers := range t.linkServers {
- for _, server := range servers.Servers {
- server.Close()
- }
- }
- return nil
- }
- func (t *Transport) updateTransports(link *TransportLink) error {
- t.linkAccess.Lock()
- defer t.linkAccess.Unlock()
- if servers, loaded := t.linkServers[link]; loaded {
- for _, server := range servers.Servers {
- server.Close()
- }
- }
- serverDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{
- BindInterface: link.iif.Name,
- UDPFragmentDefault: true,
- }))
- var transports []adapter.DNSTransport
- for _, address := range link.address {
- serverAddr, ok := netip.AddrFromSlice(address.Address)
- if !ok {
- return os.ErrInvalid
- }
- if link.dnsOverTLS {
- tlsConfig := common.Must1(tls.NewClient(t.ctx, t.logger, serverAddr.String(), option.OutboundTLSOptions{
- Enabled: true,
- ServerName: serverAddr.String(),
- }))
- transports = append(transports, transport.NewTLSRaw(t.logger, t.TransportAdapter, serverDialer, M.SocksaddrFrom(serverAddr, 53), tlsConfig))
- } else {
- transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, M.SocksaddrFrom(serverAddr, 53)))
- }
- }
- for _, address := range link.addressEx {
- serverAddr, ok := netip.AddrFromSlice(address.Address)
- if !ok {
- return os.ErrInvalid
- }
- if link.dnsOverTLS {
- var serverName string
- if address.Name != "" {
- serverName = address.Name
- } else {
- serverName = serverAddr.String()
- }
- tlsConfig := common.Must1(tls.NewClient(t.ctx, t.logger, serverAddr.String(), option.OutboundTLSOptions{
- Enabled: true,
- ServerName: serverName,
- }))
- transports = append(transports, transport.NewTLSRaw(t.logger, t.TransportAdapter, serverDialer, M.SocksaddrFrom(serverAddr, address.Port), tlsConfig))
- } else {
- transports = append(transports, transport.NewUDPRaw(t.logger, t.TransportAdapter, serverDialer, M.SocksaddrFrom(serverAddr, address.Port)))
- }
- }
- t.linkServers[link] = &LinkServers{
- Link: link,
- Servers: transports,
- }
- return nil
- }
- func (t *Transport) deleteTransport(link *TransportLink) {
- t.linkAccess.Lock()
- defer t.linkAccess.Unlock()
- servers, loaded := t.linkServers[link]
- if !loaded {
- return
- }
- for _, server := range servers.Servers {
- server.Close()
- }
- delete(t.linkServers, link)
- }
- func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
- question := message.Question[0]
- var selectedLink *TransportLink
- t.service.linkAccess.RLock()
- for _, link := range t.service.links {
- for _, domain := range link.domain {
- if domain.Domain == "." && domain.RoutingOnly && !t.acceptDefaultResolvers {
- continue
- }
- if strings.HasSuffix(question.Name, domain.Domain) {
- selectedLink = link
- }
- }
- }
- if selectedLink == nil && t.acceptDefaultResolvers {
- for l := len(t.service.defaultRouteSequence); l > 0; l-- {
- selectedLink = t.service.links[t.service.defaultRouteSequence[l-1]]
- if len(selectedLink.address) > 0 || len(selectedLink.addressEx) > 0 {
- break
- }
- }
- }
- t.service.linkAccess.RUnlock()
- if selectedLink == nil {
- return dns.FixedResponseStatus(message, mDNS.RcodeNameError), nil
- }
- t.linkAccess.RLock()
- servers := t.linkServers[selectedLink]
- t.linkAccess.RUnlock()
- if len(servers.Servers) == 0 {
- return dns.FixedResponseStatus(message, mDNS.RcodeNameError), nil
- }
- if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
- return t.exchangeParallel(ctx, servers, message)
- } else {
- return t.exchangeSingleRequest(ctx, servers, message)
- }
- }
- func (t *Transport) exchangeSingleRequest(ctx context.Context, servers *LinkServers, message *mDNS.Msg) (*mDNS.Msg, error) {
- var lastErr error
- for _, fqdn := range servers.Link.nameList(t.ndots, message.Question[0].Name) {
- response, err := t.tryOneName(ctx, servers, message, fqdn)
- if err != nil {
- lastErr = err
- continue
- }
- return response, nil
- }
- return nil, lastErr
- }
- func (t *Transport) tryOneName(ctx context.Context, servers *LinkServers, message *mDNS.Msg, fqdn string) (*mDNS.Msg, error) {
- serverOffset := servers.ServerOffset(t.rotate)
- sLen := uint32(len(servers.Servers))
- var lastErr error
- for i := 0; i < t.attempts; i++ {
- for j := uint32(0); j < sLen; j++ {
- server := servers.Servers[(serverOffset+j)%sLen]
- question := message.Question[0]
- question.Name = fqdn
- exchangeMessage := *message
- exchangeMessage.Question = []mDNS.Question{question}
- exchangeCtx, cancel := context.WithTimeout(ctx, t.timeout)
- response, err := server.Exchange(exchangeCtx, &exchangeMessage)
- cancel()
- if err != nil {
- lastErr = err
- continue
- }
- return response, nil
- }
- }
- return nil, E.Cause(lastErr, fqdn)
- }
- func (t *Transport) exchangeParallel(ctx context.Context, servers *LinkServers, message *mDNS.Msg) (*mDNS.Msg, error) {
- returned := make(chan struct{})
- defer close(returned)
- type queryResult struct {
- response *mDNS.Msg
- err error
- }
- results := make(chan queryResult)
- startRacer := func(ctx context.Context, fqdn string) {
- response, err := t.tryOneName(ctx, servers, message, fqdn)
- select {
- case results <- queryResult{response, err}:
- case <-returned:
- }
- }
- queryCtx, queryCancel := context.WithCancel(ctx)
- defer queryCancel()
- var nameCount int
- for _, fqdn := range servers.Link.nameList(t.ndots, message.Question[0].Name) {
- nameCount++
- go startRacer(queryCtx, fqdn)
- }
- var errors []error
- for {
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case result := <-results:
- if result.err == nil {
- return result.response, nil
- }
- errors = append(errors, result.err)
- if len(errors) == nameCount {
- return nil, E.Errors(errors...)
- }
- }
- }
- }
|