| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- package model
- import (
- "fmt"
- "strings"
- "time"
- "github.com/bytedance/sonic"
- "github.com/labring/aiproxy/core/common"
- "github.com/labring/aiproxy/core/relay/mode"
- "gorm.io/gorm"
- )
- const (
- // /1K tokens
- PriceUnit = 1000
- )
- //nolint:revive
- type ModelConfig struct {
- CreatedAt time.Time `gorm:"index;autoCreateTime" json:"created_at"`
- UpdatedAt time.Time `gorm:"index;autoUpdateTime" json:"updated_at"`
- Config map[ModelConfigKey]any `gorm:"serializer:fastjson;type:text" json:"config,omitempty"`
- Model string `gorm:"primaryKey" json:"model"`
- Owner ModelOwner `gorm:"type:varchar(255);index" json:"owner"`
- Type mode.Mode `json:"type"`
- ExcludeFromTests bool `json:"exclude_from_tests,omitempty"`
- RPM int64 `json:"rpm,omitempty"`
- TPM int64 `json:"tpm,omitempty"`
- // map[size]map[quality]price_per_image
- ImageQualityPrices map[string]map[string]float64 `gorm:"serializer:fastjson;type:text" json:"image_quality_prices,omitempty"`
- // map[size]price_per_image
- ImagePrices map[string]float64 `gorm:"serializer:fastjson;type:text" json:"image_prices,omitempty"`
- Price Price `gorm:"embedded" json:"price,omitempty"`
- RetryTimes int64 `json:"retry_times"`
- }
- func NewDefaultModelConfig(model string) *ModelConfig {
- return &ModelConfig{
- Model: model,
- }
- }
- func (c *ModelConfig) LoadFromGroupModelConfig(groupModelConfig GroupModelConfig) ModelConfig {
- newC := *c
- if groupModelConfig.OverrideLimit {
- newC.RPM = groupModelConfig.RPM
- newC.TPM = groupModelConfig.TPM
- }
- if groupModelConfig.OverridePrice {
- newC.ImagePrices = groupModelConfig.ImagePrices
- newC.Price = groupModelConfig.Price
- }
- if groupModelConfig.OverrideRetryTimes {
- newC.RetryTimes = groupModelConfig.RetryTimes
- }
- return newC
- }
- func (c *ModelConfig) MarshalJSON() ([]byte, error) {
- type Alias ModelConfig
- a := &struct {
- *Alias
- CreatedAt int64 `json:"created_at,omitempty"`
- UpdatedAt int64 `json:"updated_at,omitempty"`
- }{
- Alias: (*Alias)(c),
- }
- if !c.CreatedAt.IsZero() {
- a.CreatedAt = c.CreatedAt.UnixMilli()
- }
- if !c.UpdatedAt.IsZero() {
- a.UpdatedAt = c.UpdatedAt.UnixMilli()
- }
- return sonic.Marshal(a)
- }
- func (c *ModelConfig) MaxContextTokens() (int, bool) {
- return GetModelConfigInt(c.Config, ModelConfigMaxContextTokensKey)
- }
- func (c *ModelConfig) MaxInputTokens() (int, bool) {
- return GetModelConfigInt(c.Config, ModelConfigMaxInputTokensKey)
- }
- func (c *ModelConfig) MaxOutputTokens() (int, bool) {
- return GetModelConfigInt(c.Config, ModelConfigMaxOutputTokensKey)
- }
- func (c *ModelConfig) SupportVision() (bool, bool) {
- return GetModelConfigBool(c.Config, ModelConfigVisionKey)
- }
- func (c *ModelConfig) SupportVoices() ([]string, bool) {
- return GetModelConfigStringSlice(c.Config, ModelConfigSupportVoicesKey)
- }
- func (c *ModelConfig) SupportToolChoice() (bool, bool) {
- return GetModelConfigBool(c.Config, ModelConfigToolChoiceKey)
- }
- func (c *ModelConfig) SupportFormats() ([]string, bool) {
- return GetModelConfigStringSlice(c.Config, ModelConfigSupportFormatsKey)
- }
- func GetModelConfigs(page int, perPage int, model string) (configs []*ModelConfig, total int64, err error) {
- tx := DB.Model(&ModelConfig{})
- if model != "" {
- tx = tx.Where("model = ?", model)
- }
- err = tx.Count(&total).Error
- if err != nil {
- return nil, 0, err
- }
- if total <= 0 {
- return nil, 0, nil
- }
- limit, offset := toLimitOffset(page, perPage)
- err = tx.
- Order("created_at desc").
- Omit("created_at", "updated_at").
- Limit(limit).
- Offset(offset).
- Find(&configs).
- Error
- return configs, total, err
- }
- func GetAllModelConfigs() (configs []*ModelConfig, err error) {
- tx := DB.Model(&ModelConfig{})
- err = tx.Order("created_at desc").
- Omit("created_at", "updated_at").
- Find(&configs).
- Error
- return configs, err
- }
- func GetModelConfigsByModels(models []string) (configs []*ModelConfig, err error) {
- tx := DB.Model(&ModelConfig{}).Where("model IN (?)", models)
- err = tx.Order("created_at desc").
- Omit("created_at", "updated_at").
- Find(&configs).
- Error
- return configs, err
- }
- func GetModelConfig(model string) (*ModelConfig, error) {
- config := &ModelConfig{}
- err := DB.Model(&ModelConfig{}).
- Where("model = ?", model).
- Omit("created_at", "updated_at").
- First(config).
- Error
- return config, HandleNotFound(err, ErrModelConfigNotFound)
- }
- func SearchModelConfigs(keyword string, page int, perPage int, model string, owner ModelOwner) (configs []*ModelConfig, total int64, err error) {
- tx := DB.Model(&ModelConfig{}).Where("model LIKE ?", "%"+keyword+"%")
- if model != "" {
- tx = tx.Where("model = ?", model)
- }
- if owner != "" {
- tx = tx.Where("owner = ?", owner)
- }
- if keyword != "" {
- var conditions []string
- var values []interface{}
- if model == "" {
- if common.UsingPostgreSQL {
- conditions = append(conditions, "model ILIKE ?")
- } else {
- conditions = append(conditions, "model LIKE ?")
- }
- values = append(values, "%"+keyword+"%")
- }
- if owner != "" {
- if common.UsingPostgreSQL {
- conditions = append(conditions, "owner ILIKE ?")
- } else {
- conditions = append(conditions, "owner LIKE ?")
- }
- values = append(values, "%"+string(owner)+"%")
- }
- if len(conditions) > 0 {
- tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
- }
- }
- err = tx.Count(&total).Error
- if err != nil {
- return nil, 0, err
- }
- if total <= 0 {
- return nil, 0, nil
- }
- limit, offset := toLimitOffset(page, perPage)
- err = tx.Order("created_at desc").
- Omit("created_at", "updated_at").
- Limit(limit).
- Offset(offset).
- Find(&configs).
- Error
- return configs, total, err
- }
- func SaveModelConfig(config *ModelConfig) (err error) {
- defer func() {
- if err == nil {
- _ = InitModelConfigAndChannelCache()
- }
- }()
- return DB.Save(config).Error
- }
- func SaveModelConfigs(configs []*ModelConfig) (err error) {
- defer func() {
- if err == nil {
- _ = InitModelConfigAndChannelCache()
- }
- }()
- return DB.Transaction(func(tx *gorm.DB) error {
- for _, config := range configs {
- if err := tx.Save(config).Error; err != nil {
- return err
- }
- }
- return nil
- })
- }
- const ErrModelConfigNotFound = "model config"
- func DeleteModelConfig(model string) error {
- result := DB.Where("model = ?", model).Delete(&ModelConfig{})
- return HandleUpdateResult(result, ErrModelConfigNotFound)
- }
- func DeleteModelConfigsByModels(models []string) error {
- return DB.Transaction(func(tx *gorm.DB) error {
- return tx.
- Where("model IN (?)", models).
- Delete(&ModelConfig{}).
- Error
- })
- }
|