package model import ( "errors" "fmt" "strings" "time" "github.com/bytedance/sonic" "github.com/go-viper/mapstructure/v2" "github.com/labring/aiproxy/core/common" "github.com/labring/aiproxy/core/relay/mode" "gorm.io/gorm" ) const ( // /1K tokens PriceUnit = 1000 ) type TimeoutConfig struct { RequestTimeout int64 `json:"request_timeout,omitempty" yaml:"request_timeout,omitempty"` StreamRequestTimeout int64 `json:"stream_request_timeout,omitempty" yaml:"stream_request_timeout,omitempty"` } type ModelConfig struct { CreatedAt time.Time `gorm:"index;autoCreateTime" json:"created_at" yaml:"-"` UpdatedAt time.Time `gorm:"index;autoUpdateTime" json:"updated_at" yaml:"-"` Config map[ModelConfigKey]any `gorm:"serializer:fastjson;type:text" json:"config,omitempty" yaml:"config,omitempty"` Plugin map[string]map[string]any `gorm:"serializer:fastjson;type:text" json:"plugin,omitempty" yaml:"plugin,omitempty"` Model string `gorm:"size:64;primaryKey" json:"model" yaml:"model,omitempty"` Owner ModelOwner `gorm:"type:varchar(32);index" json:"owner" yaml:"owner,omitempty"` Type mode.Mode ` json:"type" yaml:"type,omitempty"` ExcludeFromTests bool ` json:"exclude_from_tests,omitempty" yaml:"exclude_from_tests,omitempty"` RPM int64 ` json:"rpm,omitempty" yaml:"rpm,omitempty"` TPM int64 ` json:"tpm,omitempty" yaml:"tpm,omitempty"` // map[size]map[quality]price_per_image ImageQualityPrices map[string]map[string]float64 `gorm:"serializer:fastjson;type:text" json:"image_quality_prices,omitempty" yaml:"image_quality_prices,omitempty"` // map[size]price_per_image ImagePrices map[string]float64 `gorm:"serializer:fastjson;type:text" json:"image_prices,omitempty" yaml:"image_prices,omitempty"` Price Price `gorm:"embedded" json:"price,omitempty" yaml:"price,omitempty"` RetryTimes int64 ` json:"retry_times,omitempty" yaml:"retry_times,omitempty"` TimeoutConfig TimeoutConfig `gorm:"embedded" json:"timeout_config,omitempty" yaml:"timeout_config,omitempty"` WarnErrorRate float64 ` json:"warn_error_rate,omitempty" yaml:"warn_error_rate,omitempty"` MaxErrorRate float64 ` json:"max_error_rate,omitempty" yaml:"max_error_rate,omitempty"` ForceSaveDetail bool ` json:"force_save_detail,omitempty" yaml:"force_save_detail,omitempty"` } func (c *ModelConfig) BeforeSave(_ *gorm.DB) (err error) { if c.Model == "" { return errors.New("model is required") } if err := c.Price.ValidateConditionalPrices(); err != nil { return err } return nil } func NewDefaultModelConfig(model string) ModelConfig { return ModelConfig{ Model: model, } } func (c *ModelConfig) RequestTimeout() time.Duration { return timeoutSecond(c.TimeoutConfig.RequestTimeout) } func (c *ModelConfig) StreamRequestTimeout() time.Duration { return timeoutSecond(c.TimeoutConfig.StreamRequestTimeout) } func timeoutSecond(second int64) time.Duration { if second == 0 { return 0 } return time.Duration(second) * time.Second } func (c *ModelConfig) LoadPluginConfig(pluginName string, config any) error { if len(c.Plugin) == 0 { return nil } pluginConfig, ok := c.Plugin[pluginName] if !ok || len(pluginConfig) == 0 { return nil } return mapstructure.Decode(pluginConfig, config) } func (c *ModelConfig) LoadFromGroupModelConfig(groupModelConfig GroupModelConfig) ModelConfig { newC := *c if groupModelConfig.OverrideLimit { newC.RPM = groupModelConfig.RPM newC.TPM = groupModelConfig.TPM } if groupModelConfig.OverridePrice { newC.ImagePrices = groupModelConfig.ImagePrices newC.Price = groupModelConfig.Price } if groupModelConfig.OverrideRetryTimes { newC.RetryTimes = groupModelConfig.RetryTimes } if groupModelConfig.OverrideForceSaveDetail { newC.ForceSaveDetail = groupModelConfig.ForceSaveDetail } return newC } func (c *ModelConfig) MarshalJSON() ([]byte, error) { type Alias ModelConfig a := &struct { *Alias CreatedAt int64 `json:"created_at,omitempty"` UpdatedAt int64 `json:"updated_at,omitempty"` }{ Alias: (*Alias)(c), } if !c.CreatedAt.IsZero() { a.CreatedAt = c.CreatedAt.UnixMilli() } if !c.UpdatedAt.IsZero() { a.UpdatedAt = c.UpdatedAt.UnixMilli() } return sonic.Marshal(a) } func (c *ModelConfig) MaxContextTokens() (int, bool) { return GetModelConfigInt(c.Config, ModelConfigMaxContextTokensKey) } func (c *ModelConfig) MaxInputTokens() (int, bool) { return GetModelConfigInt(c.Config, ModelConfigMaxInputTokensKey) } func (c *ModelConfig) MaxOutputTokens() (int, bool) { return GetModelConfigInt(c.Config, ModelConfigMaxOutputTokensKey) } func (c *ModelConfig) SupportVision() (bool, bool) { return GetModelConfigBool(c.Config, ModelConfigVisionKey) } func (c *ModelConfig) SupportVoices() ([]string, bool) { return GetModelConfigStringSlice(c.Config, ModelConfigSupportVoicesKey) } func (c *ModelConfig) SupportToolChoice() (bool, bool) { return GetModelConfigBool(c.Config, ModelConfigToolChoiceKey) } func (c *ModelConfig) SupportFormats() ([]string, bool) { return GetModelConfigStringSlice(c.Config, ModelConfigSupportFormatsKey) } func GetModelConfigs( page, perPage int, model string, ) (configs []*ModelConfig, total int64, err error) { tx := DB.Model(&ModelConfig{}) if model != "" { tx = tx.Where("model = ?", model) } err = tx.Count(&total).Error if err != nil { return nil, 0, err } if total <= 0 { return nil, 0, nil } limit, offset := toLimitOffset(page, perPage) err = tx. Order("created_at desc"). Omit("created_at", "updated_at"). Limit(limit). Offset(offset). Find(&configs). Error return configs, total, err } func GetAllModelConfigs() (configs []ModelConfig, err error) { tx := DB.Model(&ModelConfig{}) err = tx.Order("created_at desc"). Omit("created_at", "updated_at"). Find(&configs). Error return configs, err } func GetModelConfigsByModels(models []string) (configs []ModelConfig, err error) { tx := DB.Model(&ModelConfig{}).Where("model IN (?)", models) err = tx.Order("created_at desc"). Omit("created_at", "updated_at"). Find(&configs). Error return configs, err } func GetModelConfig(model string) (ModelConfig, error) { config := ModelConfig{} err := DB.Model(&ModelConfig{}). Where("model = ?", model). Omit("created_at", "updated_at"). First(config). Error return config, HandleNotFound(err, ErrModelConfigNotFound) } func SearchModelConfigs( keyword string, page, perPage int, model string, owner ModelOwner, ) (configs []ModelConfig, total int64, err error) { tx := DB.Model(&ModelConfig{}).Where("model LIKE ?", "%"+keyword+"%") if model != "" { tx = tx.Where("model = ?", model) } if owner != "" { tx = tx.Where("owner = ?", owner) } if keyword != "" { var ( conditions []string values []any ) if model == "" { if !common.UsingSQLite { conditions = append(conditions, "model ILIKE ?") } else { conditions = append(conditions, "model LIKE ?") } values = append(values, "%"+keyword+"%") } if owner != "" { if !common.UsingSQLite { conditions = append(conditions, "owner ILIKE ?") } else { conditions = append(conditions, "owner LIKE ?") } values = append(values, "%"+string(owner)+"%") } if len(conditions) > 0 { tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...) } } err = tx.Count(&total).Error if err != nil { return nil, 0, err } if total <= 0 { return nil, 0, nil } limit, offset := toLimitOffset(page, perPage) err = tx.Order("created_at desc"). Omit("created_at", "updated_at"). Limit(limit). Offset(offset). Find(&configs). Error return configs, total, err } func SaveModelConfig(config ModelConfig) (err error) { defer func() { if err == nil { _ = InitModelConfigAndChannelCache() } }() return DB.Save(&config).Error } func SaveModelConfigs(configs []ModelConfig) (err error) { defer func() { if err == nil { _ = InitModelConfigAndChannelCache() } }() return DB.Transaction(func(tx *gorm.DB) error { for _, config := range configs { if err := tx.Save(&config).Error; err != nil { return err } } return nil }) } const ErrModelConfigNotFound = "model config" func DeleteModelConfig(model string) error { result := DB.Where("model = ?", model).Delete(&ModelConfig{}) return HandleUpdateResult(result, ErrModelConfigNotFound) } func DeleteModelConfigsByModels(models []string) error { return DB.Transaction(func(tx *gorm.DB) error { return tx. Where("model IN (?)", models). Delete(&ModelConfig{}). Error }) }