Browse Source

Fix route

世界 3 years ago
parent
commit
43cf0441db
5 changed files with 69 additions and 27 deletions
  1. 1 1
      go.mod
  2. 2 2
      go.sum
  3. 14 8
      route/router.go
  4. 41 13
      route/rule.go
  5. 11 3
      service.go

+ 1 - 1
go.mod

@@ -7,7 +7,7 @@ require (
 	github.com/goccy/go-json v0.9.8
 	github.com/logrusorgru/aurora v2.0.3+incompatible
 	github.com/oschwald/geoip2-golang v1.7.0
-	github.com/sagernet/sing v0.0.0-20220703122912-677c52f01aba
+	github.com/sagernet/sing v0.0.0-20220704113227-8b990551511a
 	github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649
 	github.com/sirupsen/logrus v1.8.1
 	github.com/spf13/cobra v1.5.0

+ 2 - 2
go.sum

@@ -20,8 +20,8 @@ github.com/oschwald/maxminddb-golang v1.9.0/go.mod h1:TK+s/Z2oZq0rSl4PSeAEoP0bgm
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
-github.com/sagernet/sing v0.0.0-20220703122912-677c52f01aba h1:ffb+Es7ddyDDOYUXKoJz5vpA+9C80GK7f7sjYN9rFvY=
-github.com/sagernet/sing v0.0.0-20220703122912-677c52f01aba/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c=
+github.com/sagernet/sing v0.0.0-20220704113227-8b990551511a h1:IvYjuvuPNmZzQfBbCxE/uQqGkNWUa5/KrEMIecRMjZk=
+github.com/sagernet/sing v0.0.0-20220704113227-8b990551511a/go.mod h1:3ZmoGNg/nNJTyHAZFNRSPaXpNIwpDvyIiAUd0KIWV5c=
 github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649 h1:whNDUGOAX5GPZkSy4G3Gv9QyIgk5SXRyjkRuP7ohF8k=
 github.com/sagernet/sing-shadowsocks v0.0.0-20220701084835-2208da1d8649/go.mod h1:MuyT+9fEPjvauAv0fSE0a6Q+l0Tv2ZrAafTkYfnxBFw=
 github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=

+ 14 - 8
route/router.go

@@ -19,6 +19,7 @@ import (
 	F "github.com/sagernet/sing/common/format"
 	M "github.com/sagernet/sing/common/metadata"
 	N "github.com/sagernet/sing/common/network"
+	"github.com/sagernet/sing/common/rw"
 )
 
 var _ adapter.Router = (*Router)(nil)
@@ -82,7 +83,7 @@ func isGeoRule(rule option.DefaultRule) bool {
 }
 
 func notPrivateNode(code string) bool {
-	return code == "private"
+	return code != "private"
 }
 
 func (r *Router) Initialize(outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error {
@@ -156,7 +157,10 @@ func (r *Router) Initialize(outbounds []adapter.Outbound, defaultOutbound func()
 
 func (r *Router) Start() error {
 	if r.needGeoDatabase {
-		go r.prepareGeoIPDatabase()
+		err := r.prepareGeoIPDatabase()
+		if err != nil {
+			return err
+		}
 	}
 	return nil
 }
@@ -171,15 +175,17 @@ func (r *Router) GeoIPReader() *geoip2.Reader {
 	return r.geoReader
 }
 
-func (r *Router) prepareGeoIPDatabase() {
+func (r *Router) prepareGeoIPDatabase() error {
 	var geoPath string
 	if r.geoOptions.Path != "" {
 		geoPath = r.geoOptions.Path
 	} else {
 		geoPath = "Country.mmdb"
+		if foundPath, loaded := C.Find(geoPath); loaded {
+			geoPath = foundPath
+		}
 	}
-	geoPath, loaded := C.Find(geoPath)
-	if !loaded {
+	if !rw.FileExists(geoPath) {
 		r.logger.Warn("geoip database not exists: ", geoPath)
 		var err error
 		for attempts := 0; attempts < 3; attempts++ {
@@ -192,7 +198,7 @@ func (r *Router) prepareGeoIPDatabase() {
 			time.Sleep(10 * time.Second)
 		}
 		if err != nil {
-			return
+			return err
 		}
 	}
 	geoReader, err := geoip2.Open(geoPath)
@@ -200,9 +206,9 @@ func (r *Router) prepareGeoIPDatabase() {
 		r.logger.Info("loaded geoip database")
 		r.geoReader = geoReader
 	} else {
-		r.logger.Error("open geoip database: ", err)
-		return
+		return E.Cause(err, "open geoip database")
 	}
+	return nil
 }
 
 func (r *Router) downloadGeoIPDatabase(savePath string) error {

+ 41 - 13
route/rule.go

@@ -41,9 +41,10 @@ func NewRule(router adapter.Router, logger log.Logger, options option.Rule) (ada
 var _ adapter.Rule = (*DefaultRule)(nil)
 
 type DefaultRule struct {
-	index    int
-	outbound string
-	items    []RuleItem
+	items                   []RuleItem
+	sourceAddressItems      []RuleItem
+	destinationAddressItems []RuleItem
+	outbound                string
 }
 
 type RuleItem interface {
@@ -78,37 +79,37 @@ func NewDefaultRule(router adapter.Router, logger log.Logger, options option.Def
 		rule.items = append(rule.items, NewProtocolItem(options.Protocol))
 	}
 	if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 {
-		rule.items = append(rule.items, NewDomainItem(options.Domain, options.DomainSuffix))
+		rule.destinationAddressItems = append(rule.destinationAddressItems, NewDomainItem(options.Domain, options.DomainSuffix))
 	}
 	if len(options.DomainKeyword) > 0 {
-		rule.items = append(rule.items, NewDomainKeywordItem(options.DomainKeyword))
+		rule.destinationAddressItems = append(rule.destinationAddressItems, NewDomainKeywordItem(options.DomainKeyword))
 	}
 	if len(options.DomainRegex) > 0 {
 		item, err := NewDomainRegexItem(options.DomainRegex)
 		if err != nil {
 			return nil, E.Cause(err, "domain_regex")
 		}
-		rule.items = append(rule.items, item)
+		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
 	}
 	if len(options.SourceGeoIP) > 0 {
-		rule.items = append(rule.items, NewGeoIPItem(router, logger, true, options.SourceGeoIP))
+		rule.sourceAddressItems = append(rule.sourceAddressItems, NewGeoIPItem(router, logger, true, options.SourceGeoIP))
 	}
 	if len(options.GeoIP) > 0 {
-		rule.items = append(rule.items, NewGeoIPItem(router, logger, false, options.GeoIP))
+		rule.destinationAddressItems = append(rule.destinationAddressItems, NewGeoIPItem(router, logger, false, options.GeoIP))
 	}
 	if len(options.SourceIPCIDR) > 0 {
 		item, err := NewIPCIDRItem(true, options.SourceIPCIDR)
 		if err != nil {
 			return nil, E.Cause(err, "source_ipcidr")
 		}
-		rule.items = append(rule.items, item)
+		rule.sourceAddressItems = append(rule.sourceAddressItems, item)
 	}
 	if len(options.IPCIDR) > 0 {
 		item, err := NewIPCIDRItem(false, options.IPCIDR)
 		if err != nil {
 			return nil, E.Cause(err, "ipcidr")
 		}
-		rule.items = append(rule.items, item)
+		rule.destinationAddressItems = append(rule.destinationAddressItems, item)
 	}
 	if len(options.SourcePort) > 0 {
 		rule.items = append(rule.items, NewPortItem(true, options.SourcePort))
@@ -121,11 +122,38 @@ func NewDefaultRule(router adapter.Router, logger log.Logger, options option.Def
 
 func (r *DefaultRule) Match(metadata *adapter.InboundContext) bool {
 	for _, item := range r.items {
-		if item.Match(metadata) {
-			return true
+		if !item.Match(metadata) {
+			return false
 		}
 	}
-	return false
+
+	if len(r.sourceAddressItems) > 0 {
+		var sourceAddressMatch bool
+		for _, item := range r.sourceAddressItems {
+			if item.Match(metadata) {
+				sourceAddressMatch = true
+				break
+			}
+		}
+		if !sourceAddressMatch {
+			return false
+		}
+	}
+
+	if len(r.destinationAddressItems) > 0 {
+		var destinationAddressMatch bool
+		for _, item := range r.destinationAddressItems {
+			if item.Match(metadata) {
+				destinationAddressMatch = true
+				break
+			}
+		}
+		if !destinationAddressMatch {
+			return false
+		}
+	}
+
+	return true
 }
 
 func (r *DefaultRule) Outbound() string {

+ 11 - 3
service.go

@@ -2,6 +2,7 @@ package box
 
 import (
 	"context"
+	"time"
 
 	"github.com/sagernet/sing-box/adapter"
 	"github.com/sagernet/sing-box/inbound"
@@ -11,6 +12,7 @@ import (
 	"github.com/sagernet/sing-box/route"
 	"github.com/sagernet/sing/common"
 	E "github.com/sagernet/sing/common/exceptions"
+	F "github.com/sagernet/sing/common/format"
 )
 
 var _ adapter.Service = (*Service)(nil)
@@ -20,9 +22,11 @@ type Service struct {
 	logger    log.Logger
 	inbounds  []adapter.Inbound
 	outbounds []adapter.Outbound
+	createdAt time.Time
 }
 
 func NewService(ctx context.Context, options option.Options) (*Service, error) {
+	createdAt := time.Now()
 	logger, err := log.NewLogger(common.PtrValueOrDefault(options.Log))
 	if err != nil {
 		return nil, E.Cause(err, "parse log options")
@@ -63,6 +67,7 @@ func NewService(ctx context.Context, options option.Options) (*Service, error) {
 		logger:    logger,
 		inbounds:  inbounds,
 		outbounds: outbounds,
+		createdAt: createdAt,
 	}, nil
 }
 
@@ -71,15 +76,18 @@ func (s *Service) Start() error {
 	if err != nil {
 		return err
 	}
+	err = s.router.Start()
+	if err != nil {
+		return err
+	}
 	for _, in := range s.inbounds {
 		err = in.Start()
 		if err != nil {
 			return err
 		}
 	}
-	return common.AnyError(
-		s.router.Start(),
-	)
+	s.logger.Info("sing-box started (", F.Seconds(time.Since(s.createdAt).Seconds()), "s)")
+	return nil
 }
 
 func (s *Service) Close() error {