瀏覽代碼

Rename HTTP start context

世界 1 年之前
父節點
當前提交
327bb35ddd
共有 5 個文件被更改,包括 58 次插入66 次删除
  1. 40 4
      adapter/router.go
  2. 6 3
      route/router.go
  3. 0 45
      route/rule_set.go
  4. 1 2
      route/rule_set_local.go
  5. 11 12
      route/rule_set_remote.go

+ 40 - 4
adapter/router.go

@@ -2,13 +2,17 @@ package adapter
 
 import (
 	"context"
+	"net"
 	"net/http"
 	"net/netip"
+	"sync"
 
 	"github.com/sagernet/sing-box/common/geoip"
+	C "github.com/sagernet/sing-box/constant"
 	"github.com/sagernet/sing-dns"
 	"github.com/sagernet/sing-tun"
 	"github.com/sagernet/sing/common/control"
+	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/common/x/list"
 	"github.com/sagernet/sing/service"
@@ -98,7 +102,7 @@ type DNSRule interface {
 
 type RuleSet interface {
 	Name() string
-	StartContext(ctx context.Context, startContext RuleSetStartContext) error
+	StartContext(ctx context.Context, startContext *HTTPStartContext) error
 	PostStart() error
 	Metadata() RuleSetMetadata
 	ExtractIPSet() []*netipx.IPSet
@@ -118,10 +122,42 @@ type RuleSetMetadata struct {
 	ContainsWIFIRule    bool
 	ContainsIPCIDRRule  bool
 }
+type HTTPStartContext struct {
+	access          sync.Mutex
+	httpClientCache map[string]*http.Client
+}
+
+func NewHTTPStartContext() *HTTPStartContext {
+	return &HTTPStartContext{
+		httpClientCache: make(map[string]*http.Client),
+	}
+}
+
+func (c *HTTPStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client {
+	c.access.Lock()
+	defer c.access.Unlock()
+	if httpClient, loaded := c.httpClientCache[detour]; loaded {
+		return httpClient
+	}
+	httpClient := &http.Client{
+		Transport: &http.Transport{
+			ForceAttemptHTTP2:   true,
+			TLSHandshakeTimeout: C.TCPTimeout,
+			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+				return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
+			},
+		},
+	}
+	c.httpClientCache[detour] = httpClient
+	return httpClient
+}
 
-type RuleSetStartContext interface {
-	HTTPClient(detour string, dialer N.Dialer) *http.Client
-	Close()
+func (c *HTTPStartContext) Close() {
+	c.access.Lock()
+	defer c.access.Unlock()
+	for _, client := range c.httpClientCache {
+		client.CloseIdleConnections()
+	}
 }
 
 type InterfaceUpdateListener interface {

+ 6 - 3
route/router.go

@@ -659,14 +659,15 @@ func (r *Router) Close() error {
 
 func (r *Router) PostStart() error {
 	monitor := taskmonitor.New(r.logger, C.StopTimeout)
+	var cacheContext *adapter.HTTPStartContext
 	if len(r.ruleSets) > 0 {
 		monitor.Start("initialize rule-set")
-		ruleSetStartContext := NewRuleSetStartContext()
+		cacheContext = adapter.NewHTTPStartContext()
 		var ruleSetStartGroup task.Group
 		for i, ruleSet := range r.ruleSets {
 			ruleSetInPlace := ruleSet
 			ruleSetStartGroup.Append0(func(ctx context.Context) error {
-				err := ruleSetInPlace.StartContext(ctx, ruleSetStartContext)
+				err := ruleSetInPlace.StartContext(ctx, cacheContext)
 				if err != nil {
 					return E.Cause(err, "initialize rule-set[", i, "]")
 				}
@@ -680,7 +681,9 @@ func (r *Router) PostStart() error {
 		if err != nil {
 			return err
 		}
-		ruleSetStartContext.Close()
+	}
+	if cacheContext != nil {
+		cacheContext.Close()
 	}
 	needFindProcess := r.needFindProcess
 	needWIFIState := r.needWIFIState

+ 0 - 45
route/rule_set.go

@@ -2,9 +2,6 @@ package route
 
 import (
 	"context"
-	"net"
-	"net/http"
-	"sync"
 
 	"github.com/sagernet/sing-box/adapter"
 	C "github.com/sagernet/sing-box/constant"
@@ -12,8 +9,6 @@ import (
 	"github.com/sagernet/sing/common"
 	E "github.com/sagernet/sing/common/exceptions"
 	"github.com/sagernet/sing/common/logger"
-	M "github.com/sagernet/sing/common/metadata"
-	N "github.com/sagernet/sing/common/network"
 
 	"go4.org/netipx"
 )
@@ -46,43 +41,3 @@ func extractIPSetFromRule(rawRule adapter.HeadlessRule) []*netipx.IPSet {
 		panic("unexpected rule type")
 	}
 }
-
-var _ adapter.RuleSetStartContext = (*RuleSetStartContext)(nil)
-
-type RuleSetStartContext struct {
-	access          sync.Mutex
-	httpClientCache map[string]*http.Client
-}
-
-func NewRuleSetStartContext() *RuleSetStartContext {
-	return &RuleSetStartContext{
-		httpClientCache: make(map[string]*http.Client),
-	}
-}
-
-func (c *RuleSetStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client {
-	c.access.Lock()
-	defer c.access.Unlock()
-	if httpClient, loaded := c.httpClientCache[detour]; loaded {
-		return httpClient
-	}
-	httpClient := &http.Client{
-		Transport: &http.Transport{
-			ForceAttemptHTTP2:   true,
-			TLSHandshakeTimeout: C.TCPTimeout,
-			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
-				return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
-			},
-		},
-	}
-	c.httpClientCache[detour] = httpClient
-	return httpClient
-}
-
-func (c *RuleSetStartContext) Close() {
-	c.access.Lock()
-	defer c.access.Unlock()
-	for _, client := range c.httpClientCache {
-		client.CloseIdleConnections()
-	}
-}

+ 1 - 2
route/rule_set_local.go

@@ -58,7 +58,6 @@ func NewLocalRuleSet(ctx context.Context, router adapter.Router, logger logger.L
 		}
 	}
 	if options.Type == C.RuleSetTypeLocal {
-		var watcher *fswatch.Watcher
 		filePath, _ := filepath.Abs(options.LocalOptions.Path)
 		watcher, err := fswatch.NewWatcher(fswatch.Options{
 			Path: []string{filePath},
@@ -85,7 +84,7 @@ func (s *LocalRuleSet) String() string {
 	return strings.Join(F.MapToString(s.rules), " ")
 }
 
-func (s *LocalRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error {
+func (s *LocalRuleSet) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
 	if s.watcher != nil {
 		err := s.watcher.Start()
 		if err != nil {

+ 11 - 12
route/rule_set_remote.go

@@ -45,6 +45,7 @@ type RemoteRuleSet struct {
 	lastUpdated    time.Time
 	lastEtag       string
 	updateTicker   *time.Ticker
+	cacheFile      adapter.CacheFile
 	pauseManager   pause.Manager
 	callbackAccess sync.Mutex
 	callbacks      list.List[adapter.RuleSetUpdateCallback]
@@ -78,7 +79,8 @@ func (s *RemoteRuleSet) String() string {
 	return strings.Join(F.MapToString(s.rules), " ")
 }
 
-func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error {
+func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
+	s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx)
 	var dialer N.Dialer
 	if s.options.RemoteOptions.DownloadDetour != "" {
 		outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour)
@@ -94,9 +96,8 @@ func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext adapter.R
 		dialer = outbound
 	}
 	s.dialer = dialer
-	cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
-	if cacheFile != nil {
-		if savedSet := cacheFile.LoadRuleSet(s.options.Tag); savedSet != nil {
+	if s.cacheFile != nil {
+		if savedSet := s.cacheFile.LoadRuleSet(s.options.Tag); savedSet != nil {
 			err := s.loadBytes(savedSet.Content)
 			if err != nil {
 				return E.Cause(err, "restore cached rule-set")
@@ -226,7 +227,7 @@ func (s *RemoteRuleSet) loopUpdate() {
 	}
 }
 
-func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.RuleSetStartContext) error {
+func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext *adapter.HTTPStartContext) error {
 	s.logger.Debug("updating rule-set ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL)
 	var httpClient *http.Client
 	if startContext != nil {
@@ -257,12 +258,11 @@ func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.Rule
 	case http.StatusOK:
 	case http.StatusNotModified:
 		s.lastUpdated = time.Now()
-		cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
-		if cacheFile != nil {
-			savedRuleSet := cacheFile.LoadRuleSet(s.options.Tag)
+		if s.cacheFile != nil {
+			savedRuleSet := s.cacheFile.LoadRuleSet(s.options.Tag)
 			if savedRuleSet != nil {
 				savedRuleSet.LastUpdated = s.lastUpdated
-				err = cacheFile.SaveRuleSet(s.options.Tag, savedRuleSet)
+				err = s.cacheFile.SaveRuleSet(s.options.Tag, savedRuleSet)
 				if err != nil {
 					s.logger.Error("save rule-set updated time: ", err)
 					return nil
@@ -290,9 +290,8 @@ func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.Rule
 		s.lastEtag = eTagHeader
 	}
 	s.lastUpdated = time.Now()
-	cacheFile := service.FromContext[adapter.CacheFile](s.ctx)
-	if cacheFile != nil {
-		err = cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{
+	if s.cacheFile != nil {
+		err = s.cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{
 			LastUpdated: s.lastUpdated,
 			Content:     content,
 			LastEtag:    s.lastEtag,