pricing.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. package model
  2. import (
  3. "fmt"
  4. "one-api/common"
  5. "one-api/constant"
  6. "one-api/setting/ratio_setting"
  7. "one-api/types"
  8. "sync"
  9. "time"
  10. )
  11. type Pricing struct {
  12. ModelName string `json:"model_name"`
  13. QuotaType int `json:"quota_type"`
  14. ModelRatio float64 `json:"model_ratio"`
  15. ModelPrice float64 `json:"model_price"`
  16. OwnerBy string `json:"owner_by"`
  17. CompletionRatio float64 `json:"completion_ratio"`
  18. EnableGroup []string `json:"enable_groups"`
  19. SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"`
  20. }
  21. var (
  22. pricingMap []Pricing
  23. lastGetPricingTime time.Time
  24. updatePricingLock sync.Mutex
  25. )
  26. var (
  27. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  28. modelSupportEndpointsLock = sync.RWMutex{}
  29. )
  30. func GetPricing() []Pricing {
  31. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  32. updatePricingLock.Lock()
  33. defer updatePricingLock.Unlock()
  34. // Double check after acquiring the lock
  35. if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 {
  36. modelSupportEndpointsLock.Lock()
  37. defer modelSupportEndpointsLock.Unlock()
  38. updatePricing()
  39. }
  40. }
  41. return pricingMap
  42. }
  43. func GetModelSupportEndpointTypes(model string) []constant.EndpointType {
  44. if model == "" {
  45. return make([]constant.EndpointType, 0)
  46. }
  47. modelSupportEndpointsLock.RLock()
  48. defer modelSupportEndpointsLock.RUnlock()
  49. if endpoints, ok := modelSupportEndpointTypes[model]; ok {
  50. return endpoints
  51. }
  52. return make([]constant.EndpointType, 0)
  53. }
  54. func updatePricing() {
  55. //modelRatios := common.GetModelRatios()
  56. enableAbilities, err := GetAllEnableAbilityWithChannels()
  57. if err != nil {
  58. common.SysError(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err))
  59. return
  60. }
  61. modelGroupsMap := make(map[string]*types.Set[string])
  62. for _, ability := range enableAbilities {
  63. groups, ok := modelGroupsMap[ability.Model]
  64. if !ok {
  65. groups = types.NewSet[string]()
  66. modelGroupsMap[ability.Model] = groups
  67. }
  68. groups.Add(ability.Group)
  69. }
  70. //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点
  71. modelSupportEndpointsStr := make(map[string][]string)
  72. for _, ability := range enableAbilities {
  73. endpoints, ok := modelSupportEndpointsStr[ability.Model]
  74. if !ok {
  75. endpoints = make([]string, 0)
  76. modelSupportEndpointsStr[ability.Model] = endpoints
  77. }
  78. channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model)
  79. for _, channelType := range channelTypes {
  80. if !common.StringsContains(endpoints, string(channelType)) {
  81. endpoints = append(endpoints, string(channelType))
  82. }
  83. }
  84. modelSupportEndpointsStr[ability.Model] = endpoints
  85. }
  86. modelSupportEndpointTypes = make(map[string][]constant.EndpointType)
  87. for model, endpoints := range modelSupportEndpointsStr {
  88. supportedEndpoints := make([]constant.EndpointType, 0)
  89. for _, endpointStr := range endpoints {
  90. endpointType := constant.EndpointType(endpointStr)
  91. supportedEndpoints = append(supportedEndpoints, endpointType)
  92. }
  93. modelSupportEndpointTypes[model] = supportedEndpoints
  94. }
  95. pricingMap = make([]Pricing, 0)
  96. for model, groups := range modelGroupsMap {
  97. pricing := Pricing{
  98. ModelName: model,
  99. EnableGroup: groups.Items(),
  100. SupportedEndpointTypes: modelSupportEndpointTypes[model],
  101. }
  102. modelPrice, findPrice := ratio_setting.GetModelPrice(model, false)
  103. if findPrice {
  104. pricing.ModelPrice = modelPrice
  105. pricing.QuotaType = 1
  106. } else {
  107. modelRatio, _, _ := ratio_setting.GetModelRatio(model)
  108. pricing.ModelRatio = modelRatio
  109. pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model)
  110. pricing.QuotaType = 0
  111. }
  112. pricingMap = append(pricingMap, pricing)
  113. }
  114. lastGetPricingTime = time.Now()
  115. }