package model import ( "encoding/json" "fmt" "strings" "sync" "time" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/QuantumNous/new-api/types" ) type Pricing struct { ModelName string `json:"model_name"` Description string `json:"description,omitempty"` Icon string `json:"icon,omitempty"` Tags string `json:"tags,omitempty"` VendorID int `json:"vendor_id,omitempty"` QuotaType int `json:"quota_type"` ModelRatio float64 `json:"model_ratio"` ModelPrice float64 `json:"model_price"` OwnerBy string `json:"owner_by"` CompletionRatio float64 `json:"completion_ratio"` EnableGroup []string `json:"enable_groups"` SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` } type PricingVendor struct { ID int `json:"id"` Name string `json:"name"` Description string `json:"description,omitempty"` Icon string `json:"icon,omitempty"` } var ( pricingMap []Pricing vendorsList []PricingVendor supportedEndpointMap map[string]common.EndpointInfo lastGetPricingTime time.Time updatePricingLock sync.Mutex // 缓存映射:模型名 -> 启用分组 / 计费类型 modelEnableGroups = make(map[string][]string) modelQuotaTypeMap = make(map[string]int) modelEnableGroupsLock = sync.RWMutex{} ) var ( modelSupportEndpointTypes = make(map[string][]constant.EndpointType) modelSupportEndpointsLock = sync.RWMutex{} ) func GetPricing() []Pricing { if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { updatePricingLock.Lock() defer updatePricingLock.Unlock() // Double check after acquiring the lock if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { modelSupportEndpointsLock.Lock() defer modelSupportEndpointsLock.Unlock() updatePricing() } } return pricingMap } // GetVendors 返回当前定价接口使用到的供应商信息 func GetVendors() []PricingVendor { if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { // 保证先刷新一次 GetPricing() } return vendorsList } func GetModelSupportEndpointTypes(model string) []constant.EndpointType { if model == "" { return make([]constant.EndpointType, 0) } modelSupportEndpointsLock.RLock() defer modelSupportEndpointsLock.RUnlock() if endpoints, ok := modelSupportEndpointTypes[model]; ok { return endpoints } return make([]constant.EndpointType, 0) } func updatePricing() { //modelRatios := common.GetModelRatios() enableAbilities, err := GetAllEnableAbilityWithChannels() if err != nil { common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) return } // 预加载模型元数据与供应商一次,避免循环查询 var allMeta []Model _ = DB.Find(&allMeta).Error metaMap := make(map[string]*Model) prefixList := make([]*Model, 0) suffixList := make([]*Model, 0) containsList := make([]*Model, 0) for i := range allMeta { m := &allMeta[i] if m.NameRule == NameRuleExact { metaMap[m.ModelName] = m } else { switch m.NameRule { case NameRulePrefix: prefixList = append(prefixList, m) case NameRuleSuffix: suffixList = append(suffixList, m) case NameRuleContains: containsList = append(containsList, m) } } } // 将非精确规则模型匹配到 metaMap for _, m := range prefixList { for _, pricingModel := range enableAbilities { if strings.HasPrefix(pricingModel.Model, m.ModelName) { if _, exists := metaMap[pricingModel.Model]; !exists { metaMap[pricingModel.Model] = m } } } } for _, m := range suffixList { for _, pricingModel := range enableAbilities { if strings.HasSuffix(pricingModel.Model, m.ModelName) { if _, exists := metaMap[pricingModel.Model]; !exists { metaMap[pricingModel.Model] = m } } } } for _, m := range containsList { for _, pricingModel := range enableAbilities { if strings.Contains(pricingModel.Model, m.ModelName) { if _, exists := metaMap[pricingModel.Model]; !exists { metaMap[pricingModel.Model] = m } } } } // 预加载供应商 var vendors []Vendor _ = DB.Find(&vendors).Error vendorMap := make(map[int]*Vendor) for i := range vendors { vendorMap[vendors[i].Id] = &vendors[i] } // 初始化默认供应商映射 initDefaultVendorMapping(metaMap, vendorMap, enableAbilities) // 构建对前端友好的供应商列表 vendorsList = make([]PricingVendor, 0, len(vendorMap)) for _, v := range vendorMap { vendorsList = append(vendorsList, PricingVendor{ ID: v.Id, Name: v.Name, Description: v.Description, Icon: v.Icon, }) } modelGroupsMap := make(map[string]*types.Set[string]) for _, ability := range enableAbilities { groups, ok := modelGroupsMap[ability.Model] if !ok { groups = types.NewSet[string]() modelGroupsMap[ability.Model] = groups } groups.Add(ability.Group) } //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点 modelSupportEndpointsStr := make(map[string][]string) // 先根据已有能力填充原生端点 for _, ability := range enableAbilities { endpoints := modelSupportEndpointsStr[ability.Model] channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) for _, channelType := range channelTypes { if !common.StringsContains(endpoints, string(channelType)) { endpoints = append(endpoints, string(channelType)) } } modelSupportEndpointsStr[ability.Model] = endpoints } // 再补充模型自定义端点 for modelName, meta := range metaMap { if strings.TrimSpace(meta.Endpoints) == "" { continue } var raw map[string]interface{} if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { endpoints := modelSupportEndpointsStr[modelName] for k := range raw { if !common.StringsContains(endpoints, k) { endpoints = append(endpoints, k) } } modelSupportEndpointsStr[modelName] = endpoints } } modelSupportEndpointTypes = make(map[string][]constant.EndpointType) for model, endpoints := range modelSupportEndpointsStr { supportedEndpoints := make([]constant.EndpointType, 0) for _, endpointStr := range endpoints { endpointType := constant.EndpointType(endpointStr) supportedEndpoints = append(supportedEndpoints, endpointType) } modelSupportEndpointTypes[model] = supportedEndpoints } // 构建全局 supportedEndpointMap(默认 + 自定义覆盖) supportedEndpointMap = make(map[string]common.EndpointInfo) // 1. 默认端点 for _, endpoints := range modelSupportEndpointTypes { for _, et := range endpoints { if info, ok := common.GetDefaultEndpointInfo(et); ok { if _, exists := supportedEndpointMap[string(et)]; !exists { supportedEndpointMap[string(et)] = info } } } } // 2. 自定义端点(models 表)覆盖默认 for _, meta := range metaMap { if strings.TrimSpace(meta.Endpoints) == "" { continue } var raw map[string]interface{} if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { for k, v := range raw { switch val := v.(type) { case string: supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"} case map[string]interface{}: ep := common.EndpointInfo{Method: "POST"} if p, ok := val["path"].(string); ok { ep.Path = p } if m, ok := val["method"].(string); ok { ep.Method = strings.ToUpper(m) } supportedEndpointMap[k] = ep default: // ignore unsupported types } } } } pricingMap = make([]Pricing, 0) for model, groups := range modelGroupsMap { pricing := Pricing{ ModelName: model, EnableGroup: groups.Items(), SupportedEndpointTypes: modelSupportEndpointTypes[model], } // 补充模型元数据(描述、标签、供应商、状态) if meta, ok := metaMap[model]; ok { // 若模型被禁用(status!=1),则直接跳过,不返回给前端 if meta.Status != 1 { continue } pricing.Description = meta.Description pricing.Icon = meta.Icon pricing.Tags = meta.Tags pricing.VendorID = meta.VendorID } modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) if findPrice { pricing.ModelPrice = modelPrice pricing.QuotaType = 1 } else { modelRatio, _, _ := ratio_setting.GetModelRatio(model) pricing.ModelRatio = modelRatio pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model) pricing.QuotaType = 0 } pricingMap = append(pricingMap, pricing) } // 刷新缓存映射,供高并发快速查询 modelEnableGroupsLock.Lock() modelEnableGroups = make(map[string][]string) modelQuotaTypeMap = make(map[string]int) for _, p := range pricingMap { modelEnableGroups[p.ModelName] = p.EnableGroup modelQuotaTypeMap[p.ModelName] = p.QuotaType } modelEnableGroupsLock.Unlock() lastGetPricingTime = time.Now() } // GetSupportedEndpointMap 返回全局端点到路径的映射 func GetSupportedEndpointMap() map[string]common.EndpointInfo { return supportedEndpointMap }