| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- package model
- import (
- "fmt"
- "one-api/common"
- "one-api/constant"
- "one-api/setting/ratio_setting"
- "one-api/types"
- "sync"
- "time"
- )
- type Pricing struct {
- ModelName string `json:"model_name"`
- QuotaType int `json:"quota_type"`
- ModelRatio float64 `json:"model_ratio"`
- ModelPrice float64 `json:"model_price"`
- OwnerBy string `json:"owner_by"`
- CompletionRatio float64 `json:"completion_ratio"`
- EnableGroup []string `json:"enable_groups"`
- SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
- }
- var (
- pricingMap []Pricing
- lastGetPricingTime time.Time
- updatePricingLock sync.Mutex
- )
- var (
- modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
- modelSupportEndpointsLock = sync.RWMutex{}
- )
- func GetPricing() []Pricing {
- if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
- updatePricingLock.Lock()
- defer updatePricingLock.Unlock()
- // Double check after acquiring the lock
- if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
- modelSupportEndpointsLock.Lock()
- defer modelSupportEndpointsLock.Unlock()
- updatePricing()
- }
- }
- return pricingMap
- }
- func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
- if model == "" {
- return make([]constant.EndpointType, 0)
- }
- modelSupportEndpointsLock.RLock()
- defer modelSupportEndpointsLock.RUnlock()
- if endpoints, ok := modelSupportEndpointTypes[model]; ok {
- return endpoints
- }
- return make([]constant.EndpointType, 0)
- }
- func updatePricing() {
- //modelRatios := common.GetModelRatios()
- enableAbilities, err := GetAllEnableAbilityWithChannels()
- if err != nil {
- common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
- return
- }
- modelGroupsMap := make(map[string]*types.Set[string])
- for _, ability := range enableAbilities {
- groups, ok := modelGroupsMap[ability.Model]
- if !ok {
- groups = types.NewSet[string]()
- modelGroupsMap[ability.Model] = groups
- }
- groups.Add(ability.Group)
- }
- //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
- modelSupportEndpointsStr := make(map[string][]string)
- for _, ability := range enableAbilities {
- endpoints, ok := modelSupportEndpointsStr[ability.Model]
- if !ok {
- endpoints = make([]string, 0)
- modelSupportEndpointsStr[ability.Model] = endpoints
- }
- channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
- for _, channelType := range channelTypes {
- if !common.StringsContains(endpoints, string(channelType)) {
- endpoints = append(endpoints, string(channelType))
- }
- }
- modelSupportEndpointsStr[ability.Model] = endpoints
- }
- modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
- for model, endpoints := range modelSupportEndpointsStr {
- supportedEndpoints := make([]constant.EndpointType, 0)
- for _, endpointStr := range endpoints {
- endpointType := constant.EndpointType(endpointStr)
- supportedEndpoints = append(supportedEndpoints, endpointType)
- }
- modelSupportEndpointTypes[model] = supportedEndpoints
- }
- pricingMap = make([]Pricing, 0)
- for model, groups := range modelGroupsMap {
- pricing := Pricing{
- ModelName: model,
- EnableGroup: groups.Items(),
- SupportedEndpointTypes: modelSupportEndpointTypes[model],
- }
- modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
- if findPrice {
- pricing.ModelPrice = modelPrice
- pricing.QuotaType = 1
- } else {
- modelRatio, _, _ := ratio_setting.GetModelRatio(model)
- pricing.ModelRatio = modelRatio
- pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
- pricing.QuotaType = 0
- }
- pricingMap = append(pricingMap, pricing)
- }
- lastGetPricingTime = time.Now()
- }
|