| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312 |
- package model
- import (
- "encoding/json"
- "fmt"
- "strings"
- "one-api/common"
- "one-api/constant"
- "one-api/setting/ratio_setting"
- "one-api/types"
- "sync"
- "time"
- )
- type Pricing struct {
- ModelName string `json:"model_name"`
- Description string `json:"description,omitempty"`
- Icon string `json:"icon,omitempty"`
- Tags string `json:"tags,omitempty"`
- VendorID int `json:"vendor_id,omitempty"`
- QuotaType int `json:"quota_type"`
- ModelRatio float64 `json:"model_ratio"`
- ModelPrice float64 `json:"model_price"`
- OwnerBy string `json:"owner_by"`
- CompletionRatio float64 `json:"completion_ratio"`
- EnableGroup []string `json:"enable_groups"`
- SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
- }
- type PricingVendor struct {
- ID int `json:"id"`
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
- Icon string `json:"icon,omitempty"`
- }
- var (
- pricingMap []Pricing
- vendorsList []PricingVendor
- supportedEndpointMap map[string]common.EndpointInfo
- lastGetPricingTime time.Time
- updatePricingLock sync.Mutex
- // 缓存映射:模型名 -> 启用分组 / 计费类型
- modelEnableGroups = make(map[string][]string)
- modelQuotaTypeMap = make(map[string]int)
- modelEnableGroupsLock = sync.RWMutex{}
- )
- var (
- modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
- modelSupportEndpointsLock = sync.RWMutex{}
- )
- func GetPricing() []Pricing {
- if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
- updatePricingLock.Lock()
- defer updatePricingLock.Unlock()
- // Double check after acquiring the lock
- if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
- modelSupportEndpointsLock.Lock()
- defer modelSupportEndpointsLock.Unlock()
- updatePricing()
- }
- }
- return pricingMap
- }
- // GetVendors 返回当前定价接口使用到的供应商信息
- func GetVendors() []PricingVendor {
- if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
- // 保证先刷新一次
- GetPricing()
- }
- return vendorsList
- }
- func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
- if model == "" {
- return make([]constant.EndpointType, 0)
- }
- modelSupportEndpointsLock.RLock()
- defer modelSupportEndpointsLock.RUnlock()
- if endpoints, ok := modelSupportEndpointTypes[model]; ok {
- return endpoints
- }
- return make([]constant.EndpointType, 0)
- }
- func updatePricing() {
- //modelRatios := common.GetModelRatios()
- enableAbilities, err := GetAllEnableAbilityWithChannels()
- if err != nil {
- common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
- return
- }
- // 预加载模型元数据与供应商一次,避免循环查询
- var allMeta []Model
- _ = DB.Find(&allMeta).Error
- metaMap := make(map[string]*Model)
- prefixList := make([]*Model, 0)
- suffixList := make([]*Model, 0)
- containsList := make([]*Model, 0)
- for i := range allMeta {
- m := &allMeta[i]
- if m.NameRule == NameRuleExact {
- metaMap[m.ModelName] = m
- } else {
- switch m.NameRule {
- case NameRulePrefix:
- prefixList = append(prefixList, m)
- case NameRuleSuffix:
- suffixList = append(suffixList, m)
- case NameRuleContains:
- containsList = append(containsList, m)
- }
- }
- }
- // 将非精确规则模型匹配到 metaMap
- for _, m := range prefixList {
- for _, pricingModel := range enableAbilities {
- if strings.HasPrefix(pricingModel.Model, m.ModelName) {
- if _, exists := metaMap[pricingModel.Model]; !exists {
- metaMap[pricingModel.Model] = m
- }
- }
- }
- }
- for _, m := range suffixList {
- for _, pricingModel := range enableAbilities {
- if strings.HasSuffix(pricingModel.Model, m.ModelName) {
- if _, exists := metaMap[pricingModel.Model]; !exists {
- metaMap[pricingModel.Model] = m
- }
- }
- }
- }
- for _, m := range containsList {
- for _, pricingModel := range enableAbilities {
- if strings.Contains(pricingModel.Model, m.ModelName) {
- if _, exists := metaMap[pricingModel.Model]; !exists {
- metaMap[pricingModel.Model] = m
- }
- }
- }
- }
- // 预加载供应商
- var vendors []Vendor
- _ = DB.Find(&vendors).Error
- vendorMap := make(map[int]*Vendor)
- for i := range vendors {
- vendorMap[vendors[i].Id] = &vendors[i]
- }
- // 初始化默认供应商映射
- initDefaultVendorMapping(metaMap, vendorMap, enableAbilities)
- // 构建对前端友好的供应商列表
- vendorsList = make([]PricingVendor, 0, len(vendorMap))
- for _, v := range vendorMap {
- vendorsList = append(vendorsList, PricingVendor{
- ID: v.Id,
- Name: v.Name,
- Description: v.Description,
- Icon: v.Icon,
- })
- }
- modelGroupsMap := make(map[string]*types.Set[string])
- for _, ability := range enableAbilities {
- groups, ok := modelGroupsMap[ability.Model]
- if !ok {
- groups = types.NewSet[string]()
- modelGroupsMap[ability.Model] = groups
- }
- groups.Add(ability.Group)
- }
- //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
- modelSupportEndpointsStr := make(map[string][]string)
- // 先根据已有能力填充原生端点
- for _, ability := range enableAbilities {
- endpoints := modelSupportEndpointsStr[ability.Model]
- channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
- for _, channelType := range channelTypes {
- if !common.StringsContains(endpoints, string(channelType)) {
- endpoints = append(endpoints, string(channelType))
- }
- }
- modelSupportEndpointsStr[ability.Model] = endpoints
- }
- // 再补充模型自定义端点
- for modelName, meta := range metaMap {
- if strings.TrimSpace(meta.Endpoints) == "" {
- continue
- }
- var raw map[string]interface{}
- if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
- endpoints := modelSupportEndpointsStr[modelName]
- for k := range raw {
- if !common.StringsContains(endpoints, k) {
- endpoints = append(endpoints, k)
- }
- }
- modelSupportEndpointsStr[modelName] = endpoints
- }
- }
- modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
- for model, endpoints := range modelSupportEndpointsStr {
- supportedEndpoints := make([]constant.EndpointType, 0)
- for _, endpointStr := range endpoints {
- endpointType := constant.EndpointType(endpointStr)
- supportedEndpoints = append(supportedEndpoints, endpointType)
- }
- modelSupportEndpointTypes[model] = supportedEndpoints
- }
- // 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
- supportedEndpointMap = make(map[string]common.EndpointInfo)
- // 1. 默认端点
- for _, endpoints := range modelSupportEndpointTypes {
- for _, et := range endpoints {
- if info, ok := common.GetDefaultEndpointInfo(et); ok {
- if _, exists := supportedEndpointMap[string(et)]; !exists {
- supportedEndpointMap[string(et)] = info
- }
- }
- }
- }
- // 2. 自定义端点(models 表)覆盖默认
- for _, meta := range metaMap {
- if strings.TrimSpace(meta.Endpoints) == "" {
- continue
- }
- var raw map[string]interface{}
- if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
- for k, v := range raw {
- switch val := v.(type) {
- case string:
- supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
- case map[string]interface{}:
- ep := common.EndpointInfo{Method: "POST"}
- if p, ok := val["path"].(string); ok {
- ep.Path = p
- }
- if m, ok := val["method"].(string); ok {
- ep.Method = strings.ToUpper(m)
- }
- supportedEndpointMap[k] = ep
- default:
- // ignore unsupported types
- }
- }
- }
- }
- pricingMap = make([]Pricing, 0)
- for model, groups := range modelGroupsMap {
- pricing := Pricing{
- ModelName: model,
- EnableGroup: groups.Items(),
- SupportedEndpointTypes: modelSupportEndpointTypes[model],
- }
- // 补充模型元数据(描述、标签、供应商、状态)
- if meta, ok := metaMap[model]; ok {
- // 若模型被禁用(status!=1),则直接跳过,不返回给前端
- if meta.Status != 1 {
- continue
- }
- pricing.Description = meta.Description
- pricing.Icon = meta.Icon
- pricing.Tags = meta.Tags
- pricing.VendorID = meta.VendorID
- }
- modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
- if findPrice {
- pricing.ModelPrice = modelPrice
- pricing.QuotaType = 1
- } else {
- modelRatio, _, _ := ratio_setting.GetModelRatio(model)
- pricing.ModelRatio = modelRatio
- pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
- pricing.QuotaType = 0
- }
- pricingMap = append(pricingMap, pricing)
- }
- // 刷新缓存映射,供高并发快速查询
- modelEnableGroupsLock.Lock()
- modelEnableGroups = make(map[string][]string)
- modelQuotaTypeMap = make(map[string]int)
- for _, p := range pricingMap {
- modelEnableGroups[p.ModelName] = p.EnableGroup
- modelQuotaTypeMap[p.ModelName] = p.QuotaType
- }
- modelEnableGroupsLock.Unlock()
- lastGetPricingTime = time.Now()
- }
- // GetSupportedEndpointMap 返回全局端点到路径的映射
- func GetSupportedEndpointMap() map[string]common.EndpointInfo {
- return supportedEndpointMap
- }
|