pricing.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. package model
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "strings"
  6. "sync"
  7. "time"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/setting/ratio_setting"
  11. "github.com/QuantumNous/new-api/types"
  12. )
  13. type Pricing struct {
  14. ModelName string `json:"model_name"`
  15. Description string `json:"description,omitempty"`
  16. Icon string `json:"icon,omitempty"`
  17. Tags string `json:"tags,omitempty"`
  18. VendorID int `json:"vendor_id,omitempty"`
  19. QuotaType int `json:"quota_type"`
  20. ModelRatio float64 `json:"model_ratio"`
  21. ModelPrice float64 `json:"model_price"`
  22. OwnerBy string `json:"owner_by"`
  23. CompletionRatio float64 `json:"completion_ratio"`
  24. CacheRatio *float64 `json:"cache_ratio,omitempty"`
  25. CreateCacheRatio *float64 `json:"create_cache_ratio,omitempty"`
  26. ImageRatio *float64 `json:"image_ratio,omitempty"`
  27. AudioRatio *float64 `json:"audio_ratio,omitempty"`
  28. AudioCompletionRatio *float64 `json:"audio_completion_ratio,omitempty"`
  29. EnableGroup []string `json:"enable_groups"`
  30. SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
  31. PricingVersion string `json:"pricing_version,omitempty"`
  32. }
  33. type PricingVendor struct {
  34. ID int `json:"id"`
  35. Name string `json:"name"`
  36. Description string `json:"description,omitempty"`
  37. Icon string `json:"icon,omitempty"`
  38. }
  39. var (
  40. pricingMap []Pricing
  41. vendorsList []PricingVendor
  42. supportedEndpointMap map[string]common.EndpointInfo
  43. lastGetPricingTime time.Time
  44. updatePricingLock sync.Mutex
  45. // 缓存映射:模型名 -> 启用分组 / 计费类型
  46. modelEnableGroups = make(map[string][]string)
  47. modelQuotaTypeMap = make(map[string]int)
  48. modelEnableGroupsLock = sync.RWMutex{}
  49. )
  50. var (
  51. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  52. modelSupportEndpointsLock = sync.RWMutex{}
  53. )
  54. func GetPricing() []Pricing {
  55. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  56. updatePricingLock.Lock()
  57. defer updatePricingLock.Unlock()
  58. // Double check after acquiring the lock
  59. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  60. modelSupportEndpointsLock.Lock()
  61. defer modelSupportEndpointsLock.Unlock()
  62. updatePricing()
  63. }
  64. }
  65. return pricingMap
  66. }
  67. // GetVendors 返回当前定价接口使用到的供应商信息
  68. func GetVendors() []PricingVendor {
  69. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  70. // 保证先刷新一次
  71. GetPricing()
  72. }
  73. return vendorsList
  74. }
  75. func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
  76. if model == "" {
  77. return make([]constant.EndpointType, 0)
  78. }
  79. modelSupportEndpointsLock.RLock()
  80. defer modelSupportEndpointsLock.RUnlock()
  81. if endpoints, ok := modelSupportEndpointTypes[model]; ok {
  82. return endpoints
  83. }
  84. return make([]constant.EndpointType, 0)
  85. }
  86. func updatePricing() {
  87. //modelRatios := common.GetModelRatios()
  88. enableAbilities, err := GetAllEnableAbilityWithChannels()
  89. if err != nil {
  90. common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
  91. return
  92. }
  93. // 预加载模型元数据与供应商一次,避免循环查询
  94. var allMeta []Model
  95. _ = DB.Find(&allMeta).Error
  96. metaMap := make(map[string]*Model)
  97. prefixList := make([]*Model, 0)
  98. suffixList := make([]*Model, 0)
  99. containsList := make([]*Model, 0)
  100. for i := range allMeta {
  101. m := &allMeta[i]
  102. if m.NameRule == NameRuleExact {
  103. metaMap[m.ModelName] = m
  104. } else {
  105. switch m.NameRule {
  106. case NameRulePrefix:
  107. prefixList = append(prefixList, m)
  108. case NameRuleSuffix:
  109. suffixList = append(suffixList, m)
  110. case NameRuleContains:
  111. containsList = append(containsList, m)
  112. }
  113. }
  114. }
  115. // 将非精确规则模型匹配到 metaMap
  116. for _, m := range prefixList {
  117. for _, pricingModel := range enableAbilities {
  118. if strings.HasPrefix(pricingModel.Model, m.ModelName) {
  119. if _, exists := metaMap[pricingModel.Model]; !exists {
  120. metaMap[pricingModel.Model] = m
  121. }
  122. }
  123. }
  124. }
  125. for _, m := range suffixList {
  126. for _, pricingModel := range enableAbilities {
  127. if strings.HasSuffix(pricingModel.Model, m.ModelName) {
  128. if _, exists := metaMap[pricingModel.Model]; !exists {
  129. metaMap[pricingModel.Model] = m
  130. }
  131. }
  132. }
  133. }
  134. for _, m := range containsList {
  135. for _, pricingModel := range enableAbilities {
  136. if strings.Contains(pricingModel.Model, m.ModelName) {
  137. if _, exists := metaMap[pricingModel.Model]; !exists {
  138. metaMap[pricingModel.Model] = m
  139. }
  140. }
  141. }
  142. }
  143. // 预加载供应商
  144. var vendors []Vendor
  145. _ = DB.Find(&vendors).Error
  146. vendorMap := make(map[int]*Vendor)
  147. for i := range vendors {
  148. vendorMap[vendors[i].Id] = &vendors[i]
  149. }
  150. // 初始化默认供应商映射
  151. initDefaultVendorMapping(metaMap, vendorMap, enableAbilities)
  152. // 构建对前端友好的供应商列表
  153. vendorsList = make([]PricingVendor, 0, len(vendorMap))
  154. for _, v := range vendorMap {
  155. vendorsList = append(vendorsList, PricingVendor{
  156. ID: v.Id,
  157. Name: v.Name,
  158. Description: v.Description,
  159. Icon: v.Icon,
  160. })
  161. }
  162. modelGroupsMap := make(map[string]*types.Set[string])
  163. for _, ability := range enableAbilities {
  164. groups, ok := modelGroupsMap[ability.Model]
  165. if !ok {
  166. groups = types.NewSet[string]()
  167. modelGroupsMap[ability.Model] = groups
  168. }
  169. groups.Add(ability.Group)
  170. }
  171. //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
  172. modelSupportEndpointsStr := make(map[string][]string)
  173. // 先根据已有能力填充原生端点
  174. for _, ability := range enableAbilities {
  175. endpoints := modelSupportEndpointsStr[ability.Model]
  176. channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
  177. for _, channelType := range channelTypes {
  178. if !common.StringsContains(endpoints, string(channelType)) {
  179. endpoints = append(endpoints, string(channelType))
  180. }
  181. }
  182. modelSupportEndpointsStr[ability.Model] = endpoints
  183. }
  184. // 再补充模型自定义端点:若配置有效则替换默认端点,不做合并
  185. for modelName, meta := range metaMap {
  186. if strings.TrimSpace(meta.Endpoints) == "" {
  187. continue
  188. }
  189. var raw map[string]interface{}
  190. if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
  191. endpoints := make([]string, 0, len(raw))
  192. for k, v := range raw {
  193. switch v.(type) {
  194. case string, map[string]interface{}:
  195. if !common.StringsContains(endpoints, k) {
  196. endpoints = append(endpoints, k)
  197. }
  198. }
  199. }
  200. if len(endpoints) > 0 {
  201. modelSupportEndpointsStr[modelName] = endpoints
  202. }
  203. }
  204. }
  205. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  206. for model, endpoints := range modelSupportEndpointsStr {
  207. supportedEndpoints := make([]constant.EndpointType, 0)
  208. for _, endpointStr := range endpoints {
  209. endpointType := constant.EndpointType(endpointStr)
  210. supportedEndpoints = append(supportedEndpoints, endpointType)
  211. }
  212. modelSupportEndpointTypes[model] = supportedEndpoints
  213. }
  214. // 构建全局 supportedEndpointMap(默认 + 自定义覆盖)
  215. supportedEndpointMap = make(map[string]common.EndpointInfo)
  216. // 1. 默认端点
  217. for _, endpoints := range modelSupportEndpointTypes {
  218. for _, et := range endpoints {
  219. if info, ok := common.GetDefaultEndpointInfo(et); ok {
  220. if _, exists := supportedEndpointMap[string(et)]; !exists {
  221. supportedEndpointMap[string(et)] = info
  222. }
  223. }
  224. }
  225. }
  226. // 2. 自定义端点(models 表)覆盖默认
  227. for _, meta := range metaMap {
  228. if strings.TrimSpace(meta.Endpoints) == "" {
  229. continue
  230. }
  231. var raw map[string]interface{}
  232. if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil {
  233. for k, v := range raw {
  234. switch val := v.(type) {
  235. case string:
  236. supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"}
  237. case map[string]interface{}:
  238. ep := common.EndpointInfo{Method: "POST"}
  239. if p, ok := val["path"].(string); ok {
  240. ep.Path = p
  241. }
  242. if m, ok := val["method"].(string); ok {
  243. ep.Method = strings.ToUpper(m)
  244. }
  245. supportedEndpointMap[k] = ep
  246. default:
  247. // ignore unsupported types
  248. }
  249. }
  250. }
  251. }
  252. pricingMap = make([]Pricing, 0)
  253. for model, groups := range modelGroupsMap {
  254. pricing := Pricing{
  255. ModelName: model,
  256. EnableGroup: groups.Items(),
  257. SupportedEndpointTypes: modelSupportEndpointTypes[model],
  258. }
  259. // 补充模型元数据(描述、标签、供应商、状态)
  260. if meta, ok := metaMap[model]; ok {
  261. // 若模型被禁用(status!=1),则直接跳过,不返回给前端
  262. if meta.Status != 1 {
  263. continue
  264. }
  265. pricing.Description = meta.Description
  266. pricing.Icon = meta.Icon
  267. pricing.Tags = meta.Tags
  268. pricing.VendorID = meta.VendorID
  269. }
  270. modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
  271. if findPrice {
  272. pricing.ModelPrice = modelPrice
  273. pricing.QuotaType = 1
  274. } else {
  275. modelRatio, _, _ := ratio_setting.GetModelRatio(model)
  276. pricing.ModelRatio = modelRatio
  277. pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
  278. pricing.QuotaType = 0
  279. }
  280. if cacheRatio, ok := ratio_setting.GetCacheRatio(model); ok {
  281. pricing.CacheRatio = &cacheRatio
  282. }
  283. if createCacheRatio, ok := ratio_setting.GetCreateCacheRatio(model); ok {
  284. pricing.CreateCacheRatio = &createCacheRatio
  285. }
  286. if imageRatio, ok := ratio_setting.GetImageRatio(model); ok {
  287. pricing.ImageRatio = &imageRatio
  288. }
  289. if ratio_setting.ContainsAudioRatio(model) {
  290. audioRatio := ratio_setting.GetAudioRatio(model)
  291. pricing.AudioRatio = &audioRatio
  292. }
  293. if ratio_setting.ContainsAudioCompletionRatio(model) {
  294. audioCompletionRatio := ratio_setting.GetAudioCompletionRatio(model)
  295. pricing.AudioCompletionRatio = &audioCompletionRatio
  296. }
  297. pricingMap = append(pricingMap, pricing)
  298. }
  299. // 防止大更新后数据不通用
  300. if len(pricingMap) > 0 {
  301. pricingMap[0].PricingVersion = "5a90f2b86c08bd983a9a2e6d66c255f4eaef9c4bc934386d2b6ae84ef0ff1f1f"
  302. }
  303. // 刷新缓存映射,供高并发快速查询
  304. modelEnableGroupsLock.Lock()
  305. modelEnableGroups = make(map[string][]string)
  306. modelQuotaTypeMap = make(map[string]int)
  307. for _, p := range pricingMap {
  308. modelEnableGroups[p.ModelName] = p.EnableGroup
  309. modelQuotaTypeMap[p.ModelName] = p.QuotaType
  310. }
  311. modelEnableGroupsLock.Unlock()
  312. lastGetPricingTime = time.Now()
  313. }
  314. // GetSupportedEndpointMap 返回全局端点到路径的映射
  315. func GetSupportedEndpointMap() map[string]common.EndpointInfo {
  316. return supportedEndpointMap
  317. }