Explorar el Código

fix: update token usage calculation

CaIon hace 4 meses
padre
commit
c834694992
Se han modificado 3 ficheros con 13 adiciones y 20 borrados
  1. 7 10
      model/log.go
  2. 0 6
      model/usedata.go
  3. 6 4
      relay/channel/openai/relay_responses.go

+ 7 - 10
model/log.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"one-api/common"
 	"one-api/logger"
+	"one-api/types"
 	"os"
 	"strings"
 	"time"
@@ -150,10 +151,10 @@ type RecordConsumeLogParams struct {
 }
 
 func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) {
-	logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
 	if !common.LogConsumeEnabled {
 		return
 	}
+	logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params)))
 	username := c.GetString("username")
 	otherStr := common.MapToJsonStr(params.Other)
 	// 判断是否需要记录 IP
@@ -236,26 +237,22 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 		return nil, 0, err
 	}
 
-	channelIdsMap := make(map[int]struct{})
-	channelMap := make(map[int]string)
+	channelIds := types.NewSet[int]()
 	for _, log := range logs {
 		if log.ChannelId != 0 {
-			channelIdsMap[log.ChannelId] = struct{}{}
+			channelIds.Add(log.ChannelId)
 		}
 	}
 
-	channelIds := make([]int, 0, len(channelIdsMap))
-	for channelId := range channelIdsMap {
-		channelIds = append(channelIds, channelId)
-	}
-	if len(channelIds) > 0 {
+	if channelIds.Len() > 0 {
 		var channels []struct {
 			Id   int    `gorm:"column:id"`
 			Name string `gorm:"column:name"`
 		}
-		if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds).Find(&channels).Error; err != nil {
+		if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil {
 			return logs, total, err
 		}
+		channelMap := make(map[int]string, len(channels))
 		for _, channel := range channels {
 			channelMap[channel.Id] = channel.Name
 		}

+ 0 - 6
model/usedata.go

@@ -21,12 +21,6 @@ type QuotaData struct {
 }
 
 func UpdateQuotaData() {
-	// recover
-	defer func() {
-		if r := recover(); r != nil {
-			common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r))
-		}
-	}()
 	for {
 		if common.DataExportEnabled {
 			common.SysLog("正在更新数据看板数据...")

+ 6 - 4
relay/channel/openai/relay_responses.go

@@ -103,12 +103,14 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
 			// 非正常结束,使用输出文本的 token 数量
 			completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
 			usage.CompletionTokens = completionTokens
-
-			if usage.PromptTokens == 0 {
-				usage.PromptTokens = info.PromptTokens
-			}
 		}
 	}
 
+	if usage.PromptTokens == 0 && usage.CompletionTokens != 0 {
+		usage.PromptTokens = usage.CompletionTokens
+	} else {
+		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+	}
+
 	return usage, nil
 }