Browse Source

feat: 支持未开启缓存下本地重试

CaIon 1 year ago
parent
commit
462c328d4b
2 changed files with 50 additions and 4 deletions
  1. 49 3
      model/ability.go
  2. 1 1
      model/cache.go

+ 49 - 3
model/ability.go

@@ -3,6 +3,7 @@ package model
 import (
 	"errors"
 	"fmt"
+	"gorm.io/gorm"
 	"one-api/common"
 	"strings"
 )
@@ -27,8 +28,7 @@ func GetGroupModels(group string) []string {
 	return models
 }
 
-func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
-	var abilities []Ability
+func getPriority(group string, model string, retry int) (int, error) {
 	groupCol := "`group`"
 	trueVal := "1"
 	if common.UsingPostgreSQL {
@@ -36,9 +36,55 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 		trueVal = "true"
 	}
 
-	var err error = nil
+	var priorities []int
+	err := DB.Model(&Ability{}).
+		Select("DISTINCT(priority)").
+		Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model).
+		Order("priority DESC").              // 按优先级降序排序
+		Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
+
+	if err != nil {
+		// 处理错误
+		return 0, err
+	}
+
+	// 确定要使用的优先级
+	var priorityToUse int
+	if retry >= len(priorities) {
+		// 如果重试次数大于优先级数,则使用最小的优先级
+		priorityToUse = priorities[len(priorities)-1]
+	} else {
+		priorityToUse = priorities[retry]
+	}
+	return priorityToUse, nil
+}
+
+func getChannelQuery(group string, model string, retry int) *gorm.DB {
+	groupCol := "`group`"
+	trueVal := "1"
+	if common.UsingPostgreSQL {
+		groupCol = `"group"`
+		trueVal = "true"
+	}
 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
 	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
+	if retry != 0 {
+		priority, err := getPriority(group, model, retry)
+		if err != nil {
+			common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
+		} else {
+			channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = ?", group, model, priority)
+		}
+	}
+
+	return channelQuery
+}
+
+func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
+	var abilities []Ability
+
+	var err error = nil
+	channelQuery := getChannelQuery(group, model, retry)
 	if common.UsingSQLite || common.UsingPostgreSQL {
 		err = channelQuery.Order("weight DESC").Find(&abilities).Error
 	} else {

+ 1 - 1
model/cache.go

@@ -272,7 +272,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Cha
 
 	// if memory cache is disabled, get channel directly from database
 	if !common.MemoryCacheEnabled {
-		return GetRandomSatisfiedChannel(group, model)
+		return GetRandomSatisfiedChannel(group, model, retry)
 	}
 	channelSyncLock.RLock()
 	defer channelSyncLock.RUnlock()