Ver código fonte

Merge pull request #1346 from QuantumNous/fix-ability

✨ feat(ability): enhance FixAbility function
Calcium-Ion 5 meses atrás
pai
commit
78fb457765

+ 5 - 2
controller/channel.go

@@ -228,7 +228,7 @@ func FetchUpstreamModels(c *gin.Context) {
 }
 
 func FixChannelsAbilities(c *gin.Context) {
-	count, err := model.FixAbility()
+	success, fails, err := model.FixAbility()
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
@@ -239,7 +239,10 @@ func FixChannelsAbilities(c *gin.Context) {
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"message": "",
-		"data":    count,
+		"data": gin.H{
+			"success": success,
+			"fails":   fails,
+		},
 	})
 }
 

+ 2 - 2
main.go

@@ -68,9 +68,9 @@ func main() {
 				if r := recover(); r != nil {
 					common.SysError(fmt.Sprintf("InitChannelCache panic: %v, retrying once", r))
 					// Retry once
-					_, fixErr := model.FixAbility()
+					_, _, fixErr := model.FixAbility()
 					if fixErr != nil {
-						common.SysError(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
+						common.FatalLog(fmt.Sprintf("InitChannelCache failed: %s", fixErr.Error()))
 					}
 				}
 			}()

+ 33 - 61
model/ability.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"one-api/common"
 	"strings"
+	"sync"
 
 	"github.com/samber/lo"
 	"gorm.io/gorm"
@@ -272,74 +273,45 @@ func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uin
 	return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
 }
 
-func FixAbility() (int, error) {
-	var channelIds []int
-	count := 0
-	// Find all channel ids from channel table
-	err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
-	if err != nil {
-		common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
-		return 0, err
-	}
+var fixLock = sync.Mutex{}
 
-	// Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
-	if len(channelIds) > 0 {
-		// Process deletion in chunks to avoid "too many placeholders" error
-		for _, chunk := range lo.Chunk(channelIds, 100) {
-			err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
-			if err != nil {
-				common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
-				return 0, err
-			}
-		}
-	} else {
-		// If no channels exist, delete all abilities
-		err = DB.Delete(&Ability{}).Error
-		if err != nil {
-			common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
-			return 0, err
-		}
-		common.SysLog("Delete all abilities successfully")
-		return 0, nil
+func FixAbility() (int, int, error) {
+	lock := fixLock.TryLock()
+	if !lock {
+		return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
 	}
-
-	common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
-	count += len(channelIds)
-
-	// Use channelIds to find channel not in abilities table
-	var abilityChannelIds []int
-	err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
+	defer fixLock.Unlock()
+	var channels []*Channel
+	// Find all channels
+	err := DB.Model(&Channel{}).Find(&channels).Error
 	if err != nil {
-		common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
-		return count, err
+		return 0, 0, err
 	}
-
-	var channels []Channel
-	if len(abilityChannelIds) == 0 {
-		err = DB.Find(&channels).Error
-	} else {
-		// Process query in chunks to avoid "too many placeholders" error
-		err = nil
-		for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
-			var channelsChunk []Channel
-			err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
-			if err != nil {
-				common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
-				return count, err
-			}
-			channels = append(channels, channelsChunk...)
-		}
+	if len(channels) == 0 {
+		return 0, 0, nil
 	}
-
-	for _, channel := range channels {
-		err := channel.UpdateAbilities(nil)
+	successCount := 0
+	failCount := 0
+	for _, chunk := range lo.Chunk(channels, 50) {
+		ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
+		// Delete all abilities of this channel
+		err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
 		if err != nil {
-			common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
-		} else {
-			common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
-			count++
+			common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
+			failCount += len(chunk)
+			continue
+		}
+		// Then add new abilities
+		for _, channel := range chunk {
+			err = channel.AddAbilities()
+			if err != nil {
+				common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
+				failCount++
+			} else {
+				successCount++
+			}
 		}
 	}
 	InitChannelCache()
-	return count, nil
+	return successCount, failCount, nil
 }

+ 2 - 2
web/src/components/table/ChannelsTable.js

@@ -1461,9 +1461,9 @@ const ChannelsTable = () => {
 
   const fixChannelsAbilities = async () => {
     const res = await API.post(`/api/channel/fix`);
-    const { success, message, data } = res.data;
+    const { success, message, data  } = res.data;
     if (success) {
-      showSuccess(t('已修复 ${data} 个通道!').replace('${data}', data));
+      showSuccess(t('已修复 ${success} 个通道,失败 ${fails} 个通道。').replace('${success}', data.success).replace('${fails}', data.fails));
       await refresh();
     } else {
       showError(message);

+ 1 - 1
web/src/pages/Channel/EditChannel.js

@@ -240,7 +240,7 @@ const EditChannel = (props) => {
     if (isEdit) {
       // 如果是编辑模式,使用已有的channel id获取模型列表
       const res = await API.get('/api/channel/fetch_models/' + channelId);
-      if (res.data && res.data?.success) {
+      if (res.data && res.data.success) {
         models.push(...res.data.data);
       } else {
         err = true;