|
|
@@ -2,9 +2,8 @@ package model
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
- "os"
|
|
|
- "strings"
|
|
|
"time"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
@@ -66,16 +65,8 @@ func formatUserLogs(logs []*Log) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func GetLogByKey(key string) (logs []*Log, err error) {
|
|
|
- if os.Getenv("LOG_SQL_DSN") != "" {
|
|
|
- var tk Token
|
|
|
- if err = DB.Model(&Token{}).Where(logKeyCol+"=?", strings.TrimPrefix(key, "sk-")).First(&tk).Error; err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- err = LOG_DB.Model(&Log{}).Where("token_id=?", tk.Id).Find(&logs).Error
|
|
|
- } else {
|
|
|
- err = LOG_DB.Joins("left join tokens on tokens.id = logs.token_id").Where("tokens.key = ?", strings.TrimPrefix(key, "sk-")).Find(&logs).Error
|
|
|
- }
|
|
|
+func GetLogByTokenId(tokenId int) (logs []*Log, err error) {
|
|
|
+ err = LOG_DB.Model(&Log{}).Where("token_id = ?", tokenId).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
|
|
|
formatUserLogs(logs)
|
|
|
return logs, err
|
|
|
}
|
|
|
@@ -276,6 +267,8 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|
|
return logs, total, err
|
|
|
}
|
|
|
|
|
|
+const logSearchCountLimit = 10000
|
|
|
+
|
|
|
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int, group string, requestId string) (logs []*Log, total int64, err error) {
|
|
|
var tx *gorm.DB
|
|
|
if logType == LogTypeUnknown {
|
|
|
@@ -285,7 +278,11 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
|
|
}
|
|
|
|
|
|
if modelName != "" {
|
|
|
- tx = tx.Where("logs.model_name like ?", modelName)
|
|
|
+ modelNamePattern, err := sanitizeLikePattern(modelName)
|
|
|
+ if err != nil {
|
|
|
+ return nil, 0, err
|
|
|
+ }
|
|
|
+ tx = tx.Where("logs.model_name LIKE ? ESCAPE '!'", modelNamePattern)
|
|
|
}
|
|
|
if tokenName != "" {
|
|
|
tx = tx.Where("logs.token_name = ?", tokenName)
|
|
|
@@ -302,37 +299,28 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
|
|
if group != "" {
|
|
|
tx = tx.Where("logs."+logGroupCol+" = ?", group)
|
|
|
}
|
|
|
- err = tx.Model(&Log{}).Count(&total).Error
|
|
|
+ err = tx.Model(&Log{}).Limit(logSearchCountLimit).Count(&total).Error
|
|
|
if err != nil {
|
|
|
- return nil, 0, err
|
|
|
+ common.SysError("failed to count user logs: " + err.Error())
|
|
|
+ return nil, 0, errors.New("查询日志失败")
|
|
|
}
|
|
|
err = tx.Order("logs.id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
|
|
if err != nil {
|
|
|
- return nil, 0, err
|
|
|
+ common.SysError("failed to search user logs: " + err.Error())
|
|
|
+ return nil, 0, errors.New("查询日志失败")
|
|
|
}
|
|
|
|
|
|
formatUserLogs(logs)
|
|
|
return logs, total, err
|
|
|
}
|
|
|
|
|
|
-func SearchAllLogs(keyword string) (logs []*Log, err error) {
|
|
|
- err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
|
|
|
- return logs, err
|
|
|
-}
|
|
|
-
|
|
|
-func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
|
|
|
- err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
|
|
|
- formatUserLogs(logs)
|
|
|
- return logs, err
|
|
|
-}
|
|
|
-
|
|
|
type Stat struct {
|
|
|
Quota int `json:"quota"`
|
|
|
Rpm int `json:"rpm"`
|
|
|
Tpm int `json:"tpm"`
|
|
|
}
|
|
|
|
|
|
-func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat) {
|
|
|
+func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int, group string) (stat Stat, err error) {
|
|
|
tx := LOG_DB.Table("logs").Select("sum(quota) quota")
|
|
|
|
|
|
// 为rpm和tpm创建单独的查询
|
|
|
@@ -353,8 +341,12 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|
|
tx = tx.Where("created_at <= ?", endTimestamp)
|
|
|
}
|
|
|
if modelName != "" {
|
|
|
- tx = tx.Where("model_name like ?", modelName)
|
|
|
- rpmTpmQuery = rpmTpmQuery.Where("model_name like ?", modelName)
|
|
|
+ modelNamePattern, err := sanitizeLikePattern(modelName)
|
|
|
+ if err != nil {
|
|
|
+ return stat, err
|
|
|
+ }
|
|
|
+ tx = tx.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
|
|
|
+ rpmTpmQuery = rpmTpmQuery.Where("model_name LIKE ? ESCAPE '!'", modelNamePattern)
|
|
|
}
|
|
|
if channel != 0 {
|
|
|
tx = tx.Where("channel_id = ?", channel)
|
|
|
@@ -372,10 +364,16 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|
|
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
|
|
|
|
|
|
// 执行查询
|
|
|
- tx.Scan(&stat)
|
|
|
- rpmTpmQuery.Scan(&stat)
|
|
|
+ if err := tx.Scan(&stat).Error; err != nil {
|
|
|
+ common.SysError("failed to query log stat: " + err.Error())
|
|
|
+ return stat, errors.New("查询统计数据失败")
|
|
|
+ }
|
|
|
+ if err := rpmTpmQuery.Scan(&stat).Error; err != nil {
|
|
|
+ common.SysError("failed to query rpm/tpm stat: " + err.Error())
|
|
|
+ return stat, errors.New("查询统计数据失败")
|
|
|
+ }
|
|
|
|
|
|
- return stat
|
|
|
+ return stat, nil
|
|
|
}
|
|
|
|
|
|
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
|