Procházet zdrojové kódy

Add cache support for ssm-api

世界 před 4 měsíci
rodič
revize
e65b78d1e4

+ 7 - 1
docs/configuration/service/ssm-api.md

@@ -19,6 +19,7 @@ See https://github.com/Shadowsocks-NET/shadowsocks-specs/blob/main/2023-1-shadow
   ... // Listen Fields
   
   "servers": {},
+  "cache_path": "",
   "tls": {}
 }
 ```
@@ -37,7 +38,7 @@ A mapping Object from HTTP endpoints to [Shadowsocks Inbound](/configuration/inb
 
 Selected Shadowsocks inbounds must be configured with [managed](/configuration/inbound/shadowsocks#managed) enabled.
 
-Example: 
+Example:
 
 ```json
 {
@@ -47,6 +48,11 @@ Example:
 }
 ```
 
+#### cache_path
+
+If set, when the server is about to stop, traffic and user state will be saved to the specified JSON file
+to be restored on the next startup.
+
 #### tls
 
 TLS configuration, see [TLS](/configuration/shared/tls/#inbound).

+ 2 - 1
option/ssmapi.go

@@ -6,6 +6,7 @@ import (
 
 type SSMAPIServiceOptions struct {
 	ListenOptions
-	Servers *badjson.TypedMap[string, string] `json:"servers"`
+	Servers   *badjson.TypedMap[string, string] `json:"servers"`
+	CachePath string                            `json:"cache_path,omitempty"`
 	InboundTLSOptionsContainer
 }

+ 1 - 1
service/derp/service.go

@@ -134,7 +134,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
 func (d *Service) Start(stage adapter.StartStage) error {
 	switch stage {
 	case adapter.StartStateStart:
-		config, err := readDERPConfig(d.configPath)
+		config, err := readDERPConfig(filemanager.BasePath(d.ctx, d.configPath))
 		if err != nil {
 			return err
 		}

+ 222 - 0
service/ssmapi/cache.go

@@ -0,0 +1,222 @@
+package ssmapi
+
+import (
+	"bytes"
+	"os"
+	"path/filepath"
+	"sort"
+
+	"github.com/sagernet/sing/common/atomic"
+	"github.com/sagernet/sing/common/json"
+	"github.com/sagernet/sing/common/json/badjson"
+	"github.com/sagernet/sing/service/filemanager"
+)
+
+type Cache struct {
+	Endpoints *badjson.TypedMap[string, *EndpointCache] `json:"endpoints"`
+}
+
+type EndpointCache struct {
+	GlobalUplink          int64                             `json:"global_uplink"`
+	GlobalDownlink        int64                             `json:"global_downlink"`
+	GlobalUplinkPackets   int64                             `json:"global_uplink_packets"`
+	GlobalDownlinkPackets int64                             `json:"global_downlink_packets"`
+	GlobalTCPSessions     int64                             `json:"global_tcp_sessions"`
+	GlobalUDPSessions     int64                             `json:"global_udp_sessions"`
+	UserUplink            *badjson.TypedMap[string, int64]  `json:"user_uplink"`
+	UserDownlink          *badjson.TypedMap[string, int64]  `json:"user_downlink"`
+	UserUplinkPackets     *badjson.TypedMap[string, int64]  `json:"user_uplink_packets"`
+	UserDownlinkPackets   *badjson.TypedMap[string, int64]  `json:"user_downlink_packets"`
+	UserTCPSessions       *badjson.TypedMap[string, int64]  `json:"user_tcp_sessions"`
+	UserUDPSessions       *badjson.TypedMap[string, int64]  `json:"user_udp_sessions"`
+	Users                 *badjson.TypedMap[string, string] `json:"users"`
+}
+
+func (s *Service) loadCache() error {
+	if s.cachePath == "" {
+		return nil
+	}
+	basePath := filemanager.BasePath(s.ctx, s.cachePath)
+	cacheBinary, err := os.ReadFile(basePath)
+	if err != nil {
+		if os.IsNotExist(err) {
+			return nil
+		}
+		return err
+	}
+	err = s.decodeCache(cacheBinary)
+	if err != nil {
+		os.RemoveAll(basePath)
+		return err
+	}
+	return nil
+}
+
+func (s *Service) saveCache() error {
+	if s.cachePath == "" {
+		return nil
+	}
+	basePath := filemanager.BasePath(s.ctx, s.cachePath)
+	err := os.MkdirAll(filepath.Dir(basePath), 0o777)
+	if err != nil {
+		return err
+	}
+	cacheBinary, err := s.encodeCache()
+	if err != nil {
+		return err
+	}
+	return os.WriteFile(s.cachePath, cacheBinary, 0o644)
+}
+
+func (s *Service) decodeCache(cacheBinary []byte) error {
+	if len(cacheBinary) == 0 {
+		return nil
+	}
+	cache, err := json.UnmarshalExtended[*Cache](cacheBinary)
+	if err != nil {
+		return err
+	}
+	if cache.Endpoints == nil || cache.Endpoints.Size() == 0 {
+		return nil
+	}
+	for _, entry := range cache.Endpoints.Entries() {
+		trafficManager, loaded := s.traffics[entry.Key]
+		if !loaded {
+			continue
+		}
+		trafficManager.globalUplink.Store(entry.Value.GlobalUplink)
+		trafficManager.globalDownlink.Store(entry.Value.GlobalDownlink)
+		trafficManager.globalUplinkPackets.Store(entry.Value.GlobalUplinkPackets)
+		trafficManager.globalDownlinkPackets.Store(entry.Value.GlobalDownlinkPackets)
+		trafficManager.globalTCPSessions.Store(entry.Value.GlobalTCPSessions)
+		trafficManager.globalUDPSessions.Store(entry.Value.GlobalUDPSessions)
+		trafficManager.userUplink = typedAtomicInt64Map(entry.Value.UserUplink)
+		trafficManager.userDownlink = typedAtomicInt64Map(entry.Value.UserDownlink)
+		trafficManager.userUplinkPackets = typedAtomicInt64Map(entry.Value.UserUplinkPackets)
+		trafficManager.userDownlinkPackets = typedAtomicInt64Map(entry.Value.UserDownlinkPackets)
+		trafficManager.userTCPSessions = typedAtomicInt64Map(entry.Value.UserTCPSessions)
+		trafficManager.userUDPSessions = typedAtomicInt64Map(entry.Value.UserUDPSessions)
+		userManager, loaded := s.users[entry.Key]
+		if !loaded {
+			continue
+		}
+		userManager.usersMap = typedMap(entry.Value.Users)
+		_ = userManager.postUpdate(false)
+	}
+	return nil
+}
+
+func (s *Service) encodeCache() ([]byte, error) {
+	endpoints := new(badjson.TypedMap[string, *EndpointCache])
+	for tag, traffic := range s.traffics {
+		var (
+			userUplink          = new(badjson.TypedMap[string, int64])
+			userDownlink        = new(badjson.TypedMap[string, int64])
+			userUplinkPackets   = new(badjson.TypedMap[string, int64])
+			userDownlinkPackets = new(badjson.TypedMap[string, int64])
+			userTCPSessions     = new(badjson.TypedMap[string, int64])
+			userUDPSessions     = new(badjson.TypedMap[string, int64])
+			userMap             = new(badjson.TypedMap[string, string])
+		)
+		for user, uplink := range traffic.userUplink {
+			if uplink.Load() > 0 {
+				userUplink.Put(user, uplink.Load())
+			}
+		}
+		for user, downlink := range traffic.userDownlink {
+			if downlink.Load() > 0 {
+				userDownlink.Put(user, downlink.Load())
+			}
+		}
+		for user, uplinkPackets := range traffic.userUplinkPackets {
+			if uplinkPackets.Load() > 0 {
+				userUplinkPackets.Put(user, uplinkPackets.Load())
+			}
+		}
+		for user, downlinkPackets := range traffic.userDownlinkPackets {
+			if downlinkPackets.Load() > 0 {
+				userDownlinkPackets.Put(user, downlinkPackets.Load())
+			}
+		}
+		for user, tcpSessions := range traffic.userTCPSessions {
+			if tcpSessions.Load() > 0 {
+				userTCPSessions.Put(user, tcpSessions.Load())
+			}
+		}
+		for user, udpSessions := range traffic.userUDPSessions {
+			if udpSessions.Load() > 0 {
+				userUDPSessions.Put(user, udpSessions.Load())
+			}
+		}
+		userManager := s.users[tag]
+		if userManager != nil && len(userManager.usersMap) > 0 {
+			userMap = new(badjson.TypedMap[string, string])
+			for username, password := range userManager.usersMap {
+				if username != "" && password != "" {
+					userMap.Put(username, password)
+				}
+			}
+		}
+		endpoints.Put(tag, &EndpointCache{
+			GlobalUplink:          traffic.globalUplink.Load(),
+			GlobalDownlink:        traffic.globalDownlink.Load(),
+			GlobalUplinkPackets:   traffic.globalUplinkPackets.Load(),
+			GlobalDownlinkPackets: traffic.globalDownlinkPackets.Load(),
+			GlobalTCPSessions:     traffic.globalTCPSessions.Load(),
+			GlobalUDPSessions:     traffic.globalUDPSessions.Load(),
+			UserUplink:            sortTypedMap(userUplink),
+			UserDownlink:          sortTypedMap(userDownlink),
+			UserUplinkPackets:     sortTypedMap(userUplinkPackets),
+			UserDownlinkPackets:   sortTypedMap(userDownlinkPackets),
+			UserTCPSessions:       sortTypedMap(userTCPSessions),
+			UserUDPSessions:       sortTypedMap(userUDPSessions),
+			Users:                 sortTypedMap(userMap),
+		})
+	}
+	var buffer bytes.Buffer
+	encoder := json.NewEncoder(&buffer)
+	encoder.SetIndent("", "  ")
+	err := encoder.Encode(&Cache{
+		Endpoints: sortTypedMap(endpoints),
+	})
+	if err != nil {
+		return nil, err
+	}
+	return buffer.Bytes(), nil
+}
+
+func sortTypedMap[T comparable](trafficMap *badjson.TypedMap[string, T]) *badjson.TypedMap[string, T] {
+	if trafficMap == nil {
+		return nil
+	}
+	keys := trafficMap.Keys()
+	sort.Strings(keys)
+	sortedMap := new(badjson.TypedMap[string, T])
+	for _, key := range keys {
+		value, _ := trafficMap.Get(key)
+		sortedMap.Put(key, value)
+	}
+	return sortedMap
+}
+
+func typedAtomicInt64Map(trafficMap *badjson.TypedMap[string, int64]) map[string]*atomic.Int64 {
+	result := make(map[string]*atomic.Int64)
+	if trafficMap != nil {
+		for _, entry := range trafficMap.Entries() {
+			counter := new(atomic.Int64)
+			counter.Store(entry.Value)
+			result[entry.Key] = counter
+		}
+	}
+	return result
+}
+
+func typedMap[T comparable](trafficMap *badjson.TypedMap[string, T]) map[string]T {
+	result := make(map[string]T)
+	if trafficMap != nil {
+		for _, entry := range trafficMap.Entries() {
+			result[entry.Key] = entry.Value
+		}
+	}
+	return result
+}

+ 17 - 1
service/ssmapi/server.go

@@ -33,6 +33,9 @@ type Service struct {
 	listener   *listener.Listener
 	tlsConfig  tls.ServerConfig
 	httpServer *http.Server
+	traffics   map[string]*TrafficManager
+	users      map[string]*UserManager
+	cachePath  string
 }
 
 func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.SSMAPIServiceOptions) (adapter.Service, error) {
@@ -50,6 +53,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
 		httpServer: &http.Server{
 			Handler: chiRouter,
 		},
+		traffics:  make(map[string]*TrafficManager),
+		users:     make(map[string]*UserManager),
+		cachePath: options.CachePath,
 	}
 	inboundManager := service.FromContext[adapter.InboundManager](ctx)
 	if options.Servers.Size() == 0 {
@@ -68,6 +74,8 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
 		managedServer.SetTracker(traffic)
 		user := NewUserManager(managedServer, traffic)
 		chiRouter.Route(entry.Key, NewAPIServer(logger, traffic, user).Route)
+		s.traffics[entry.Key] = traffic
+		s.users[entry.Key] = user
 	}
 	if options.TLS != nil {
 		tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
@@ -83,8 +91,12 @@ func (s *Service) Start(stage adapter.StartStage) error {
 	if stage != adapter.StartStateStart {
 		return nil
 	}
+	err := s.loadCache()
+	if err != nil {
+		s.logger.Error(E.Cause(err, "load cache"))
+	}
 	if s.tlsConfig != nil {
-		err := s.tlsConfig.Start()
+		err = s.tlsConfig.Start()
 		if err != nil {
 			return E.Cause(err, "create TLS config")
 		}
@@ -109,6 +121,10 @@ func (s *Service) Start(stage adapter.StartStage) error {
 }
 
 func (s *Service) Close() error {
+	err := s.saveCache()
+	if err != nil {
+		s.logger.Error(E.Cause(err, "save cache"))
+	}
 	return common.Close(
 		common.PtrOrNil(s.httpServer),
 		common.PtrOrNil(s.listener),

+ 7 - 5
service/ssmapi/user.go

@@ -22,7 +22,7 @@ func NewUserManager(inbound adapter.ManagedSSMServer, trafficManager *TrafficMan
 	}
 }
 
-func (m *UserManager) postUpdate() error {
+func (m *UserManager) postUpdate(updated bool) error {
 	users := make([]string, 0, len(m.usersMap))
 	uPSKs := make([]string, 0, len(m.usersMap))
 	for username, password := range m.usersMap {
@@ -33,7 +33,9 @@ func (m *UserManager) postUpdate() error {
 	if err != nil {
 		return err
 	}
-	m.trafficManager.UpdateUsers(users)
+	if updated {
+		m.trafficManager.UpdateUsers(users)
+	}
 	return nil
 }
 
@@ -58,7 +60,7 @@ func (m *UserManager) Add(username string, password string) error {
 		return E.New("user ", username, " already exists")
 	}
 	m.usersMap[username] = password
-	return m.postUpdate()
+	return m.postUpdate(true)
 }
 
 func (m *UserManager) Get(username string) (string, bool) {
@@ -74,12 +76,12 @@ func (m *UserManager) Update(username string, password string) error {
 	m.access.Lock()
 	defer m.access.Unlock()
 	m.usersMap[username] = password
-	return m.postUpdate()
+	return m.postUpdate(true)
 }
 
 func (m *UserManager) Delete(username string) error {
 	m.access.Lock()
 	defer m.access.Unlock()
 	delete(m.usersMap, username)
-	return m.postUpdate()
+	return m.postUpdate(true)
 }