Browse Source

Fix resolved service

世界 2 months ago
parent
commit
354ece2bdf
4 changed files with 46 additions and 26 deletions
  1. 18 9
      adapter/outbound/manager.go
  2. 4 4
      box.go
  3. 17 8
      service/resolved/resolve1.go
  4. 7 5
      service/resolved/service.go

+ 18 - 9
adapter/outbound/manager.go

@@ -30,7 +30,7 @@ type Manager struct {
 	outboundByTag           map[string]adapter.Outbound
 	dependByTag             map[string][]string
 	defaultOutbound         adapter.Outbound
-	defaultOutboundFallback adapter.Outbound
+	defaultOutboundFallback func() (adapter.Outbound, error)
 }
 
 func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, endpoint adapter.EndpointManager, defaultTag string) *Manager {
@@ -44,7 +44,7 @@ func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry,
 	}
 }
 
-func (m *Manager) Initialize(defaultOutboundFallback adapter.Outbound) {
+func (m *Manager) Initialize(defaultOutboundFallback func() (adapter.Outbound, error)) {
 	m.defaultOutboundFallback = defaultOutboundFallback
 }
 
@@ -55,18 +55,31 @@ func (m *Manager) Start(stage adapter.StartStage) error {
 	}
 	m.started = true
 	m.stage = stage
-	outbounds := m.outbounds
-	m.access.Unlock()
 	if stage == adapter.StartStateStart {
+		if m.defaultOutbound == nil {
+			directOutbound, err := m.defaultOutboundFallback()
+			if err != nil {
+				m.access.Unlock()
+				return E.Cause(err, "create direct outbound for fallback")
+			}
+			m.outbounds = append(m.outbounds, directOutbound)
+			m.outboundByTag[directOutbound.Tag()] = directOutbound
+			m.defaultOutbound = directOutbound
+		}
 		if m.defaultTag != "" && m.defaultOutbound == nil {
 			defaultEndpoint, loaded := m.endpoint.Get(m.defaultTag)
 			if !loaded {
+				m.access.Unlock()
 				return E.New("default outbound not found: ", m.defaultTag)
 			}
 			m.defaultOutbound = defaultEndpoint
 		}
+		outbounds := m.outbounds
+		m.access.Unlock()
 		return m.startOutbounds(append(outbounds, common.Map(m.endpoint.Endpoints(), func(it adapter.Endpoint) adapter.Outbound { return it })...))
 	} else {
+		outbounds := m.outbounds
+		m.access.Unlock()
 		for _, outbound := range outbounds {
 			err := adapter.LegacyStart(outbound, stage)
 			if err != nil {
@@ -187,11 +200,7 @@ func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
 func (m *Manager) Default() adapter.Outbound {
 	m.access.RLock()
 	defer m.access.RUnlock()
-	if m.defaultOutbound != nil {
-		return m.defaultOutbound
-	} else {
-		return m.defaultOutboundFallback
-	}
+	return m.defaultOutbound
 }
 
 func (m *Manager) Remove(tag string) error {

+ 4 - 4
box.go

@@ -314,15 +314,15 @@ func New(options Options) (*Box, error) {
 			return nil, E.Cause(err, "initialize service[", i, "]")
 		}
 	}
-	outboundManager.Initialize(common.Must1(
-		direct.NewOutbound(
+	outboundManager.Initialize(func() (adapter.Outbound, error) {
+		return direct.NewOutbound(
 			ctx,
 			router,
 			logFactory.NewLogger("outbound/direct"),
 			"direct",
 			option.DirectOutboundOptions{},
-		),
-	))
+		)
+	})
 	dnsTransportManager.Initialize(common.Must1(
 		local.NewTransport(
 			ctx,

+ 17 - 8
service/resolved/resolve1.go

@@ -182,9 +182,9 @@ func (t *resolve1Manager) logRequest(sender dbus.Sender, message ...any) context
 		} else if metadata.ProcessInfo.UserId != 0 {
 			prefix = F.ToString("uid:", metadata.ProcessInfo.UserId)
 		}
-		t.logger.InfoContext(ctx, "(", prefix, ") ", F.ToString(message...))
+		t.logger.InfoContext(ctx, "(", prefix, ") ", strings.Join(F.MapToString(message), " "))
 	} else {
-		t.logger.InfoContext(ctx, F.ToString(message...))
+		t.logger.InfoContext(ctx, strings.Join(F.MapToString(message), " "))
 	}
 	return adapter.WithContext(ctx, &metadata)
 }
@@ -280,7 +280,10 @@ func (t *resolve1Manager) ResolveAddress(sender dbus.Sender, ifIndex int32, fami
 		},
 	}
 	ctx := t.logRequest(sender, "ResolveAddress ", link.iif.Name, familyToString(family), addr, flags)
-	response, lookupErr := t.dnsRouter.Exchange(ctx, request, adapter.DNSQueryOptions{})
+	var metadata adapter.InboundContext
+	metadata.InboundType = t.Type()
+	metadata.Inbound = t.Tag()
+	response, lookupErr := t.dnsRouter.Exchange(adapter.WithContext(ctx, &metadata), request, adapter.DNSQueryOptions{})
 	if lookupErr != nil {
 		err = wrapError(err)
 		return
@@ -301,7 +304,7 @@ func (t *resolve1Manager) ResolveAddress(sender dbus.Sender, ifIndex int32, fami
 	return
 }
 
-func (t *resolve1Manager) ResolveRecord(sender dbus.Sender, ifIndex int32, family int32, hostname string, qClass uint16, qType uint16, flags uint64) (records []ResourceRecord, outflags uint64, err *dbus.Error) {
+func (t *resolve1Manager) ResolveRecord(sender dbus.Sender, ifIndex int32, hostname string, qClass uint16, qType uint16, flags uint64) (records []ResourceRecord, outflags uint64, err *dbus.Error) {
 	t.linkAccess.Lock()
 	link, err := t.getLink(ifIndex)
 	if err != nil {
@@ -320,8 +323,11 @@ func (t *resolve1Manager) ResolveRecord(sender dbus.Sender, ifIndex int32, famil
 			},
 		},
 	}
-	ctx := t.logRequest(sender, "ResolveRecord ", link.iif.Name, familyToString(family), hostname, mDNS.Class(qClass), mDNS.Type(qType), flags)
-	response, exchangeErr := t.dnsRouter.Exchange(ctx, request, adapter.DNSQueryOptions{})
+	ctx := t.logRequest(sender, "ResolveRecord", link.iif.Name, hostname, mDNS.Class(qClass), mDNS.Type(qType), flags)
+	var metadata adapter.InboundContext
+	metadata.InboundType = t.Type()
+	metadata.Inbound = t.Tag()
+	response, exchangeErr := t.dnsRouter.Exchange(adapter.WithContext(ctx, &metadata), request, adapter.DNSQueryOptions{})
 	if exchangeErr != nil {
 		err = wrapError(exchangeErr)
 		return
@@ -341,6 +347,7 @@ func (t *resolve1Manager) ResolveRecord(sender dbus.Sender, ifIndex int32, famil
 			err = wrapError(unpackErr)
 		}
 		record.Data = data
+		records = append(records, record)
 	}
 	return
 }
@@ -380,8 +387,10 @@ func (t *resolve1Manager) ResolveService(sender dbus.Sender, ifIndex int32, host
 			},
 		},
 	}
-
-	srvResponse, exchangeErr := t.dnsRouter.Exchange(ctx, srvRequest, adapter.DNSQueryOptions{})
+	var metadata adapter.InboundContext
+	metadata.InboundType = t.Type()
+	metadata.Inbound = t.Tag()
+	srvResponse, exchangeErr := t.dnsRouter.Exchange(adapter.WithContext(ctx, &metadata), srvRequest, adapter.DNSQueryOptions{})
 	if exchangeErr != nil {
 		err = wrapError(exchangeErr)
 		return

+ 7 - 5
service/resolved/service.go

@@ -91,11 +91,6 @@ func (i *Service) Start(stage adapter.StartStage) error {
 				return E.New("multiple resolved service are not supported")
 			}
 		}
-	case adapter.StartStateStart:
-		err := i.listener.Start()
-		if err != nil {
-			return err
-		}
 		systemBus, err := dbus.SystemBus()
 		if err != nil {
 			return err
@@ -117,6 +112,11 @@ func (i *Service) Start(stage adapter.StartStage) error {
 			return E.New("unknown request name reply: ", reply)
 		}
 		i.networkUpdateCallback = i.network.NetworkMonitor().RegisterCallback(i.onNetworkUpdate)
+	case adapter.StartStateStart:
+		err := i.listener.Start()
+		if err != nil {
+			return err
+		}
 	}
 	return nil
 }
@@ -167,6 +167,8 @@ func (i *Service) exchangePacket0(ctx context.Context, buffer *buf.Buffer, oob [
 	}
 	var metadata adapter.InboundContext
 	metadata.Source = source
+	metadata.InboundType = i.Type()
+	metadata.Inbound = i.Tag()
 	response, err := i.dnsRouter.Exchange(adapter.WithContext(ctx, &metadata), &message, adapter.DNSQueryOptions{})
 	if err != nil {
 		return err