| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547 |
- package model
- import (
- "context"
- "fmt"
- "slices"
- "strings"
- "time"
- "github.com/bytedance/sonic"
- "github.com/labring/aiproxy/core/common"
- "github.com/labring/aiproxy/core/common/config"
- "github.com/labring/aiproxy/core/monitor"
- "github.com/labring/aiproxy/core/relay/mode"
- "gorm.io/gorm"
- "gorm.io/gorm/clause"
- )
- const (
- ErrChannelNotFound = "channel"
- )
- const (
- ChannelStatusUnknown = 0
- ChannelStatusEnabled = 1
- ChannelStatusDisabled = 2
- )
- const (
- ChannelDefaultSet = "default"
- )
- type Channel struct {
- DeletedAt gorm.DeletedAt `gorm:"index" json:"-" yaml:"-"`
- CreatedAt time.Time `gorm:"index" json:"created_at" yaml:"-"`
- LastTestErrorAt time.Time ` json:"last_test_error_at" yaml:"-"`
- ChannelTests []*ChannelTest `gorm:"foreignKey:ChannelID;references:ID" json:"channel_tests,omitempty" yaml:"-"`
- BalanceUpdatedAt time.Time ` json:"balance_updated_at" yaml:"-"`
- ModelMapping map[string]string `gorm:"serializer:fastjson;type:text" json:"model_mapping" yaml:"model_mapping,omitempty"`
- Key string `gorm:"type:text;index:,length:191" json:"key" yaml:"key,omitempty"`
- Name string `gorm:"size:64;index" json:"name" yaml:"name,omitempty"`
- BaseURL string `gorm:"size:128;index" json:"base_url" yaml:"base_url,omitempty"`
- Models []string `gorm:"serializer:fastjson;type:text" json:"models" yaml:"models,omitempty"`
- Balance float64 ` json:"balance" yaml:"balance,omitempty"`
- ID int `gorm:"primaryKey" json:"id" yaml:"id,omitempty"`
- UsedAmount float64 `gorm:"index" json:"used_amount" yaml:"-"`
- RequestCount int `gorm:"index" json:"request_count" yaml:"-"`
- RetryCount int `gorm:"index" json:"retry_count" yaml:"-"`
- Status int `gorm:"default:1;index" json:"status" yaml:"status,omitempty"`
- Type ChannelType `gorm:"default:0;index" json:"type" yaml:"type,omitempty"`
- Priority int32 ` json:"priority" yaml:"priority,omitempty"`
- EnabledAutoBalanceCheck bool ` json:"enabled_auto_balance_check" yaml:"enabled_auto_balance_check,omitempty"`
- BalanceThreshold float64 ` json:"balance_threshold" yaml:"balance_threshold,omitempty"`
- Configs ChannelConfigs `gorm:"serializer:fastjson;type:text" json:"configs,omitempty" yaml:"configs,omitempty"`
- Sets []string `gorm:"serializer:fastjson;type:text" json:"sets,omitempty" yaml:"sets,omitempty"`
- }
- func (c *Channel) GetSets() []string {
- if len(c.Sets) == 0 {
- return []string{ChannelDefaultSet}
- }
- return c.Sets
- }
- func (c *Channel) BeforeDelete(tx *gorm.DB) (err error) {
- return tx.Model(&ChannelTest{}).Where("channel_id = ?", c.ID).Delete(&ChannelTest{}).Error
- }
- func (c *Channel) GetBalanceThreshold() float64 {
- if c.BalanceThreshold < 0 {
- return 0
- }
- return c.BalanceThreshold
- }
- const (
- DefaultPriority = 10
- )
- func (c *Channel) GetPriority() int32 {
- if c.Priority == 0 {
- return DefaultPriority
- }
- return c.Priority
- }
- type ChannelConfigs map[string]any
- func (c ChannelConfigs) LoadConfig(config any) error {
- if len(c) == 0 {
- return nil
- }
- v, err := sonic.Marshal(c)
- if err != nil {
- return err
- }
- return sonic.Unmarshal(v, config)
- }
- func GetModelConfigWithModels(models []string) ([]string, []string, error) {
- if len(models) == 0 || config.DisableModelConfig {
- return models, nil, nil
- }
- where := DB.Model(&ModelConfig{}).Where("model IN ?", models)
- var count int64
- if err := where.Count(&count).Error; err != nil {
- return nil, nil, err
- }
- if count == 0 {
- return nil, models, nil
- }
- if count == int64(len(models)) {
- return models, nil, nil
- }
- var foundModels []string
- if err := where.Pluck("model", &foundModels).Error; err != nil {
- return nil, nil, err
- }
- if len(foundModels) == len(models) {
- return models, nil, nil
- }
- foundModelsMap := make(map[string]struct{}, len(foundModels))
- for _, model := range foundModels {
- foundModelsMap[model] = struct{}{}
- }
- if len(models)-len(foundModels) > 0 {
- missingModels := make([]string, 0, len(models)-len(foundModels))
- for _, model := range models {
- if _, exists := foundModelsMap[model]; !exists {
- missingModels = append(missingModels, model)
- }
- }
- return foundModels, missingModels, nil
- }
- return foundModels, nil, nil
- }
- func CheckModelConfigExist(models []string) error {
- _, missingModels, err := GetModelConfigWithModels(models)
- if err != nil {
- return err
- }
- if len(missingModels) > 0 {
- slices.Sort(missingModels)
- return fmt.Errorf("model config not found: %v", missingModels)
- }
- return nil
- }
- func (c *Channel) MarshalJSON() ([]byte, error) {
- type Alias Channel
- return sonic.Marshal(&struct {
- *Alias
- CreatedAt int64 `json:"created_at"`
- BalanceUpdatedAt int64 `json:"balance_updated_at"`
- LastTestErrorAt int64 `json:"last_test_error_at"`
- }{
- Alias: (*Alias)(c),
- CreatedAt: c.CreatedAt.UnixMilli(),
- BalanceUpdatedAt: c.BalanceUpdatedAt.UnixMilli(),
- LastTestErrorAt: c.LastTestErrorAt.UnixMilli(),
- })
- }
- //nolint:goconst
- func getChannelOrder(order string) string {
- prefix, suffix, _ := strings.Cut(order, "-")
- switch prefix {
- case "name",
- "type",
- "created_at",
- "status",
- "test_at",
- "balance_updated_at",
- "used_amount",
- "request_count",
- "priority",
- "id":
- switch suffix {
- case "asc":
- return prefix + " asc"
- default:
- return prefix + " desc"
- }
- default:
- return "id desc"
- }
- }
- func GetAllChannels() (channels []*Channel, err error) {
- tx := DB.Model(&Channel{})
- err = tx.Order("id desc").Find(&channels).Error
- return channels, err
- }
- func GetChannels(
- page, perPage, id int,
- name, key string,
- channelType int,
- baseURL, order string,
- ) (channels []*Channel, total int64, err error) {
- tx := DB.Model(&Channel{})
- if id != 0 {
- tx = tx.Where("id = ?", id)
- }
- if name != "" {
- tx = tx.Where("name = ?", name)
- }
- if key != "" {
- tx = tx.Where("key = ?", key)
- }
- if channelType != 0 {
- tx = tx.Where("type = ?", channelType)
- }
- if baseURL != "" {
- tx = tx.Where("base_url = ?", baseURL)
- }
- err = tx.Count(&total).Error
- if err != nil {
- return nil, 0, err
- }
- if total <= 0 {
- return nil, 0, nil
- }
- limit, offset := toLimitOffset(page, perPage)
- err = tx.Order(getChannelOrder(order)).Limit(limit).Offset(offset).Find(&channels).Error
- return channels, total, err
- }
- func SearchChannels(
- keyword string,
- page, perPage, id int,
- name, key string,
- channelType int,
- baseURL, order string,
- ) (channels []*Channel, total int64, err error) {
- tx := DB.Model(&Channel{})
- // Handle exact match conditions for non-zero values
- if id != 0 {
- tx = tx.Where("id = ?", id)
- }
- if name != "" {
- tx = tx.Where("name = ?", name)
- }
- if key != "" {
- tx = tx.Where("key = ?", key)
- }
- if channelType != 0 {
- tx = tx.Where("type = ?", channelType)
- }
- if baseURL != "" {
- tx = tx.Where("base_url = ?", baseURL)
- }
- // Handle keyword search for zero value fields
- if keyword != "" {
- var (
- conditions []string
- values []any
- )
- keywordInt := String2Int(keyword)
- if keywordInt != 0 {
- if id == 0 {
- conditions = append(conditions, "id = ?")
- values = append(values, keywordInt)
- }
- }
- if name == "" {
- if !common.UsingSQLite {
- conditions = append(conditions, "name ILIKE ?")
- } else {
- conditions = append(conditions, "name LIKE ?")
- }
- values = append(values, "%"+keyword+"%")
- }
- if key == "" {
- if !common.UsingSQLite {
- conditions = append(conditions, "key ILIKE ?")
- } else {
- conditions = append(conditions, "key LIKE ?")
- }
- values = append(values, "%"+keyword+"%")
- }
- if baseURL == "" {
- if !common.UsingSQLite {
- conditions = append(conditions, "base_url ILIKE ?")
- } else {
- conditions = append(conditions, "base_url LIKE ?")
- }
- values = append(values, "%"+keyword+"%")
- }
- if !common.UsingSQLite {
- conditions = append(conditions, "models ILIKE ?")
- } else {
- conditions = append(conditions, "models LIKE ?")
- }
- values = append(values, "%"+keyword+"%")
- if !common.UsingSQLite {
- conditions = append(conditions, "sets ILIKE ?")
- } else {
- conditions = append(conditions, "sets LIKE ?")
- }
- values = append(values, "%"+keyword+"%")
- if len(conditions) > 0 {
- tx = tx.Where(fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")), values...)
- }
- }
- err = tx.Count(&total).Error
- if err != nil {
- return nil, 0, err
- }
- if total <= 0 {
- return nil, 0, nil
- }
- limit, offset := toLimitOffset(page, perPage)
- err = tx.Order(getChannelOrder(order)).Limit(limit).Offset(offset).Find(&channels).Error
- return channels, total, err
- }
- func GetChannelByID(id int) (*Channel, error) {
- channel := Channel{ID: id}
- err := DB.First(&channel, "id = ?", id).Error
- return &channel, HandleNotFound(err, ErrChannelNotFound)
- }
- func BatchInsertChannels(channels []*Channel) (err error) {
- defer func() {
- if err == nil {
- _ = InitModelConfigAndChannelCache()
- }
- }()
- for _, channel := range channels {
- if err := CheckModelConfigExist(channel.Models); err != nil {
- return err
- }
- }
- return DB.Transaction(func(tx *gorm.DB) error {
- return tx.Create(&channels).Error
- })
- }
- func UpdateChannel(channel *Channel) (err error) {
- defer func() {
- if err == nil {
- _ = InitModelConfigAndChannelCache()
- _ = monitor.ClearChannelAllModelErrors(context.Background(), channel.ID)
- }
- }()
- if err := CheckModelConfigExist(channel.Models); err != nil {
- return err
- }
- selects := []string{
- "model_mapping",
- "key",
- "base_url",
- "models",
- "priority",
- "config",
- "enabled_auto_balance_check",
- "balance_threshold",
- "sets",
- }
- if channel.Type != 0 {
- selects = append(selects, "type")
- }
- if channel.Name != "" {
- selects = append(selects, "name")
- }
- result := DB.
- Select(selects).
- Clauses(clause.Returning{}).
- Where("id = ?", channel.ID).
- Updates(channel)
- return HandleUpdateResult(result, ErrChannelNotFound)
- }
- func ClearLastTestErrorAt(id int) error {
- result := DB.Model(&Channel{}).
- Where("id = ?", id).
- Update("last_test_error_at", gorm.Expr("NULL"))
- return HandleUpdateResult(result, ErrChannelNotFound)
- }
- func (c *Channel) UpdateModelTest(
- testAt time.Time,
- model, actualModel string,
- mode mode.Mode,
- took float64,
- success bool,
- response string,
- code int,
- ) (*ChannelTest, error) {
- var ct *ChannelTest
- err := DB.Transaction(func(tx *gorm.DB) error {
- if !success {
- result := tx.Model(&Channel{}).
- Where("id = ?", c.ID).
- Update("last_test_error_at", testAt)
- if err := HandleUpdateResult(result, ErrChannelNotFound); err != nil {
- return err
- }
- } else if !c.LastTestErrorAt.IsZero() && time.Since(c.LastTestErrorAt) > time.Hour {
- result := tx.Model(&Channel{}).Where("id = ?", c.ID).Update("last_test_error_at", gorm.Expr("NULL"))
- if err := HandleUpdateResult(result, ErrChannelNotFound); err != nil {
- return err
- }
- }
- ct = &ChannelTest{
- ChannelID: c.ID,
- ChannelType: c.Type,
- ChannelName: c.Name,
- Model: model,
- ActualModel: actualModel,
- Mode: mode,
- TestAt: testAt,
- Took: took,
- Success: success,
- Response: response,
- Code: code,
- }
- result := tx.Save(ct)
- return HandleUpdateResult(result, ErrChannelNotFound)
- })
- if err != nil {
- return nil, err
- }
- return ct, nil
- }
- func (c *Channel) UpdateBalance(balance float64) error {
- result := DB.Model(&Channel{}).
- Select("balance_updated_at", "balance").
- Where("id = ?", c.ID).
- Updates(Channel{
- BalanceUpdatedAt: time.Now(),
- Balance: balance,
- })
- return HandleUpdateResult(result, ErrChannelNotFound)
- }
- func DeleteChannelByID(id int) (err error) {
- defer func() {
- if err == nil {
- _ = InitModelConfigAndChannelCache()
- _ = monitor.ClearChannelAllModelErrors(context.Background(), id)
- }
- }()
- result := DB.Delete(&Channel{ID: id})
- return HandleUpdateResult(result, ErrChannelNotFound)
- }
- func DeleteChannelsByIDs(ids []int) (err error) {
- defer func() {
- if err == nil {
- _ = InitModelConfigAndChannelCache()
- for _, id := range ids {
- _ = monitor.ClearChannelAllModelErrors(context.Background(), id)
- }
- }
- }()
- return DB.Transaction(func(tx *gorm.DB) error {
- return tx.
- Where("id IN (?)", ids).
- Delete(&Channel{}).
- Error
- })
- }
- func UpdateChannelStatusByID(id, status int) error {
- result := DB.Model(&Channel{}).
- Where("id = ?", id).
- Update("status", status)
- return HandleUpdateResult(result, ErrChannelNotFound)
- }
- func UpdateChannelUsedAmount(id int, amount float64, requestCount, retryCount int) error {
- result := DB.Model(&Channel{}).
- Where("id = ?", id).
- Updates(map[string]any{
- "used_amount": gorm.Expr("used_amount + ?", amount),
- "request_count": gorm.Expr("request_count + ?", requestCount),
- "retry_count": gorm.Expr("retry_count + ?", retryCount),
- })
- return HandleUpdateResult(result, ErrChannelNotFound)
- }
|