modelconfig.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. package model
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "time"
  8. "github.com/bytedance/sonic"
  9. "github.com/labring/aiproxy/core/common"
  10. "github.com/labring/aiproxy/core/relay/mode"
  11. "gorm.io/gorm"
  12. )
  13. const (
  14. // /1K tokens
  15. PriceUnit = 1000
  16. )
  17. //nolint:revive
  18. type ModelConfig struct {
  19. CreatedAt time.Time `gorm:"index;autoCreateTime" json:"created_at"`
  20. UpdatedAt time.Time `gorm:"index;autoUpdateTime" json:"updated_at"`
  21. Config map[ModelConfigKey]any `gorm:"serializer:fastjson;type:text" json:"config,omitempty"`
  22. Plugin map[string]json.RawMessage `gorm:"serializer:fastjson;type:text" json:"plugin,omitempty"`
  23. Model string `gorm:"primaryKey" json:"model"`
  24. Owner ModelOwner `gorm:"type:varchar(255);index" json:"owner"`
  25. Type mode.Mode ` json:"type"`
  26. ExcludeFromTests bool ` json:"exclude_from_tests,omitempty"`
  27. RPM int64 ` json:"rpm,omitempty"`
  28. TPM int64 ` json:"tpm,omitempty"`
  29. // map[size]map[quality]price_per_image
  30. ImageQualityPrices map[string]map[string]float64 `gorm:"serializer:fastjson;type:text" json:"image_quality_prices,omitempty"`
  31. // map[size]price_per_image
  32. ImagePrices map[string]float64 `gorm:"serializer:fastjson;type:text" json:"image_prices,omitempty"`
  33. Price Price `gorm:"embedded" json:"price,omitempty"`
  34. RetryTimes int64 ` json:"retry_times,omitempty"`
  35. Timeout int64 ` json:"timeout,omitempty"`
  36. MaxErrorRate float64 ` json:"max_error_rate,omitempty"`
  37. ForceSaveDetail bool ` json:"force_save_detail,omitempty"`
  38. }
  39. func (c *ModelConfig) BeforeSave(_ *gorm.DB) (err error) {
  40. if c.Model == "" {
  41. return errors.New("model is required")
  42. }
  43. return nil
  44. }
  45. func NewDefaultModelConfig(model string) ModelConfig {
  46. return ModelConfig{
  47. Model: model,
  48. }
  49. }
  50. func (c *ModelConfig) LoadPluginConfig(pluginName string, config any) error {
  51. if len(c.Plugin) == 0 {
  52. return nil
  53. }
  54. pluginConfig, ok := c.Plugin[pluginName]
  55. if !ok || len(pluginConfig) == 0 {
  56. return nil
  57. }
  58. return sonic.Unmarshal(pluginConfig, config)
  59. }
  60. func (c *ModelConfig) LoadFromGroupModelConfig(groupModelConfig GroupModelConfig) ModelConfig {
  61. newC := *c
  62. if groupModelConfig.OverrideLimit {
  63. newC.RPM = groupModelConfig.RPM
  64. newC.TPM = groupModelConfig.TPM
  65. }
  66. if groupModelConfig.OverridePrice {
  67. newC.ImagePrices = groupModelConfig.ImagePrices
  68. newC.Price = groupModelConfig.Price
  69. }
  70. if groupModelConfig.OverrideRetryTimes {
  71. newC.RetryTimes = groupModelConfig.RetryTimes
  72. }
  73. if groupModelConfig.OverrideForceSaveDetail {
  74. newC.ForceSaveDetail = groupModelConfig.ForceSaveDetail
  75. }
  76. return newC
  77. }
  78. func (c *ModelConfig) MarshalJSON() ([]byte, error) {
  79. type Alias ModelConfig
  80. a := &struct {
  81. *Alias
  82. CreatedAt int64 `json:"created_at,omitempty"`
  83. UpdatedAt int64 `json:"updated_at,omitempty"`
  84. }{
  85. Alias: (*Alias)(c),
  86. }
  87. if !c.CreatedAt.IsZero() {
  88. a.CreatedAt = c.CreatedAt.UnixMilli()
  89. }
  90. if !c.UpdatedAt.IsZero() {
  91. a.UpdatedAt = c.UpdatedAt.UnixMilli()
  92. }
  93. return sonic.Marshal(a)
  94. }
  95. func (c *ModelConfig) MaxContextTokens() (int, bool) {
  96. return GetModelConfigInt(c.Config, ModelConfigMaxContextTokensKey)
  97. }
  98. func (c *ModelConfig) MaxInputTokens() (int, bool) {
  99. return GetModelConfigInt(c.Config, ModelConfigMaxInputTokensKey)
  100. }
  101. func (c *ModelConfig) MaxOutputTokens() (int, bool) {
  102. return GetModelConfigInt(c.Config, ModelConfigMaxOutputTokensKey)
  103. }
  104. func (c *ModelConfig) SupportVision() (bool, bool) {
  105. return GetModelConfigBool(c.Config, ModelConfigVisionKey)
  106. }
  107. func (c *ModelConfig) SupportVoices() ([]string, bool) {
  108. return GetModelConfigStringSlice(c.Config, ModelConfigSupportVoicesKey)
  109. }
  110. func (c *ModelConfig) SupportToolChoice() (bool, bool) {
  111. return GetModelConfigBool(c.Config, ModelConfigToolChoiceKey)
  112. }
  113. func (c *ModelConfig) SupportFormats() ([]string, bool) {
  114. return GetModelConfigStringSlice(c.Config, ModelConfigSupportFormatsKey)
  115. }
  116. func GetModelConfigs(
  117. page, perPage int,
  118. model string,
  119. ) (configs []*ModelConfig, total int64, err error) {
  120. tx := DB.Model(&ModelConfig{})
  121. if model != "" {
  122. tx = tx.Where("model = ?", model)
  123. }
  124. err = tx.Count(&total).Error
  125. if err != nil {
  126. return nil, 0, err
  127. }
  128. if total <= 0 {
  129. return nil, 0, nil
  130. }
  131. limit, offset := toLimitOffset(page, perPage)
  132. err = tx.
  133. Order("created_at desc").
  134. Omit("created_at", "updated_at").
  135. Limit(limit).
  136. Offset(offset).
  137. Find(&configs).
  138. Error
  139. return configs, total, err
  140. }
  141. func GetAllModelConfigs() (configs []ModelConfig, err error) {
  142. tx := DB.Model(&ModelConfig{})
  143. err = tx.Order("created_at desc").
  144. Omit("created_at", "updated_at").
  145. Find(&configs).
  146. Error
  147. return configs, err
  148. }
  149. func GetModelConfigsByModels(models []string) (configs []ModelConfig, err error) {
  150. tx := DB.Model(&ModelConfig{}).Where("model IN (?)", models)
  151. err = tx.Order("created_at desc").
  152. Omit("created_at", "updated_at").
  153. Find(&configs).
  154. Error
  155. return configs, err
  156. }
  157. func GetModelConfig(model string) (ModelConfig, error) {
  158. config := ModelConfig{}
  159. err := DB.Model(&ModelConfig{}).
  160. Where("model = ?", model).
  161. Omit("created_at", "updated_at").
  162. First(config).
  163. Error
  164. return config, HandleNotFound(err, ErrModelConfigNotFound)
  165. }
  166. func SearchModelConfigs(
  167. keyword string,
  168. page, perPage int,
  169. model string,
  170. owner ModelOwner,
  171. ) (configs []ModelConfig, total int64, err error) {
  172. tx := DB.Model(&ModelConfig{}).Where("model LIKE ?", "%"+keyword+"%")
  173. if model != "" {
  174. tx = tx.Where("model = ?", model)
  175. }
  176. if owner != "" {
  177. tx = tx.Where("owner = ?", owner)
  178. }
  179. if keyword != "" {
  180. var conditions []string
  181. var values []any
  182. if model == "" {
  183. if common.UsingPostgreSQL {
  184. conditions = append(conditions, "model ILIKE ?")
  185. } else {
  186. conditions = append(conditions, "model LIKE ?")
  187. }
  188. values = append(values, "%"+keyword+"%")
  189. }
  190. if owner != "" {
  191. if common.UsingPostgreSQL {
  192. conditions = append(conditions, "owner ILIKE ?")
  193. } else {
  194. conditions = append(conditions, "owner LIKE ?")
  195. }
  196. values = append(values, "%"+string(owner)+"%")
  197. }
  198. if len(conditions) > 0 {
  199. tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
  200. }
  201. }
  202. err = tx.Count(&total).Error
  203. if err != nil {
  204. return nil, 0, err
  205. }
  206. if total <= 0 {
  207. return nil, 0, nil
  208. }
  209. limit, offset := toLimitOffset(page, perPage)
  210. err = tx.Order("created_at desc").
  211. Omit("created_at", "updated_at").
  212. Limit(limit).
  213. Offset(offset).
  214. Find(&configs).
  215. Error
  216. return configs, total, err
  217. }
  218. func SaveModelConfig(config ModelConfig) (err error) {
  219. defer func() {
  220. if err == nil {
  221. _ = InitModelConfigAndChannelCache()
  222. }
  223. }()
  224. return DB.Save(&config).Error
  225. }
  226. func SaveModelConfigs(configs []ModelConfig) (err error) {
  227. defer func() {
  228. if err == nil {
  229. _ = InitModelConfigAndChannelCache()
  230. }
  231. }()
  232. return DB.Transaction(func(tx *gorm.DB) error {
  233. for _, config := range configs {
  234. if err := tx.Save(&config).Error; err != nil {
  235. return err
  236. }
  237. }
  238. return nil
  239. })
  240. }
  241. const ErrModelConfigNotFound = "model config"
  242. func DeleteModelConfig(model string) error {
  243. result := DB.Where("model = ?", model).Delete(&ModelConfig{})
  244. return HandleUpdateResult(result, ErrModelConfigNotFound)
  245. }
  246. func DeleteModelConfigsByModels(models []string) error {
  247. return DB.Transaction(func(tx *gorm.DB) error {
  248. return tx.
  249. Where("model IN (?)", models).
  250. Delete(&ModelConfig{}).
  251. Error
  252. })
  253. }