Selaa lähdekoodia

refactor(ratio): replace maps with RWMap for improved concurrency handling

CaIon 4 päivää sitten
vanhempi
sitoutus
44c5fac5ea

+ 14 - 6
common/topup-ratio.go

@@ -2,29 +2,37 @@ package common
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"sync"
 )
 )
 
 
-var TopupGroupRatio = map[string]float64{
+var topupGroupRatio = map[string]float64{
 	"default": 1,
 	"default": 1,
 	"vip":     1,
 	"vip":     1,
 	"svip":    1,
 	"svip":    1,
 }
 }
+var topupGroupRatioMutex sync.RWMutex
 
 
 func TopupGroupRatio2JSONString() string {
 func TopupGroupRatio2JSONString() string {
-	jsonBytes, err := json.Marshal(TopupGroupRatio)
+	topupGroupRatioMutex.RLock()
+	defer topupGroupRatioMutex.RUnlock()
+	jsonBytes, err := json.Marshal(topupGroupRatio)
 	if err != nil {
 	if err != nil {
-		SysError("error marshalling model ratio: " + err.Error())
+		SysError("error marshalling topup group ratio: " + err.Error())
 	}
 	}
 	return string(jsonBytes)
 	return string(jsonBytes)
 }
 }
 
 
 func UpdateTopupGroupRatioByJSONString(jsonStr string) error {
 func UpdateTopupGroupRatioByJSONString(jsonStr string) error {
-	TopupGroupRatio = make(map[string]float64)
-	return json.Unmarshal([]byte(jsonStr), &TopupGroupRatio)
+	topupGroupRatioMutex.Lock()
+	defer topupGroupRatioMutex.Unlock()
+	topupGroupRatio = make(map[string]float64)
+	return json.Unmarshal([]byte(jsonStr), &topupGroupRatio)
 }
 }
 
 
 func GetTopupGroupRatio(name string) float64 {
 func GetTopupGroupRatio(name string) float64 {
-	ratio, ok := TopupGroupRatio[name]
+	topupGroupRatioMutex.RLock()
+	defer topupGroupRatioMutex.RUnlock()
+	ratio, ok := topupGroupRatio[name]
 	if !ok {
 	if !ok {
 		SysError("topup group ratio not found: " + name)
 		SysError("topup group ratio not found: " + name)
 		return 1
 		return 1

+ 13 - 63
setting/ratio_setting/cache_ratio.go

@@ -1,10 +1,7 @@
 package ratio_setting
 package ratio_setting
 
 
 import (
 import (
-	"encoding/json"
-	"sync"
-
-	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/types"
 )
 )
 
 
 var defaultCacheRatio = map[string]float64{
 var defaultCacheRatio = map[string]float64{
@@ -98,70 +95,37 @@ var defaultCreateCacheRatio = map[string]float64{
 
 
 //var defaultCreateCacheRatio = map[string]float64{}
 //var defaultCreateCacheRatio = map[string]float64{}
 
 
-var cacheRatioMap map[string]float64
-var cacheRatioMapMutex sync.RWMutex
-
-var createCacheRatioMap map[string]float64
-var createCacheRatioMapMutex sync.RWMutex
+var cacheRatioMap = types.NewRWMap[string, float64]()
+var createCacheRatioMap = types.NewRWMap[string, float64]()
 
 
-// GetCacheRatioMap returns the cache ratio map
+// GetCacheRatioMap returns a copy of the cache ratio map
 func GetCacheRatioMap() map[string]float64 {
 func GetCacheRatioMap() map[string]float64 {
-	cacheRatioMapMutex.RLock()
-	defer cacheRatioMapMutex.RUnlock()
-	return cacheRatioMap
+	return cacheRatioMap.ReadAll()
 }
 }
 
 
 // CacheRatio2JSONString converts the cache ratio map to a JSON string
 // CacheRatio2JSONString converts the cache ratio map to a JSON string
 func CacheRatio2JSONString() string {
 func CacheRatio2JSONString() string {
-	cacheRatioMapMutex.RLock()
-	defer cacheRatioMapMutex.RUnlock()
-	jsonBytes, err := json.Marshal(cacheRatioMap)
-	if err != nil {
-		common.SysLog("error marshalling cache ratio: " + err.Error())
-	}
-	return string(jsonBytes)
+	return cacheRatioMap.MarshalJSONString()
 }
 }
 
 
 // CreateCacheRatio2JSONString converts the create cache ratio map to a JSON string
 // CreateCacheRatio2JSONString converts the create cache ratio map to a JSON string
 func CreateCacheRatio2JSONString() string {
 func CreateCacheRatio2JSONString() string {
-	createCacheRatioMapMutex.RLock()
-	defer createCacheRatioMapMutex.RUnlock()
-	jsonBytes, err := json.Marshal(createCacheRatioMap)
-	if err != nil {
-		common.SysLog("error marshalling create cache ratio: " + err.Error())
-	}
-	return string(jsonBytes)
+	return createCacheRatioMap.MarshalJSONString()
 }
 }
 
 
 // UpdateCacheRatioByJSONString updates the cache ratio map from a JSON string
 // UpdateCacheRatioByJSONString updates the cache ratio map from a JSON string
 func UpdateCacheRatioByJSONString(jsonStr string) error {
 func UpdateCacheRatioByJSONString(jsonStr string) error {
-	cacheRatioMapMutex.Lock()
-	defer cacheRatioMapMutex.Unlock()
-	cacheRatioMap = make(map[string]float64)
-	err := json.Unmarshal([]byte(jsonStr), &cacheRatioMap)
-	if err == nil {
-		InvalidateExposedDataCache()
-	}
-	return err
+	return types.LoadFromJsonStringWithCallback(cacheRatioMap, jsonStr, InvalidateExposedDataCache)
 }
 }
 
 
 // UpdateCreateCacheRatioByJSONString updates the create cache ratio map from a JSON string
 // UpdateCreateCacheRatioByJSONString updates the create cache ratio map from a JSON string
 func UpdateCreateCacheRatioByJSONString(jsonStr string) error {
 func UpdateCreateCacheRatioByJSONString(jsonStr string) error {
-	createCacheRatioMapMutex.Lock()
-	defer createCacheRatioMapMutex.Unlock()
-	createCacheRatioMap = make(map[string]float64)
-	err := json.Unmarshal([]byte(jsonStr), &createCacheRatioMap)
-	if err == nil {
-		InvalidateExposedDataCache()
-	}
-	return err
+	return types.LoadFromJsonStringWithCallback(createCacheRatioMap, jsonStr, InvalidateExposedDataCache)
 }
 }
 
 
 // GetCacheRatio returns the cache ratio for a model
 // GetCacheRatio returns the cache ratio for a model
 func GetCacheRatio(name string) (float64, bool) {
 func GetCacheRatio(name string) (float64, bool) {
-	cacheRatioMapMutex.RLock()
-	defer cacheRatioMapMutex.RUnlock()
-	ratio, ok := cacheRatioMap[name]
+	ratio, ok := cacheRatioMap.Get(name)
 	if !ok {
 	if !ok {
 		return 1, false // Default to 1 if not found
 		return 1, false // Default to 1 if not found
 	}
 	}
@@ -169,9 +133,7 @@ func GetCacheRatio(name string) (float64, bool) {
 }
 }
 
 
 func GetCreateCacheRatio(name string) (float64, bool) {
 func GetCreateCacheRatio(name string) (float64, bool) {
-	createCacheRatioMapMutex.RLock()
-	defer createCacheRatioMapMutex.RUnlock()
-	ratio, ok := createCacheRatioMap[name]
+	ratio, ok := createCacheRatioMap.Get(name)
 	if !ok {
 	if !ok {
 		return 1.25, false // Default to 1.25 if not found
 		return 1.25, false // Default to 1.25 if not found
 	}
 	}
@@ -179,21 +141,9 @@ func GetCreateCacheRatio(name string) (float64, bool) {
 }
 }
 
 
 func GetCacheRatioCopy() map[string]float64 {
 func GetCacheRatioCopy() map[string]float64 {
-	cacheRatioMapMutex.RLock()
-	defer cacheRatioMapMutex.RUnlock()
-	copyMap := make(map[string]float64, len(cacheRatioMap))
-	for k, v := range cacheRatioMap {
-		copyMap[k] = v
-	}
-	return copyMap
+	return cacheRatioMap.ReadAll()
 }
 }
 
 
 func GetCreateCacheRatioCopy() map[string]float64 {
 func GetCreateCacheRatioCopy() map[string]float64 {
-	createCacheRatioMapMutex.RLock()
-	defer createCacheRatioMapMutex.RUnlock()
-	copyMap := make(map[string]float64, len(createCacheRatioMap))
-	for k, v := range createCacheRatioMap {
-		copyMap[k] = v
-	}
-	return copyMap
+	return createCacheRatioMap.ReadAll()
 }
 }

+ 25 - 62
setting/ratio_setting/group_ratio.go

@@ -3,29 +3,27 @@ package ratio_setting
 import (
 import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
-	"sync"
 
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/setting/config"
 	"github.com/QuantumNous/new-api/setting/config"
 	"github.com/QuantumNous/new-api/types"
 	"github.com/QuantumNous/new-api/types"
 )
 )
 
 
-var groupRatio = map[string]float64{
+var defaultGroupRatio = map[string]float64{
 	"default": 1,
 	"default": 1,
 	"vip":     1,
 	"vip":     1,
 	"svip":    1,
 	"svip":    1,
 }
 }
 
 
-var groupRatioMutex sync.RWMutex
+var groupRatioMap = types.NewRWMap[string, float64]()
 
 
-var (
-	GroupGroupRatio = map[string]map[string]float64{
-		"vip": {
-			"edit_this": 0.9,
-		},
-	}
-	groupGroupRatioMutex sync.RWMutex
-)
+var defaultGroupGroupRatio = map[string]map[string]float64{
+	"vip": {
+		"edit_this": 0.9,
+	},
+}
+
+var groupGroupRatioMap = types.NewRWMap[string, map[string]float64]()
 
 
 var defaultGroupSpecialUsableGroup = map[string]map[string]string{
 var defaultGroupSpecialUsableGroup = map[string]map[string]string{
 	"vip": {
 	"vip": {
@@ -35,9 +33,9 @@ var defaultGroupSpecialUsableGroup = map[string]map[string]string{
 }
 }
 
 
 type GroupRatioSetting struct {
 type GroupRatioSetting struct {
-	GroupRatio              map[string]float64                      `json:"group_ratio"`
-	GroupGroupRatio         map[string]map[string]float64           `json:"group_group_ratio"`
-	GroupSpecialUsableGroup *types.RWMap[string, map[string]string] `json:"group_special_usable_group"`
+	GroupRatio              *types.RWMap[string, float64]            `json:"group_ratio"`
+	GroupGroupRatio         *types.RWMap[string, map[string]float64] `json:"group_group_ratio"`
+	GroupSpecialUsableGroup *types.RWMap[string, map[string]string]  `json:"group_special_usable_group"`
 }
 }
 
 
 var groupRatioSetting GroupRatioSetting
 var groupRatioSetting GroupRatioSetting
@@ -46,10 +44,13 @@ func init() {
 	groupSpecialUsableGroup := types.NewRWMap[string, map[string]string]()
 	groupSpecialUsableGroup := types.NewRWMap[string, map[string]string]()
 	groupSpecialUsableGroup.AddAll(defaultGroupSpecialUsableGroup)
 	groupSpecialUsableGroup.AddAll(defaultGroupSpecialUsableGroup)
 
 
+	groupRatioMap.AddAll(defaultGroupRatio)
+	groupGroupRatioMap.AddAll(defaultGroupGroupRatio)
+
 	groupRatioSetting = GroupRatioSetting{
 	groupRatioSetting = GroupRatioSetting{
 		GroupSpecialUsableGroup: groupSpecialUsableGroup,
 		GroupSpecialUsableGroup: groupSpecialUsableGroup,
-		GroupRatio:              groupRatio,
-		GroupGroupRatio:         GroupGroupRatio,
+		GroupRatio:              groupRatioMap,
+		GroupGroupRatio:         groupGroupRatioMap,
 	}
 	}
 
 
 	config.GlobalConfig.Register("group_ratio_setting", &groupRatioSetting)
 	config.GlobalConfig.Register("group_ratio_setting", &groupRatioSetting)
@@ -64,48 +65,24 @@ func GetGroupRatioSetting() *GroupRatioSetting {
 }
 }
 
 
 func GetGroupRatioCopy() map[string]float64 {
 func GetGroupRatioCopy() map[string]float64 {
-	groupRatioMutex.RLock()
-	defer groupRatioMutex.RUnlock()
-
-	groupRatioCopy := make(map[string]float64)
-	for k, v := range groupRatio {
-		groupRatioCopy[k] = v
-	}
-	return groupRatioCopy
+	return groupRatioMap.ReadAll()
 }
 }
 
 
 func ContainsGroupRatio(name string) bool {
 func ContainsGroupRatio(name string) bool {
-	groupRatioMutex.RLock()
-	defer groupRatioMutex.RUnlock()
-
-	_, ok := groupRatio[name]
+	_, ok := groupRatioMap.Get(name)
 	return ok
 	return ok
 }
 }
 
 
 func GroupRatio2JSONString() string {
 func GroupRatio2JSONString() string {
-	groupRatioMutex.RLock()
-	defer groupRatioMutex.RUnlock()
-
-	jsonBytes, err := json.Marshal(groupRatio)
-	if err != nil {
-		common.SysLog("error marshalling model ratio: " + err.Error())
-	}
-	return string(jsonBytes)
+	return groupRatioMap.MarshalJSONString()
 }
 }
 
 
 func UpdateGroupRatioByJSONString(jsonStr string) error {
 func UpdateGroupRatioByJSONString(jsonStr string) error {
-	groupRatioMutex.Lock()
-	defer groupRatioMutex.Unlock()
-
-	groupRatio = make(map[string]float64)
-	return json.Unmarshal([]byte(jsonStr), &groupRatio)
+	return types.LoadFromJsonString(groupRatioMap, jsonStr)
 }
 }
 
 
 func GetGroupRatio(name string) float64 {
 func GetGroupRatio(name string) float64 {
-	groupRatioMutex.RLock()
-	defer groupRatioMutex.RUnlock()
-
-	ratio, ok := groupRatio[name]
+	ratio, ok := groupRatioMap.Get(name)
 	if !ok {
 	if !ok {
 		common.SysLog("group ratio not found: " + name)
 		common.SysLog("group ratio not found: " + name)
 		return 1
 		return 1
@@ -114,10 +91,7 @@ func GetGroupRatio(name string) float64 {
 }
 }
 
 
 func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) {
 func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) {
-	groupGroupRatioMutex.RLock()
-	defer groupGroupRatioMutex.RUnlock()
-
-	gp, ok := GroupGroupRatio[userGroup]
+	gp, ok := groupGroupRatioMap.Get(userGroup)
 	if !ok {
 	if !ok {
 		return -1, false
 		return -1, false
 	}
 	}
@@ -129,22 +103,11 @@ func GetGroupGroupRatio(userGroup, usingGroup string) (float64, bool) {
 }
 }
 
 
 func GroupGroupRatio2JSONString() string {
 func GroupGroupRatio2JSONString() string {
-	groupGroupRatioMutex.RLock()
-	defer groupGroupRatioMutex.RUnlock()
-
-	jsonBytes, err := json.Marshal(GroupGroupRatio)
-	if err != nil {
-		common.SysLog("error marshalling group-group ratio: " + err.Error())
-	}
-	return string(jsonBytes)
+	return groupGroupRatioMap.MarshalJSONString()
 }
 }
 
 
 func UpdateGroupGroupRatioByJSONString(jsonStr string) error {
 func UpdateGroupGroupRatioByJSONString(jsonStr string) error {
-	groupGroupRatioMutex.Lock()
-	defer groupGroupRatioMutex.Unlock()
-
-	GroupGroupRatio = make(map[string]map[string]float64)
-	return json.Unmarshal([]byte(jsonStr), &GroupGroupRatio)
+	return types.LoadFromJsonString(groupGroupRatioMap, jsonStr)
 }
 }
 
 
 func CheckGroupRatio(jsonStr string) error {
 func CheckGroupRatio(jsonStr string) error {

+ 42 - 229
setting/ratio_setting/model_ratio.go

@@ -1,12 +1,11 @@
 package ratio_setting
 package ratio_setting
 
 
 import (
 import (
-	"encoding/json"
 	"strings"
 	"strings"
-	"sync"
 
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/setting/operation_setting"
 	"github.com/QuantumNous/new-api/setting/operation_setting"
+	"github.com/QuantumNous/new-api/types"
 )
 )
 
 
 // from songquanpeng/one-api
 // from songquanpeng/one-api
@@ -319,19 +318,9 @@ var defaultAudioCompletionRatio = map[string]float64{
 	"tts-1-hd-1106":        0,
 	"tts-1-hd-1106":        0,
 }
 }
 
 
-var (
-	modelPriceMap      map[string]float64 = nil
-	modelPriceMapMutex                    = sync.RWMutex{}
-)
-var (
-	modelRatioMap      map[string]float64 = nil
-	modelRatioMapMutex                    = sync.RWMutex{}
-)
-
-var (
-	CompletionRatio      map[string]float64 = nil
-	CompletionRatioMutex                    = sync.RWMutex{}
-)
+var modelPriceMap = types.NewRWMap[string, float64]()
+var modelRatioMap = types.NewRWMap[string, float64]()
+var completionRatioMap = types.NewRWMap[string, float64]()
 
 
 var defaultCompletionRatio = map[string]float64{
 var defaultCompletionRatio = map[string]float64{
 	"gpt-4-gizmo-*":  2,
 	"gpt-4-gizmo-*":  2,
@@ -342,84 +331,34 @@ var defaultCompletionRatio = map[string]float64{
 
 
 // InitRatioSettings initializes all model related settings maps
 // InitRatioSettings initializes all model related settings maps
 func InitRatioSettings() {
 func InitRatioSettings() {
-	// Initialize modelPriceMap
-	modelPriceMapMutex.Lock()
-	modelPriceMap = defaultModelPrice
-	modelPriceMapMutex.Unlock()
-
-	// Initialize modelRatioMap
-	modelRatioMapMutex.Lock()
-	modelRatioMap = defaultModelRatio
-	modelRatioMapMutex.Unlock()
-
-	// Initialize CompletionRatio
-	CompletionRatioMutex.Lock()
-	CompletionRatio = defaultCompletionRatio
-	CompletionRatioMutex.Unlock()
-
-	// Initialize cacheRatioMap
-	cacheRatioMapMutex.Lock()
-	cacheRatioMap = defaultCacheRatio
-	cacheRatioMapMutex.Unlock()
-
-	// Initialize createCacheRatioMap (5m cache creation ratio)
-	createCacheRatioMapMutex.Lock()
-	createCacheRatioMap = defaultCreateCacheRatio
-	createCacheRatioMapMutex.Unlock()
-
-	// initialize imageRatioMap
-	imageRatioMapMutex.Lock()
-	imageRatioMap = defaultImageRatio
-	imageRatioMapMutex.Unlock()
-
-	// initialize audioRatioMap
-	audioRatioMapMutex.Lock()
-	audioRatioMap = defaultAudioRatio
-	audioRatioMapMutex.Unlock()
-
-	// initialize audioCompletionRatioMap
-	audioCompletionRatioMapMutex.Lock()
-	audioCompletionRatioMap = defaultAudioCompletionRatio
-	audioCompletionRatioMapMutex.Unlock()
+	modelPriceMap.AddAll(defaultModelPrice)
+	modelRatioMap.AddAll(defaultModelRatio)
+	completionRatioMap.AddAll(defaultCompletionRatio)
+	cacheRatioMap.AddAll(defaultCacheRatio)
+	createCacheRatioMap.AddAll(defaultCreateCacheRatio)
+	imageRatioMap.AddAll(defaultImageRatio)
+	audioRatioMap.AddAll(defaultAudioRatio)
+	audioCompletionRatioMap.AddAll(defaultAudioCompletionRatio)
 }
 }
 
 
 func GetModelPriceMap() map[string]float64 {
 func GetModelPriceMap() map[string]float64 {
-	modelPriceMapMutex.RLock()
-	defer modelPriceMapMutex.RUnlock()
-	return modelPriceMap
+	return modelPriceMap.ReadAll()
 }
 }
 
 
 func ModelPrice2JSONString() string {
 func ModelPrice2JSONString() string {
-	modelPriceMapMutex.RLock()
-	defer modelPriceMapMutex.RUnlock()
-
-	jsonBytes, err := common.Marshal(modelPriceMap)
-	if err != nil {
-		common.SysError("error marshalling model price: " + err.Error())
-	}
-	return string(jsonBytes)
+	return modelPriceMap.MarshalJSONString()
 }
 }
 
 
 func UpdateModelPriceByJSONString(jsonStr string) error {
 func UpdateModelPriceByJSONString(jsonStr string) error {
-	modelPriceMapMutex.Lock()
-	defer modelPriceMapMutex.Unlock()
-	modelPriceMap = make(map[string]float64)
-	err := json.Unmarshal([]byte(jsonStr), &modelPriceMap)
-	if err == nil {
-		InvalidateExposedDataCache()
-	}
-	return err
+	return types.LoadFromJsonStringWithCallback(modelPriceMap, jsonStr, InvalidateExposedDataCache)
 }
 }
 
 
 // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
 // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
 func GetModelPrice(name string, printErr bool) (float64, bool) {
 func GetModelPrice(name string, printErr bool) (float64, bool) {
-	modelPriceMapMutex.RLock()
-	defer modelPriceMapMutex.RUnlock()
-
 	name = FormatMatchingModelName(name)
 	name = FormatMatchingModelName(name)
 
 
 	if strings.HasSuffix(name, CompactModelSuffix) {
 	if strings.HasSuffix(name, CompactModelSuffix) {
-		price, ok := modelPriceMap[CompactWildcardModelKey]
+		price, ok := modelPriceMap.Get(CompactWildcardModelKey)
 		if !ok {
 		if !ok {
 			if printErr {
 			if printErr {
 				common.SysError("model price not found: " + name)
 				common.SysError("model price not found: " + name)
@@ -429,7 +368,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
 		return price, true
 		return price, true
 	}
 	}
 
 
-	price, ok := modelPriceMap[name]
+	price, ok := modelPriceMap.Get(name)
 	if !ok {
 	if !ok {
 		if printErr {
 		if printErr {
 			common.SysError("model price not found: " + name)
 			common.SysError("model price not found: " + name)
@@ -440,14 +379,7 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
 }
 }
 
 
 func UpdateModelRatioByJSONString(jsonStr string) error {
 func UpdateModelRatioByJSONString(jsonStr string) error {
-	modelRatioMapMutex.Lock()
-	defer modelRatioMapMutex.Unlock()
-	modelRatioMap = make(map[string]float64)
-	err := common.Unmarshal([]byte(jsonStr), &modelRatioMap)
-	if err == nil {
-		InvalidateExposedDataCache()
-	}
-	return err
+	return types.LoadFromJsonStringWithCallback(modelRatioMap, jsonStr, InvalidateExposedDataCache)
 }
 }
 
 
 // 处理带有思考预算的模型名称,方便统一定价
 // 处理带有思考预算的模型名称,方便统一定价
@@ -459,15 +391,12 @@ func handleThinkingBudgetModel(name, prefix, wildcard string) string {
 }
 }
 
 
 func GetModelRatio(name string) (float64, bool, string) {
 func GetModelRatio(name string) (float64, bool, string) {
-	modelRatioMapMutex.RLock()
-	defer modelRatioMapMutex.RUnlock()
-
 	name = FormatMatchingModelName(name)
 	name = FormatMatchingModelName(name)
 
 
-	ratio, ok := modelRatioMap[name]
+	ratio, ok := modelRatioMap.Get(name)
 	if !ok {
 	if !ok {
 		if strings.HasSuffix(name, CompactModelSuffix) {
 		if strings.HasSuffix(name, CompactModelSuffix) {
-			if wildcardRatio, ok := modelRatioMap[CompactWildcardModelKey]; ok {
+			if wildcardRatio, ok := modelRatioMap.Get(CompactWildcardModelKey); ok {
 				return wildcardRatio, true, name
 				return wildcardRatio, true, name
 			}
 			}
 			//return 0, true, name
 			//return 0, true, name
@@ -493,54 +422,19 @@ func GetDefaultModelPriceMap() map[string]float64 {
 	return defaultModelPrice
 	return defaultModelPrice
 }
 }
 
 
-func GetDefaultImageRatioMap() map[string]float64 {
-	return defaultImageRatio
-}
-
-func GetDefaultAudioRatioMap() map[string]float64 {
-	return defaultAudioRatio
-}
-
-func GetDefaultAudioCompletionRatioMap() map[string]float64 {
-	return defaultAudioCompletionRatio
-}
-
-func GetCompletionRatioMap() map[string]float64 {
-	CompletionRatioMutex.RLock()
-	defer CompletionRatioMutex.RUnlock()
-	return CompletionRatio
-}
-
 func CompletionRatio2JSONString() string {
 func CompletionRatio2JSONString() string {
-	CompletionRatioMutex.RLock()
-	defer CompletionRatioMutex.RUnlock()
-
-	jsonBytes, err := json.Marshal(CompletionRatio)
-	if err != nil {
-		common.SysError("error marshalling completion ratio: " + err.Error())
-	}
-	return string(jsonBytes)
+	return completionRatioMap.MarshalJSONString()
 }
 }
 
 
 func UpdateCompletionRatioByJSONString(jsonStr string) error {
 func UpdateCompletionRatioByJSONString(jsonStr string) error {
-	CompletionRatioMutex.Lock()
-	defer CompletionRatioMutex.Unlock()
-	CompletionRatio = make(map[string]float64)
-	err := common.Unmarshal([]byte(jsonStr), &CompletionRatio)
-	if err == nil {
-		InvalidateExposedDataCache()
-	}
-	return err
+	return types.LoadFromJsonStringWithCallback(completionRatioMap, jsonStr, InvalidateExposedDataCache)
 }
 }
 
 
 func GetCompletionRatio(name string) float64 {
 func GetCompletionRatio(name string) float64 {
-	CompletionRatioMutex.RLock()
-	defer CompletionRatioMutex.RUnlock()
-
 	name = FormatMatchingModelName(name)
 	name = FormatMatchingModelName(name)
 
 
 	if strings.Contains(name, "/") {
 	if strings.Contains(name, "/") {
-		if ratio, ok := CompletionRatio[name]; ok {
+		if ratio, ok := completionRatioMap.Get(name); ok {
 			return ratio
 			return ratio
 		}
 		}
 	}
 	}
@@ -548,7 +442,7 @@ func GetCompletionRatio(name string) float64 {
 	if contain {
 	if contain {
 		return hardCodedRatio
 		return hardCodedRatio
 	}
 	}
-	if ratio, ok := CompletionRatio[name]; ok {
+	if ratio, ok := completionRatioMap.Get(name); ok {
 		return ratio
 		return ratio
 	}
 	}
 	return hardCodedRatio
 	return hardCodedRatio
@@ -676,88 +570,54 @@ func getHardcodedCompletionModelRatio(name string) (float64, bool) {
 }
 }
 
 
 func GetAudioRatio(name string) float64 {
 func GetAudioRatio(name string) float64 {
-	audioRatioMapMutex.RLock()
-	defer audioRatioMapMutex.RUnlock()
 	name = FormatMatchingModelName(name)
 	name = FormatMatchingModelName(name)
-	if ratio, ok := audioRatioMap[name]; ok {
+	if ratio, ok := audioRatioMap.Get(name); ok {
 		return ratio
 		return ratio
 	}
 	}
 	return 1
 	return 1
 }
 }
 
 
 func GetAudioCompletionRatio(name string) float64 {
 func GetAudioCompletionRatio(name string) float64 {
-	audioCompletionRatioMapMutex.RLock()
-	defer audioCompletionRatioMapMutex.RUnlock()
 	name = FormatMatchingModelName(name)
 	name = FormatMatchingModelName(name)
-	if ratio, ok := audioCompletionRatioMap[name]; ok {
-
+	if ratio, ok := audioCompletionRatioMap.Get(name); ok {
 		return ratio
 		return ratio
 	}
 	}
 	return 1
 	return 1
 }
 }
 
 
 func ContainsAudioRatio(name string) bool {
 func ContainsAudioRatio(name string) bool {
-	audioRatioMapMutex.RLock()
-	defer audioRatioMapMutex.RUnlock()
 	name = FormatMatchingModelName(name)
 	name = FormatMatchingModelName(name)
-	_, ok := audioRatioMap[name]
+	_, ok := audioRatioMap.Get(name)
 	return ok
 	return ok
 }
 }
 
 
 func ContainsAudioCompletionRatio(name string) bool {
 func ContainsAudioCompletionRatio(name string) bool {
-	audioCompletionRatioMapMutex.RLock()
-	defer audioCompletionRatioMapMutex.RUnlock()
 	name = FormatMatchingModelName(name)
 	name = FormatMatchingModelName(name)
-	_, ok := audioCompletionRatioMap[name]
+	_, ok := audioCompletionRatioMap.Get(name)
 	return ok
 	return ok
 }
 }
 
 
 func ModelRatio2JSONString() string {
 func ModelRatio2JSONString() string {
-	modelRatioMapMutex.RLock()
-	defer modelRatioMapMutex.RUnlock()
-
-	jsonBytes, err := common.Marshal(modelRatioMap)
-	if err != nil {
-		common.SysError("error marshalling model ratio: " + err.Error())
-	}
-	return string(jsonBytes)
+	return modelRatioMap.MarshalJSONString()
 }
 }
 
 
 var defaultImageRatio = map[string]float64{
 var defaultImageRatio = map[string]float64{
 	"gpt-image-1": 2,
 	"gpt-image-1": 2,
 }
 }
-var imageRatioMap map[string]float64
-var imageRatioMapMutex sync.RWMutex
-var (
-	audioRatioMap      map[string]float64 = nil
-	audioRatioMapMutex                    = sync.RWMutex{}
-)
-var (
-	audioCompletionRatioMap      map[string]float64 = nil
-	audioCompletionRatioMapMutex                    = sync.RWMutex{}
-)
+var imageRatioMap = types.NewRWMap[string, float64]()
+var audioRatioMap = types.NewRWMap[string, float64]()
+var audioCompletionRatioMap = types.NewRWMap[string, float64]()
 
 
 func ImageRatio2JSONString() string {
 func ImageRatio2JSONString() string {
-	imageRatioMapMutex.RLock()
-	defer imageRatioMapMutex.RUnlock()
-	jsonBytes, err := common.Marshal(imageRatioMap)
-	if err != nil {
-		common.SysError("error marshalling cache ratio: " + err.Error())
-	}
-	return string(jsonBytes)
+	return imageRatioMap.MarshalJSONString()
 }
 }
 
 
 func UpdateImageRatioByJSONString(jsonStr string) error {
 func UpdateImageRatioByJSONString(jsonStr string) error {
-	imageRatioMapMutex.Lock()
-	defer imageRatioMapMutex.Unlock()
-	imageRatioMap = make(map[string]float64)
-	return common.Unmarshal([]byte(jsonStr), &imageRatioMap)
+	return types.LoadFromJsonString(imageRatioMap, jsonStr)
 }
 }
 
 
 func GetImageRatio(name string) (float64, bool) {
 func GetImageRatio(name string) (float64, bool) {
-	imageRatioMapMutex.RLock()
-	defer imageRatioMapMutex.RUnlock()
-	ratio, ok := imageRatioMap[name]
+	ratio, ok := imageRatioMap.Get(name)
 	if !ok {
 	if !ok {
 		return 1, false // Default to 1 if not found
 		return 1, false // Default to 1 if not found
 	}
 	}
@@ -765,78 +625,31 @@ func GetImageRatio(name string) (float64, bool) {
 }
 }
 
 
 func AudioRatio2JSONString() string {
 func AudioRatio2JSONString() string {
-	audioRatioMapMutex.RLock()
-	defer audioRatioMapMutex.RUnlock()
-	jsonBytes, err := common.Marshal(audioRatioMap)
-	if err != nil {
-		common.SysError("error marshalling audio ratio: " + err.Error())
-	}
-	return string(jsonBytes)
+	return audioRatioMap.MarshalJSONString()
 }
 }
 
 
 func UpdateAudioRatioByJSONString(jsonStr string) error {
 func UpdateAudioRatioByJSONString(jsonStr string) error {
-
-	tmp := make(map[string]float64)
-	if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
-		return err
-	}
-	audioRatioMapMutex.Lock()
-	audioRatioMap = tmp
-	audioRatioMapMutex.Unlock()
-	InvalidateExposedDataCache()
-	return nil
+	return types.LoadFromJsonStringWithCallback(audioRatioMap, jsonStr, InvalidateExposedDataCache)
 }
 }
 
 
 func AudioCompletionRatio2JSONString() string {
 func AudioCompletionRatio2JSONString() string {
-	audioCompletionRatioMapMutex.RLock()
-	defer audioCompletionRatioMapMutex.RUnlock()
-	jsonBytes, err := common.Marshal(audioCompletionRatioMap)
-	if err != nil {
-		common.SysError("error marshalling audio completion ratio: " + err.Error())
-	}
-	return string(jsonBytes)
+	return audioCompletionRatioMap.MarshalJSONString()
 }
 }
 
 
 func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
 func UpdateAudioCompletionRatioByJSONString(jsonStr string) error {
-	tmp := make(map[string]float64)
-	if err := common.Unmarshal([]byte(jsonStr), &tmp); err != nil {
-		return err
-	}
-	audioCompletionRatioMapMutex.Lock()
-	audioCompletionRatioMap = tmp
-	audioCompletionRatioMapMutex.Unlock()
-	InvalidateExposedDataCache()
-	return nil
+	return types.LoadFromJsonStringWithCallback(audioCompletionRatioMap, jsonStr, InvalidateExposedDataCache)
 }
 }
 
 
 func GetModelRatioCopy() map[string]float64 {
 func GetModelRatioCopy() map[string]float64 {
-	modelRatioMapMutex.RLock()
-	defer modelRatioMapMutex.RUnlock()
-	copyMap := make(map[string]float64, len(modelRatioMap))
-	for k, v := range modelRatioMap {
-		copyMap[k] = v
-	}
-	return copyMap
+	return modelRatioMap.ReadAll()
 }
 }
 
 
 func GetModelPriceCopy() map[string]float64 {
 func GetModelPriceCopy() map[string]float64 {
-	modelPriceMapMutex.RLock()
-	defer modelPriceMapMutex.RUnlock()
-	copyMap := make(map[string]float64, len(modelPriceMap))
-	for k, v := range modelPriceMap {
-		copyMap[k] = v
-	}
-	return copyMap
+	return modelPriceMap.ReadAll()
 }
 }
 
 
 func GetCompletionRatioCopy() map[string]float64 {
 func GetCompletionRatioCopy() map[string]float64 {
-	CompletionRatioMutex.RLock()
-	defer CompletionRatioMutex.RUnlock()
-	copyMap := make(map[string]float64, len(CompletionRatio))
-	for k, v := range CompletionRatio {
-		copyMap[k] = v
-	}
-	return copyMap
+	return completionRatioMap.ReadAll()
 }
 }
 
 
 // 转换模型名,减少渠道必须配置各种带参数模型
 // 转换模型名,减少渠道必须配置各种带参数模型

+ 21 - 0
types/rw_map.go

@@ -80,3 +80,24 @@ func LoadFromJsonString[K comparable, V any](m *RWMap[K, V], jsonStr string) err
 	m.data = make(map[K]V)
 	m.data = make(map[K]V)
 	return common.Unmarshal([]byte(jsonStr), &m.data)
 	return common.Unmarshal([]byte(jsonStr), &m.data)
 }
 }
+
+// LoadFromJsonStringWithCallback loads a JSON string into the RWMap and calls the callback on success.
+func LoadFromJsonStringWithCallback[K comparable, V any](m *RWMap[K, V], jsonStr string, onSuccess func()) error {
+	m.mutex.Lock()
+	defer m.mutex.Unlock()
+	m.data = make(map[K]V)
+	err := common.Unmarshal([]byte(jsonStr), &m.data)
+	if err == nil && onSuccess != nil {
+		onSuccess()
+	}
+	return err
+}
+
+// MarshalJSONString returns the JSON string representation of the RWMap.
+func (m *RWMap[K, V]) MarshalJSONString() string {
+	bytes, err := m.MarshalJSON()
+	if err != nil {
+		return "{}"
+	}
+	return string(bytes)
+}