package model import ( "context" "encoding/json" "fmt" "slices" "strings" "time" "github.com/bytedance/sonic" "github.com/bytedance/sonic/ast" "github.com/labring/aiproxy/core/common" "github.com/labring/aiproxy/core/common/config" "github.com/labring/aiproxy/core/monitor" "github.com/labring/aiproxy/core/relay/mode" "gorm.io/gorm" "gorm.io/gorm/clause" ) const ( ErrChannelNotFound = "channel" ) const ( ChannelStatusUnknown = 0 ChannelStatusEnabled = 1 ChannelStatusDisabled = 2 ) const ( ChannelDefaultSet = "default" ) type ChannelConfig struct { Spec json.RawMessage `json:"spec"` } // validate spec json is map[string]any func (c *ChannelConfig) UnmarshalJSON(data []byte) error { type Alias ChannelConfig alias := (*Alias)(c) if err := sonic.Unmarshal(data, alias); err != nil { return err } if len(alias.Spec) > 0 { var spec map[string]any if err := sonic.Unmarshal(alias.Spec, &spec); err != nil { return fmt.Errorf("invalid spec json: %w", err) } } return nil } func (c *ChannelConfig) SpecConfig(obj any) error { if c == nil || len(c.Spec) == 0 { return nil } return sonic.Unmarshal(c.Spec, obj) } func (c *ChannelConfig) Get(key ...any) (ast.Node, error) { if c == nil || len(c.Spec) == 0 { return ast.Node{}, ast.ErrNotExist } return sonic.Get(c.Spec, key...) } type Channel struct { DeletedAt gorm.DeletedAt `gorm:"index" json:"-"` CreatedAt time.Time `gorm:"index" json:"created_at"` LastTestErrorAt time.Time ` json:"last_test_error_at"` ChannelTests []*ChannelTest `gorm:"foreignKey:ChannelID;references:ID" json:"channel_tests,omitempty"` BalanceUpdatedAt time.Time ` json:"balance_updated_at"` ModelMapping map[string]string `gorm:"serializer:fastjson;type:text" json:"model_mapping"` Key string `gorm:"type:text;index:,length:191" json:"key"` Name string `gorm:"size:64;index" json:"name"` BaseURL string `gorm:"size:128;index" json:"base_url"` Models []string `gorm:"serializer:fastjson;type:text" json:"models"` Balance float64 ` json:"balance"` ID int `gorm:"primaryKey" json:"id"` UsedAmount float64 `gorm:"index" json:"used_amount"` RequestCount int `gorm:"index" json:"request_count"` RetryCount int `gorm:"index" json:"retry_count"` Status int `gorm:"default:1;index" json:"status"` Type ChannelType `gorm:"default:0;index" json:"type"` Priority int32 ` json:"priority"` EnabledAutoBalanceCheck bool ` json:"enabled_auto_balance_check"` BalanceThreshold float64 ` json:"balance_threshold"` Config *ChannelConfig `gorm:"serializer:fastjson;type:text" json:"config,omitempty"` Sets []string `gorm:"serializer:fastjson;type:text" json:"sets,omitempty"` } func (c *Channel) GetSets() []string { if len(c.Sets) == 0 { return []string{ChannelDefaultSet} } return c.Sets } func (c *Channel) BeforeDelete(tx *gorm.DB) (err error) { return tx.Model(&ChannelTest{}).Where("channel_id = ?", c.ID).Delete(&ChannelTest{}).Error } func (c *Channel) GetBalanceThreshold() float64 { if c.BalanceThreshold < 0 { return 0 } return c.BalanceThreshold } const ( DefaultPriority = 10 ) func (c *Channel) GetPriority() int32 { if c.Priority == 0 { return DefaultPriority } return c.Priority } func GetModelConfigWithModels(models []string) ([]string, []string, error) { if len(models) == 0 || config.DisableModelConfig { return models, nil, nil } where := DB.Model(&ModelConfig{}).Where("model IN ?", models) var count int64 if err := where.Count(&count).Error; err != nil { return nil, nil, err } if count == 0 { return nil, models, nil } if count == int64(len(models)) { return models, nil, nil } var foundModels []string if err := where.Pluck("model", &foundModels).Error; err != nil { return nil, nil, err } if len(foundModels) == len(models) { return models, nil, nil } foundModelsMap := make(map[string]struct{}, len(foundModels)) for _, model := range foundModels { foundModelsMap[model] = struct{}{} } if len(models)-len(foundModels) > 0 { missingModels := make([]string, 0, len(models)-len(foundModels)) for _, model := range models { if _, exists := foundModelsMap[model]; !exists { missingModels = append(missingModels, model) } } return foundModels, missingModels, nil } return foundModels, nil, nil } func CheckModelConfigExist(models []string) error { _, missingModels, err := GetModelConfigWithModels(models) if err != nil { return err } if len(missingModels) > 0 { slices.Sort(missingModels) return fmt.Errorf("model config not found: %v", missingModels) } return nil } func (c *Channel) MarshalJSON() ([]byte, error) { type Alias Channel return sonic.Marshal(&struct { *Alias CreatedAt int64 `json:"created_at"` BalanceUpdatedAt int64 `json:"balance_updated_at"` LastTestErrorAt int64 `json:"last_test_error_at"` }{ Alias: (*Alias)(c), CreatedAt: c.CreatedAt.UnixMilli(), BalanceUpdatedAt: c.BalanceUpdatedAt.UnixMilli(), LastTestErrorAt: c.LastTestErrorAt.UnixMilli(), }) } //nolint:goconst func getChannelOrder(order string) string { prefix, suffix, _ := strings.Cut(order, "-") switch prefix { case "name", "type", "created_at", "status", "test_at", "balance_updated_at", "used_amount", "request_count", "priority", "id": switch suffix { case "asc": return prefix + " asc" default: return prefix + " desc" } default: return "id desc" } } func GetAllChannels() (channels []*Channel, err error) { tx := DB.Model(&Channel{}) err = tx.Order("id desc").Find(&channels).Error return channels, err } func GetChannels( page, perPage, id int, name, key string, channelType int, baseURL, order string, ) (channels []*Channel, total int64, err error) { tx := DB.Model(&Channel{}) if id != 0 { tx = tx.Where("id = ?", id) } if name != "" { tx = tx.Where("name = ?", name) } if key != "" { tx = tx.Where("key = ?", key) } if channelType != 0 { tx = tx.Where("type = ?", channelType) } if baseURL != "" { tx = tx.Where("base_url = ?", baseURL) } err = tx.Count(&total).Error if err != nil { return nil, 0, err } if total <= 0 { return nil, 0, nil } limit, offset := toLimitOffset(page, perPage) err = tx.Order(getChannelOrder(order)).Limit(limit).Offset(offset).Find(&channels).Error return channels, total, err } func SearchChannels( keyword string, page, perPage, id int, name, key string, channelType int, baseURL, order string, ) (channels []*Channel, total int64, err error) { tx := DB.Model(&Channel{}) // Handle exact match conditions for non-zero values if id != 0 { tx = tx.Where("id = ?", id) } if name != "" { tx = tx.Where("name = ?", name) } if key != "" { tx = tx.Where("key = ?", key) } if channelType != 0 { tx = tx.Where("type = ?", channelType) } if baseURL != "" { tx = tx.Where("base_url = ?", baseURL) } // Handle keyword search for zero value fields if keyword != "" { var ( conditions []string values []any ) keywordInt := String2Int(keyword) if keywordInt != 0 { if id == 0 { conditions = append(conditions, "id = ?") values = append(values, keywordInt) } } if name == "" { if !common.UsingSQLite { conditions = append(conditions, "name ILIKE ?") } else { conditions = append(conditions, "name LIKE ?") } values = append(values, "%"+keyword+"%") } if key == "" { if !common.UsingSQLite { conditions = append(conditions, "key ILIKE ?") } else { conditions = append(conditions, "key LIKE ?") } values = append(values, "%"+keyword+"%") } if baseURL == "" { if !common.UsingSQLite { conditions = append(conditions, "base_url ILIKE ?") } else { conditions = append(conditions, "base_url LIKE ?") } values = append(values, "%"+keyword+"%") } if !common.UsingSQLite { conditions = append(conditions, "models ILIKE ?") } else { conditions = append(conditions, "models LIKE ?") } values = append(values, "%"+keyword+"%") if !common.UsingSQLite { conditions = append(conditions, "sets ILIKE ?") } else { conditions = append(conditions, "sets LIKE ?") } values = append(values, "%"+keyword+"%") if len(conditions) > 0 { tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...) } } err = tx.Count(&total).Error if err != nil { return nil, 0, err } if total <= 0 { return nil, 0, nil } limit, offset := toLimitOffset(page, perPage) err = tx.Order(getChannelOrder(order)).Limit(limit).Offset(offset).Find(&channels).Error return channels, total, err } func GetChannelByID(id int) (*Channel, error) { channel := Channel{ID: id} err := DB.First(&channel, "id = ?", id).Error return &channel, HandleNotFound(err, ErrChannelNotFound) } func BatchInsertChannels(channels []*Channel) (err error) { defer func() { if err == nil { _ = InitModelConfigAndChannelCache() } }() for _, channel := range channels { if err := CheckModelConfigExist(channel.Models); err != nil { return err } } return DB.Transaction(func(tx *gorm.DB) error { return tx.Create(&channels).Error }) } func UpdateChannel(channel *Channel) (err error) { defer func() { if err == nil { _ = InitModelConfigAndChannelCache() _ = monitor.ClearChannelAllModelErrors(context.Background(), channel.ID) } }() if err := CheckModelConfigExist(channel.Models); err != nil { return err } selects := []string{ "model_mapping", "key", "base_url", "models", "priority", "config", "enabled_auto_balance_check", "balance_threshold", "sets", } if channel.Type != 0 { selects = append(selects, "type") } if channel.Name != "" { selects = append(selects, "name") } result := DB. Select(selects). Clauses(clause.Returning{}). Where("id = ?", channel.ID). Updates(channel) return HandleUpdateResult(result, ErrChannelNotFound) } func ClearLastTestErrorAt(id int) error { result := DB.Model(&Channel{}). Where("id = ?", id). Update("last_test_error_at", gorm.Expr("NULL")) return HandleUpdateResult(result, ErrChannelNotFound) } func (c *Channel) UpdateModelTest( testAt time.Time, model, actualModel string, mode mode.Mode, took float64, success bool, response string, code int, ) (*ChannelTest, error) { var ct *ChannelTest err := DB.Transaction(func(tx *gorm.DB) error { if !success { result := tx.Model(&Channel{}). Where("id = ?", c.ID). Update("last_test_error_at", testAt) if err := HandleUpdateResult(result, ErrChannelNotFound); err != nil { return err } } else if !c.LastTestErrorAt.IsZero() && time.Since(c.LastTestErrorAt) > time.Hour { result := tx.Model(&Channel{}).Where("id = ?", c.ID).Update("last_test_error_at", gorm.Expr("NULL")) if err := HandleUpdateResult(result, ErrChannelNotFound); err != nil { return err } } ct = &ChannelTest{ ChannelID: c.ID, ChannelType: c.Type, ChannelName: c.Name, Model: model, ActualModel: actualModel, Mode: mode, TestAt: testAt, Took: took, Success: success, Response: response, Code: code, } result := tx.Save(ct) return HandleUpdateResult(result, ErrChannelNotFound) }) if err != nil { return nil, err } return ct, nil } func (c *Channel) UpdateBalance(balance float64) error { result := DB.Model(&Channel{}). Select("balance_updated_at", "balance"). Where("id = ?", c.ID). Updates(Channel{ BalanceUpdatedAt: time.Now(), Balance: balance, }) return HandleUpdateResult(result, ErrChannelNotFound) } func DeleteChannelByID(id int) (err error) { defer func() { if err == nil { _ = InitModelConfigAndChannelCache() _ = monitor.ClearChannelAllModelErrors(context.Background(), id) } }() result := DB.Delete(&Channel{ID: id}) return HandleUpdateResult(result, ErrChannelNotFound) } func DeleteChannelsByIDs(ids []int) (err error) { defer func() { if err == nil { _ = InitModelConfigAndChannelCache() for _, id := range ids { _ = monitor.ClearChannelAllModelErrors(context.Background(), id) } } }() return DB.Transaction(func(tx *gorm.DB) error { return tx. Where("id IN (?)", ids). Delete(&Channel{}). Error }) } func UpdateChannelStatusByID(id, status int) error { result := DB.Model(&Channel{}). Where("id = ?", id). Update("status", status) return HandleUpdateResult(result, ErrChannelNotFound) } func UpdateChannelUsedAmount(id int, amount float64, requestCount, retryCount int) error { result := DB.Model(&Channel{}). Where("id = ?", id). Updates(map[string]any{ "used_amount": gorm.Expr("used_amount + ?", amount), "request_count": gorm.Expr("request_count + ?", requestCount), "retry_count": gorm.Expr("retry_count + ?", retryCount), }) return HandleUpdateResult(result, ErrChannelNotFound) }