2
0
Эх сурвалжийг харах

Merge remote-tracking branch 'origin/main'

# Conflicts:
#	controller/log.go
#	controller/relay-audio.go
#	controller/relay-image.go
#	controller/relay-text.go
#	controller/relay.go
#	middleware/distributor.go
#	model/log.go
#	web/src/components/OperationSetting.js
CaIon 2 жил өмнө
parent
commit
985e26fd1b

+ 2 - 1
.gitignore

@@ -4,4 +4,5 @@ upload
 *.exe
 *.db
 build
-*.db-journal
+*.db-journal
+logs

+ 1 - 1
README.md

@@ -340,7 +340,7 @@ graph LR
 ### 命令行参数
 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
    + 例子:`--port 3000`
-2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存
+2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下
    + 例子:`--log-dir ./logs`
 3. `--version`: 打印系统版本号并退出。
 4. `--help`: 查看命令的使用帮助和参数说明。

+ 4 - 0
common/constants.go

@@ -101,6 +101,10 @@ var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQU
 var BatchUpdateEnabled = false
 var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
 
+const (
+	RequestIdKey = "X-Oneapi-Request-Id"
+)
+
 const (
 	RoleGuestUser  = 0
 	RoleCommonUser = 1

+ 1 - 1
common/init.go

@@ -12,7 +12,7 @@ var (
 	Port         = flag.Int("port", 3000, "the listening port")
 	PrintVersion = flag.Bool("version", false, "print version and exit")
 	PrintHelp    = flag.Bool("help", false, "print help and exit")
-	LogDir       = flag.String("log-dir", "", "specify the log directory")
+	LogDir       = flag.String("log-dir", "./logs", "specify the log directory")
 )
 
 func printHelp() {

+ 57 - 9
common/logger.go

@@ -1,29 +1,47 @@
 package common
 
 import (
+	"context"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"io"
 	"log"
 	"os"
 	"path/filepath"
+	"sync"
 	"time"
 )
 
-func SetupGinLog() {
+const (
+	loggerINFO  = "INFO"
+	loggerWarn  = "WARN"
+	loggerError = "ERR"
+)
+
+const maxLogCount = 1000000
+
+var logCount int
+var setupLogLock sync.Mutex
+var setupLogWorking bool
+
+func SetupLogger() {
 	if *LogDir != "" {
-		commonLogPath := filepath.Join(*LogDir, "common.log")
-		errorLogPath := filepath.Join(*LogDir, "error.log")
-		commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
-		if err != nil {
-			log.Fatal("failed to open log file")
+		ok := setupLogLock.TryLock()
+		if !ok {
+			log.Println("setup log is already working")
+			return
 		}
-		errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
+		defer func() {
+			setupLogLock.Unlock()
+			setupLogWorking = false
+		}()
+		logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
+		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
 		if err != nil {
 			log.Fatal("failed to open log file")
 		}
-		gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
-		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
+		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
+		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
 	}
 }
 
@@ -37,6 +55,36 @@ func SysError(s string) {
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
 }
 
+func LogInfo(ctx context.Context, msg string) {
+	logHelper(ctx, loggerINFO, msg)
+}
+
+func LogWarn(ctx context.Context, msg string) {
+	logHelper(ctx, loggerWarn, msg)
+}
+
+func LogError(ctx context.Context, msg string) {
+	logHelper(ctx, loggerError, msg)
+}
+
+func logHelper(ctx context.Context, level string, msg string) {
+	writer := gin.DefaultErrorWriter
+	if level == loggerINFO {
+		writer = gin.DefaultWriter
+	}
+	id := ctx.Value(RequestIdKey)
+	now := time.Now()
+	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
+	logCount++ // we don't need accurate count, so no lock here
+	if logCount > maxLogCount && !setupLogWorking {
+		logCount = 0
+		setupLogWorking = true
+		go func() {
+			SetupLogger()
+		}()
+	}
+}
+
 func FatalLog(v ...any) {
 	t := time.Now()
 	_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)

+ 9 - 0
common/utils.go

@@ -171,6 +171,11 @@ func GetTimestamp() int64 {
 	return time.Now().Unix()
 }
 
+func GetTimeString() string {
+	now := time.Now()
+	return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
+}
+
 func Max(a int, b int) int {
 	if a >= b {
 		return a
@@ -190,3 +195,7 @@ func GetOrDefault(env string, defaultValue int) int {
 	}
 	return num
 }
+
+func MessageWithRequestId(message string, id string) string {
+	return fmt.Sprintf("%s (request id: %s)", message, id)
+}

+ 1 - 1
controller/billing.go

@@ -29,7 +29,7 @@ func GetSubscription(c *gin.Context) {
 	if err != nil {
 		openAIError := OpenAIError{
 			Message: err.Error(),
-			Type:    "one_api_error",
+			Type:    "upstream_error",
 		}
 		c.JSON(200, gin.H{
 			"error": openAIError,

+ 45 - 12
controller/log.go

@@ -18,19 +18,21 @@ func GetAllLogs(c *gin.Context) {
 	username := c.Query("username")
 	tokenName := c.Query("token_name")
 	modelName := c.Query("model_name")
-	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
+	channel, _ := strconv.Atoi(c.Query("channel"))
+	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
 	if err != nil {
-		c.JSON(200, gin.H{
+		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"message": err.Error(),
 		})
 		return
 	}
-	c.JSON(200, gin.H{
+	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"message": "",
 		"data":    logs,
 	})
+	return
 }
 
 func GetUserLogs(c *gin.Context) {
@@ -46,34 +48,36 @@ func GetUserLogs(c *gin.Context) {
 	modelName := c.Query("model_name")
 	logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
 	if err != nil {
-		c.JSON(200, gin.H{
+		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"message": err.Error(),
 		})
 		return
 	}
-	c.JSON(200, gin.H{
+	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"message": "",
 		"data":    logs,
 	})
+	return
 }
 
 func SearchAllLogs(c *gin.Context) {
 	keyword := c.Query("keyword")
 	logs, err := model.SearchAllLogs(keyword)
 	if err != nil {
-		c.JSON(200, gin.H{
+		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"message": err.Error(),
 		})
 		return
 	}
-	c.JSON(200, gin.H{
+	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"message": "",
 		"data":    logs,
 	})
+	return
 }
 
 func SearchUserLogs(c *gin.Context) {
@@ -81,17 +85,18 @@ func SearchUserLogs(c *gin.Context) {
 	userId := c.GetInt("id")
 	logs, err := model.SearchUserLogs(userId, keyword)
 	if err != nil {
-		c.JSON(200, gin.H{
+		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"message": err.Error(),
 		})
 		return
 	}
-	c.JSON(200, gin.H{
+	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"message": "",
 		"data":    logs,
 	})
+	return
 }
 
 func GetLogByKey(c *gin.Context) {
@@ -118,9 +123,9 @@ func GetLogsStat(c *gin.Context) {
 	tokenName := c.Query("token_name")
 	username := c.Query("username")
 	modelName := c.Query("model_name")
-	stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
+	stat := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
-	c.JSON(200, gin.H{
+	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"message": "",
 		"data": gin.H{
@@ -129,6 +134,7 @@ func GetLogsStat(c *gin.Context) {
 			"tpm":   stat.Tpm,
 		},
 	})
+	return
 }
 
 func GetLogsSelfStat(c *gin.Context) {
@@ -138,7 +144,8 @@ func GetLogsSelfStat(c *gin.Context) {
 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 	tokenName := c.Query("token_name")
 	modelName := c.Query("model_name")
-	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
+	channel, _ := strconv.Atoi(c.Query("channel"))
+	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 	c.JSON(200, gin.H{
 		"success": true,
@@ -150,4 +157,30 @@ func GetLogsSelfStat(c *gin.Context) {
 			//"token": tokenNum,
 		},
 	})
+	return
+}
+
+func DeleteHistoryLogs(c *gin.Context) {
+	targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
+	if targetTimestamp == 0 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "target timestamp is required",
+		})
+		return
+	}
+	count, err := model.DeleteOldLog(targetTimestamp)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    count,
+	})
+	return
 }

+ 5 - 3
controller/relay-audio.go

@@ -2,6 +2,7 @@ package controller
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -17,6 +18,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	tokenId := c.GetInt("token_id")
 	channelType := c.GetInt("channel")
+	channelId := c.GetInt("channel_id")
 	userId := c.GetInt("id")
 	group := c.GetString("group")
 
@@ -91,7 +93,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	}
 	var audioResponse AudioResponse
 
-	defer func() {
+	defer func(ctx context.Context) {
 		go func() {
 			quota := countTokenText(audioResponse.Text, audioModel)
 			quotaDelta := quota - preConsumedQuota
@@ -106,13 +108,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent, tokenId)
+				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent, tokenId)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)
 			}
 		}()
-	}()
+	}(c.Request.Context())
 
 	responseBody, err := io.ReadAll(resp.Body)
 

+ 5 - 3
controller/relay-image.go

@@ -2,6 +2,7 @@ package controller
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -18,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	tokenId := c.GetInt("token_id")
 	channelType := c.GetInt("channel")
+	channelId := c.GetInt("channel_id")
 	userId := c.GetInt("id")
 	consumeQuota := c.GetBool("consume_quota")
 	group := c.GetString("group")
@@ -124,7 +126,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 	}
 	var textResponse ImageResponse
 
-	defer func() {
+	defer func(ctx context.Context) {
 		if consumeQuota {
 			err := model.PostConsumeTokenQuota(tokenId, quota)
 			if err != nil {
@@ -137,13 +139,13 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
+				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)
 			}
 		}
-	}()
+	}(c.Request.Context())
 
 	if consumeQuota {
 		responseBody, err := io.ReadAll(resp.Body)

+ 28 - 26
controller/relay-text.go

@@ -2,6 +2,7 @@ package controller
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -37,6 +38,7 @@ func init() {
 
 func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	channelType := c.GetInt("channel")
+	channelId := c.GetInt("channel_id")
 	tokenId := c.GetInt("token_id")
 	userId := c.GetInt("id")
 	consumeQuota := c.GetBool("consume_quota")
@@ -108,7 +110,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	case common.ChannelTypeAIProxyLibrary:
 		apiType = APITypeAIProxyLibrary
 	}
-
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
 	if c.GetString("base_url") != "" {
@@ -211,6 +212,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		// in this case, we do not pre-consume quota
 		// because the user has enough quota
 		preConsumedQuota = 0
+		common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
 	}
 	if consumeQuota && preConsumedQuota > 0 {
 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
@@ -339,9 +341,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		}
 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
-
 		//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
-		req.Close = true
 		resp, err = httpClient.Do(req)
 		if err != nil {
 			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
@@ -358,13 +358,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 
 		if resp.StatusCode != http.StatusOK {
 			if preConsumedQuota != 0 {
-				go func() {
+				go func(ctx context.Context) {
 					// return pre-consumed quota
 					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
 					if err != nil {
-						common.SysError("error return pre-consumed quota: " + err.Error())
+						common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
 					}
-				}()
+				}(c.Request.Context())
 			}
 			return relayErrorHandler(resp)
 		}
@@ -374,7 +374,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	tokenName := c.GetString("token_name")
 	channelId := c.GetInt("channel_id")
 
-	defer func() {
+	defer func(ctx context.Context) {
 		// c.Writer.Flush()
 		go func() {
 			if consumeQuota {
@@ -397,21 +397,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 				quotaDelta := quota - preConsumedQuota
 				err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 				if err != nil {
-					common.SysError("error consuming token remain quota: " + err.Error())
+					common.LogError(ctx, "error consuming token remain quota: "+err.Error())
 				}
 				err = model.CacheUpdateUserQuota(userId)
 				if err != nil {
-					common.SysError("error update user quota cache: " + err.Error())
+					common.LogError(ctx, "error update user quota cache: "+err.Error())
 				}
 				if quota != 0 {
 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
+					model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 					model.UpdateChannelUsedQuota(channelId, quota)
 				}
 			}
 		}()
-	}()
+	}(c.Request.Context())
 	switch apiType {
 	case APITypeOpenAI:
 		if isStream {
@@ -549,24 +549,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			return nil
 		}
 	case APITypeXunfei:
+		auth := c.Request.Header.Get("Authorization")
+		auth = strings.TrimPrefix(auth, "Bearer ")
+		splits := strings.Split(auth, "|")
+		if len(splits) != 3 {
+			return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
+		}
+		var err *OpenAIErrorWithStatusCode
+		var usage *Usage
 		if isStream {
-			auth := c.Request.Header.Get("Authorization")
-			auth = strings.TrimPrefix(auth, "Bearer ")
-			splits := strings.Split(auth, "|")
-			if len(splits) != 3 {
-				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
-			}
-			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
-			if err != nil {
-				return err
-			}
-			if usage != nil {
-				textResponse.Usage = *usage
-			}
-			return nil
+			err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
 		} else {
-			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
+			err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
+		}
+		if err != nil {
+			return err
+		}
+		if usage != nil {
+			textResponse.Usage = *usage
 		}
+		return nil
 	case APITypeAIProxyLibrary:
 		if isStream {
 			err, usage := aiProxyLibraryStreamHandler(c, resp)

+ 1 - 1
controller/relay-utils.go

@@ -150,7 +150,7 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
 		StatusCode: resp.StatusCode,
 		OpenAIError: OpenAIError{
 			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
-			Type:    "one_api_error",
+			Type:    "upstream_error",
 			Code:    "bad_response_status_code",
 			Param:   strconv.Itoa(resp.StatusCode),
 		},

+ 79 - 66
controller/relay-xunfei.go

@@ -118,6 +118,7 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
 			Role:    "assistant",
 			Content: response.Payload.Choices.Text[0].Content,
 		},
+		FinishReason: stopFinishReason,
 	}
 	fullTextResponse := OpenAITextResponse{
 		Object:  "chat.completion",
@@ -177,33 +178,82 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 }
 
 func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
+	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
+	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
+	if err != nil {
+		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
+	}
+	setEventStreamHeaders(c)
 	var usage Usage
-	query := c.Request.URL.Query()
-	apiVersion := query.Get("api-version")
-	if apiVersion == "" {
-		apiVersion = c.GetString("api_version")
+	c.Stream(func(w io.Writer) bool {
+		select {
+		case xunfeiResponse := <-dataChan:
+			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
+			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
+			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
+			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
+			jsonResponse, err := json.Marshal(response)
+			if err != nil {
+				common.SysError("error marshalling stream response: " + err.Error())
+				return true
+			}
+			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+			return true
+		case <-stopChan:
+			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+			return false
+		}
+	})
+	return nil, &usage
+}
+
+func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
+	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
+	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
+	if err != nil {
+		return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
 	}
-	if apiVersion == "" {
-		apiVersion = "v1.1"
-		common.SysLog("api_version not found, use default: " + apiVersion)
+	var usage Usage
+	var content string
+	var xunfeiResponse XunfeiChatResponse
+	stop := false
+	for !stop {
+		select {
+		case xunfeiResponse = <-dataChan:
+			content += xunfeiResponse.Payload.Choices.Text[0].Content
+			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
+			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
+			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
+		case stop = <-stopChan:
+		}
 	}
-	domain := "general"
-	if apiVersion == "v2.1" {
-		domain = "generalv2"
+
+	xunfeiResponse.Payload.Choices.Text[0].Content = content
+
+	response := responseXunfei2OpenAI(&xunfeiResponse)
+	jsonResponse, err := json.Marshal(response)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
-	hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
+	c.Writer.Header().Set("Content-Type", "application/json")
+	_, _ = c.Writer.Write(jsonResponse)
+	return nil, &usage
+}
+
+func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
 	d := websocket.Dialer{
 		HandshakeTimeout: 5 * time.Second,
 	}
-	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
+	conn, resp, err := d.Dial(authUrl, nil)
 	if err != nil || resp.StatusCode != 101 {
-		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
+		return nil, nil, err
 	}
 	data := requestOpenAI2Xunfei(textRequest, appId, domain)
 	err = conn.WriteJSON(data)
 	if err != nil {
-		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
+		return nil, nil, err
 	}
+
 	dataChan := make(chan XunfeiChatResponse)
 	stopChan := make(chan bool)
 	go func() {
@@ -230,61 +280,24 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
 		}
 		stopChan <- true
 	}()
-	setEventStreamHeaders(c)
-	c.Stream(func(w io.Writer) bool {
-		select {
-		case xunfeiResponse := <-dataChan:
-			usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
-			usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
-			usage.TotalTokens += xunfeiResponse.Payload.Usage.Text.TotalTokens
-			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
-			jsonResponse, err := json.Marshal(response)
-			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
-				return true
-			}
-			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
-			return true
-		case <-stopChan:
-			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
-			return false
-		}
-	})
-	return nil, &usage
+
+	return dataChan, stopChan, nil
 }
 
-func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
-	var xunfeiResponse XunfeiChatResponse
-	responseBody, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = resp.Body.Close()
-	if err != nil {
-		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
-	}
-	err = json.Unmarshal(responseBody, &xunfeiResponse)
-	if err != nil {
-		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string, string) {
+	query := c.Request.URL.Query()
+	apiVersion := query.Get("api-version")
+	if apiVersion == "" {
+		apiVersion = c.GetString("api_version")
 	}
-	if xunfeiResponse.Header.Code != 0 {
-		return &OpenAIErrorWithStatusCode{
-			OpenAIError: OpenAIError{
-				Message: xunfeiResponse.Header.Message,
-				Type:    "xunfei_error",
-				Param:   "",
-				Code:    xunfeiResponse.Header.Code,
-			},
-			StatusCode: resp.StatusCode,
-		}, nil
+	if apiVersion == "" {
+		apiVersion = "v1.1"
+		common.SysLog("api_version not found, use default: " + apiVersion)
 	}
-	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
-	jsonResponse, err := json.Marshal(fullTextResponse)
-	if err != nil {
-		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	domain := "general"
+	if apiVersion == "v2.1" {
+		domain = "generalv2"
 	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = c.Writer.Write(jsonResponse)
-	return nil, &fullTextResponse.Usage
+	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
+	return domain, authUrl
 }

+ 7 - 1
controller/relay.go

@@ -205,15 +205,20 @@ func Relay(c *gin.Context) {
 		relayMode = RelayModeImagesGenerations
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
 		relayMode = RelayModeEdits
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
+		relayMode = RelayModeAudio
 	}
 	var err *OpenAIErrorWithStatusCode
 	switch relayMode {
 	case RelayModeImagesGenerations:
 		err = relayImageHelper(c, relayMode)
+	case RelayModeAudio:
+		err = relayAudioHelper(c, relayMode)
 	default:
 		err = relayTextHelper(c, relayMode)
 	}
 	if err != nil {
+		requestId := c.GetString(common.RequestIdKey)
 		retryTimesStr := c.Query("retry")
 		retryTimes, _ := strconv.Atoi(retryTimesStr)
 		if retryTimesStr == "" {
@@ -225,12 +230,13 @@ func Relay(c *gin.Context) {
 			if err.StatusCode == http.StatusTooManyRequests {
 				//err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
 			}
+			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
 			c.JSON(err.StatusCode, gin.H{
 				"error": err.OpenAIError,
 			})
 		}
 		channelId := c.GetInt("channel_id")
-		common.SysError(fmt.Sprintf("relay error (channel #%d): %v ", channelId, err))
+		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 		// https://platform.openai.com/docs/guides/error-codes/api-errors
 		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
 			channelId := c.GetInt("channel_id")

+ 5 - 4
main.go

@@ -21,7 +21,7 @@ var buildFS embed.FS
 var indexPage []byte
 
 func main() {
-	common.SetupGinLog()
+	common.SetupLogger()
 	common.SysLog("One API " + common.Version + " started")
 	if os.Getenv("GIN_MODE") != "debug" {
 		gin.SetMode(gin.ReleaseMode)
@@ -86,11 +86,12 @@ func main() {
 	controller.InitTokenEncoders()
 
 	// Initialize HTTP server
-	server := gin.Default()
+	server := gin.New()
+	server.Use(gin.Recovery())
 	// This will cause SSE not to work!!!
 	//server.Use(gzip.Gzip(gzip.DefaultCompression))
-	server.Use(middleware.CORS())
-
+	server.Use(middleware.RequestId())
+	middleware.SetUpLogger(server)
 	// Initialize session store
 	store := cookie.NewStore([]byte(common.SessionSecret))
 	server.Use(sessions.Sessions("session", store))

+ 4 - 28
middleware/auth.go

@@ -100,34 +100,16 @@ func TokenAuth() func(c *gin.Context) {
 		}
 		token, err := model.ValidateUserToken(key)
 		if err != nil {
-			c.JSON(http.StatusUnauthorized, gin.H{
-				"error": gin.H{
-					"message": err.Error(),
-					"type":    "one_api_error",
-				},
-			})
-			c.Abort()
+			abortWithMessage(c, http.StatusUnauthorized, err.Error())
 			return
 		}
 		userEnabled, err := model.IsUserEnabled(token.UserId)
 		if err != nil {
-			c.JSON(http.StatusInternalServerError, gin.H{
-				"error": gin.H{
-					"message": err.Error(),
-					"type":    "one_api_error",
-				},
-			})
-			c.Abort()
+			abortWithMessage(c, http.StatusInternalServerError, err.Error())
 			return
 		}
 		if !userEnabled {
-			c.JSON(http.StatusForbidden, gin.H{
-				"error": gin.H{
-					"message": "用户已被封禁",
-					"type":    "one_api_error",
-				},
-			})
-			c.Abort()
+			abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
 			return
 		}
 		c.Set("id", token.UserId)
@@ -143,13 +125,7 @@ func TokenAuth() func(c *gin.Context) {
 			if model.IsAdmin(token.UserId) {
 				c.Set("channelId", parts[1])
 			} else {
-				c.JSON(http.StatusForbidden, gin.H{
-					"error": gin.H{
-						"message": "普通用户不支持指定渠道",
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 				return
 			}
 		}

+ 13 - 39
middleware/distributor.go

@@ -2,7 +2,6 @@ package middleware
 
 import (
 	"fmt"
-	"log"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
@@ -26,39 +25,22 @@ func Distribute() func(c *gin.Context) {
 		if ok {
 			id, err := strconv.Atoi(channelId.(string))
 			if err != nil {
-				c.JSON(http.StatusBadRequest, gin.H{
-					"error": gin.H{
-						"message": "无效的渠道 ID",
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
 				return
 			}
 			channel, err = model.GetChannelById(id, true)
 			if err != nil {
-				c.JSON(http.StatusBadRequest, gin.H{
-					"error": gin.H{
-						"message": "无效的渠道 ID",
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
 				return
 			}
 			if channel.Status != common.ChannelStatusEnabled {
-				c.JSON(http.StatusForbidden, gin.H{
-					"error": gin.H{
-						"message": "该渠道已被禁用",
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
 				return
 			}
 		} else {
 			// Select a channel for the user
 			var modelRequest ModelRequest
+			var err error
 			if strings.HasPrefix(c.Request.URL.Path, "/mj") {
 				// Midjourney
 				if modelRequest.Model == "" {
@@ -67,17 +49,17 @@ func Distribute() func(c *gin.Context) {
 			} else {
 				err := common.UnmarshalBodyReusable(c, &modelRequest)
 				if err != nil {
-					log.Println(err)
-					c.JSON(http.StatusBadRequest, gin.H{
-						"error": gin.H{
-							"message": "无效的请求",
-							"type":    "one_api_error",
-						},
-					})
-					c.Abort()
+					abortWithMessage(c, http.StatusBadRequest, "无效的请求")
 					return
 				}
 			}
+			if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
+				err = common.UnmarshalBodyReusable(c, &modelRequest)
+			}
+			if err != nil {
+				abortWithMessage(c, http.StatusBadRequest, "无效的请求")
+				return
+			}
 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 				if modelRequest.Model == "" {
 					modelRequest.Model = "text-moderation-stable"
@@ -98,7 +80,6 @@ func Distribute() func(c *gin.Context) {
 					modelRequest.Model = "whisper-1"
 				}
 			}
-			var err error
 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 			if err != nil {
 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
@@ -106,17 +87,10 @@ func Distribute() func(c *gin.Context) {
 					common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
 					message = "数据库一致性已被破坏,请联系管理员"
 				}
-				c.JSON(http.StatusServiceUnavailable, gin.H{
-					"error": gin.H{
-						"message": message,
-						"type":    "one_api_error",
-					},
-				})
-				c.Abort()
+				abortWithMessage(c, http.StatusServiceUnavailable, message)
 				return
 			}
 		}
-		//log.Printf("Using channel %v", channel)
 		c.Set("channel", channel.Type)
 		c.Set("channel_id", channel.Id)
 		c.Set("channel_name", channel.Name)

+ 25 - 0
middleware/logger.go

@@ -0,0 +1,25 @@
+package middleware
+
+import (
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+)
+
+func SetUpLogger(server *gin.Engine) {
+	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
+		var requestID string
+		if param.Keys != nil {
+			requestID = param.Keys[common.RequestIdKey].(string)
+		}
+		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
+			param.TimeStamp.Format("2006/01/02 - 15:04:05"),
+			requestID,
+			param.StatusCode,
+			param.Latency,
+			param.ClientIP,
+			param.Method,
+			param.Path,
+		)
+	}))
+}

+ 18 - 0
middleware/request-id.go

@@ -0,0 +1,18 @@
+package middleware
+
+import (
+	"context"
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+)
+
+func RequestId() func(c *gin.Context) {
+	return func(c *gin.Context) {
+		id := common.GetTimeString() + common.GetRandomString(8)
+		c.Set(common.RequestIdKey, id)
+		ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id)
+		c.Request = c.Request.WithContext(ctx)
+		c.Header(common.RequestIdKey, id)
+		c.Next()
+	}
+}

+ 17 - 0
middleware/utils.go

@@ -0,0 +1,17 @@
+package middleware
+
+import (
+	"github.com/gin-gonic/gin"
+	"one-api/common"
+)
+
+func abortWithMessage(c *gin.Context, statusCode int, message string) {
+	c.JSON(statusCode, gin.H{
+		"error": gin.H{
+			"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
+			"type":    "one_api_error",
+		},
+	})
+	c.Abort()
+	common.LogError(c.Request.Context(), message)
+}

+ 4 - 2
model/ability.go

@@ -10,15 +10,16 @@ type Ability struct {
 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"`
 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
 	Enabled   bool   `json:"enabled"`
+	Priority  int64  `json:"priority" gorm:"bigint;default:0"`
 }
 
 func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 	ability := Ability{}
 	var err error = nil
 	if common.UsingSQLite {
-		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
+		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
 	} else {
-		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
+		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
 	}
 	if err != nil {
 		return nil, err
@@ -40,6 +41,7 @@ func (channel *Channel) AddAbilities() error {
 				Model:     model,
 				ChannelId: channel.Id,
 				Enabled:   channel.Status == common.ChannelStatusEnabled,
+				Priority:  channel.Priority,
 			}
 			abilities = append(abilities, ability)
 		}

+ 17 - 0
model/cache.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"math/rand"
 	"one-api/common"
+	"sort"
 	"strconv"
 	"strings"
 	"sync"
@@ -159,6 +160,17 @@ func InitChannelCache() {
 			}
 		}
 	}
+
+	// sort by priority
+	for group, model2channels := range newGroup2model2channels {
+		for model, channels := range model2channels {
+			sort.Slice(channels, func(i, j int) bool {
+				return channels[i].Priority > channels[j].Priority
+			})
+			newGroup2model2channels[group][model] = channels
+		}
+	}
+
 	channelSyncLock.Lock()
 	group2model2channels = newGroup2model2channels
 	channelSyncLock.Unlock()
@@ -183,6 +195,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 	if len(channels) == 0 {
 		return nil, errors.New("channel not found")
 	}
+	// choose by priority
+	firstChannel := channels[0]
+	if firstChannel.Priority > 0 {
+		return firstChannel, nil
+	}
 	idx := rand.Intn(len(channels))
 	return channels[idx], nil
 }

+ 1 - 0
model/channel.go

@@ -24,6 +24,7 @@ type Channel struct {
 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
 	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
+	Priority           int64   `json:"priority" gorm:"bigint;default:0"`
 }
 
 func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {

+ 21 - 4
model/log.go

@@ -1,6 +1,8 @@
 package model
 
 import (
+	"context"
+	"fmt"
 	"gorm.io/gorm"
 	"one-api/common"
 	"strings"
@@ -19,6 +21,7 @@ type Log struct {
 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"`
 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"`
 	TokenId          int    `json:"token_id" gorm:"default:0;index"`
+	Channel          int    `json:"channel" gorm:"default:0"`
 }
 
 const (
@@ -51,7 +54,8 @@ func RecordLog(userId int, logType int, content string) {
 	}
 }
 
-func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int) {
+func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, tokenId int) {
+	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 	if !common.LogConsumeEnabled {
 		return
 	}
@@ -66,15 +70,16 @@ func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelN
 		TokenName:        tokenName,
 		ModelName:        modelName,
 		Quota:            quota,
+		Channel:          channelId,
 		TokenId:          tokenId,
 	}
 	err := DB.Create(log).Error
 	if err != nil {
-		common.SysError("failed to record log: " + err.Error())
+		common.LogError(ctx, "failed to record log: "+err.Error())
 	}
 }
 
-func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
+func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
 	var tx *gorm.DB
 	if logType == LogTypeUnknown {
 		tx = DB
@@ -96,6 +101,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 	if endTimestamp != 0 {
 		tx = tx.Where("created_at <= ?", endTimestamp)
 	}
+	if channel != 0 {
+		tx = tx.Where("channel = ?", channel)
+	}
 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
 	return logs, err
 }
@@ -139,7 +147,7 @@ type Stat struct {
 	Tpm   int `json:"tpm"`
 }
 
-func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (stat Stat) {
+func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
 	tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
 	if username != "" {
 		tx = tx.Where("username = ?", username)
@@ -156,6 +164,10 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 	if modelName != "" {
 		tx = tx.Where("model_name = ?", modelName)
 	}
+
+	if channel != 0 {
+		tx = tx.Where("channel = ?", channel)
+	}
 	tx.Where("type = ?", LogTypeConsume).Scan(&stat)
 	return stat
 }
@@ -180,3 +192,8 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
 	tx.Where("type = ?", LogTypeConsume).Scan(&token)
 	return token
 }
+
+func DeleteOldLog(targetTimestamp int64) (int64, error) {
+	result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
+	return result.RowsAffected, result.Error
+}

+ 1 - 0
router/api-router.go

@@ -103,6 +103,7 @@ func SetApiRouter(router *gin.Engine) {
 		}
 		logRoute := apiRouter.Group("/log")
 		logRoute.GET("/", middleware.AdminAuth(), controller.GetAllLogs)
+		logRoute.DELETE("/", middleware.AdminAuth(), controller.DeleteHistoryLogs)
 		logRoute.GET("/stat", middleware.AdminAuth(), controller.GetLogsStat)
 		logRoute.GET("/self/stat", middleware.UserAuth(), controller.GetLogsSelfStat)
 		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)

+ 1 - 0
router/relay-router.go

@@ -8,6 +8,7 @@ import (
 )
 
 func SetRelayRouter(router *gin.Engine) {
+	router.Use(middleware.CORS())
 	// https://platform.openai.com/docs/api-reference/introduction
 	modelsRouter := router.Group("/v1/models")
 	modelsRouter.Use(middleware.TokenAuth())

+ 37 - 5
web/src/components/ChannelsTable.js

@@ -1,7 +1,7 @@
 import React, { useEffect, useState } from 'react';
-import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
+import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
 import { Link } from 'react-router-dom';
-import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
+import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
 
 import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
 import {renderGroup, renderNumber, renderQuota} from '../helpers/render';
@@ -24,7 +24,7 @@ function renderType(type) {
     }
     type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
   }
-  return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
+  return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
 }
 
 function renderBalance(type, balance) {
@@ -96,7 +96,7 @@ const ChannelsTable = () => {
       });
   }, []);
 
-  const manageChannel = async (id, action, idx) => {
+  const manageChannel = async (id, action, idx, priority) => {
     let data = { id };
     let res;
     switch (action) {
@@ -111,6 +111,13 @@ const ChannelsTable = () => {
         data.status = 2;
         res = await API.put('/api/channel/', data);
         break;
+      case 'priority':
+        if (priority === '') {
+          return;
+        }
+        data.priority = parseInt(priority);
+        res = await API.put('/api/channel/', data);
+        break;
     }
     const { success, message } = res.data;
     if (success) {
@@ -195,6 +202,7 @@ const ChannelsTable = () => {
       showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
     } else {
       showError(message);
+      showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。")
     }
   };
 
@@ -346,6 +354,14 @@ const ChannelsTable = () => {
             >
               余额
             </Table.HeaderCell>
+            <Table.HeaderCell
+                style={{ cursor: 'pointer' }}
+                onClick={() => {
+                  sortChannel('priority');
+                }}
+            >
+              优先级
+            </Table.HeaderCell>
             <Table.HeaderCell>操作</Table.HeaderCell>
           </Table.Row>
         </Table.Header>
@@ -385,6 +401,22 @@ const ChannelsTable = () => {
                       basic
                     />
                   </Table.Cell>
+                  <Table.Cell>
+                    <Popup
+                        trigger={<Input type="number"  defaultValue={channel.priority} onBlur={(event) => {
+                          manageChannel(
+                              channel.id,
+                              'priority',
+                              idx,
+                              event.target.value,
+                          );
+                        }}>
+                          <input style={{maxWidth:'60px'}} />
+                        </Input>}
+                        content='渠道选择优先级,越高越优先'
+                        basic
+                    />
+                  </Table.Cell>
                   <Table.Cell>
                     <div>
                       <Button
@@ -453,7 +485,7 @@ const ChannelsTable = () => {
 
         <Table.Footer>
           <Table.Row>
-            <Table.HeaderCell colSpan='8'>
+            <Table.HeaderCell colSpan='9'>
               <Button size='small' as={Link} to='/channel/add' loading={loading}>
                 添加新的渠道
               </Button>

+ 40 - 17
web/src/components/LogsTable.js

@@ -56,9 +56,10 @@ const LogsTable = () => {
     token_name: '',
     model_name: '',
     start_timestamp: timestamp2string(0),
-    end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
+    end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
+    channel: ''
   });
-  const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
+  const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
 
   const [stat, setStat] = useState({
     quota: 0,
@@ -84,7 +85,7 @@ const LogsTable = () => {
   const getLogStat = async () => {
     let localStartTimestamp = Date.parse(start_timestamp) / 1000;
     let localEndTimestamp = Date.parse(end_timestamp) / 1000;
-    let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
+    let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
     const { success, message, data } = res.data;
     if (success) {
       setStat(data);
@@ -109,7 +110,7 @@ const LogsTable = () => {
     let localStartTimestamp = Date.parse(start_timestamp) / 1000;
     let localEndTimestamp = Date.parse(end_timestamp) / 1000;
     if (isAdminUser) {
-      url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
+      url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
     } else {
       url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
     }
@@ -205,16 +206,9 @@ const LogsTable = () => {
         </Header>
         <Form>
           <Form.Group>
-            {
-              isAdminUser && (
-                <Form.Input fluid label={'用户名称'} width={2} value={username}
-                            placeholder={'可选值'} name='username'
-                            onChange={handleInputChange} />
-              )
-            }
-            <Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
+            <Form.Input fluid label={'令牌名称'} width={3} value={token_name}
                         placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
-            <Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
+            <Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
                         name='model_name'
                         onChange={handleInputChange} />
             <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
@@ -225,6 +219,19 @@ const LogsTable = () => {
                         onChange={handleInputChange} />
             <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
           </Form.Group>
+          {
+            isAdminUser && <>
+              <Form.Group>
+                <Form.Input fluid label={'渠道 ID'} width={3} value={channel}
+                            placeholder='可选值' name='channel'
+                            onChange={handleInputChange} />
+                <Form.Input fluid label={'用户名称'} width={3} value={username}
+                            placeholder={'可选值'} name='username'
+                            onChange={handleInputChange} />
+
+              </Form.Group>
+            </>
+          }
         </Form>
         <Table basic compact size='small'>
           <Table.Header>
@@ -238,6 +245,17 @@ const LogsTable = () => {
               >
                 时间
               </Table.HeaderCell>
+              {
+                isAdminUser && <Table.HeaderCell
+                  style={{ cursor: 'pointer' }}
+                  onClick={() => {
+                    sortLog('channel');
+                  }}
+                  width={1}
+                >
+                  渠道
+                </Table.HeaderCell>
+              }
               {
                 isAdminUser && <Table.HeaderCell
                   style={{ cursor: 'pointer' }}
@@ -299,16 +317,16 @@ const LogsTable = () => {
                 onClick={() => {
                   sortLog('quota');
                 }}
-                width={2}
+                width={1}
               >
-                消耗额度
+                额度
               </Table.HeaderCell>
               <Table.HeaderCell
                 style={{ cursor: 'pointer' }}
                 onClick={() => {
                   sortLog('content');
                 }}
-                width={isAdminUser ? 4 : 5}
+                width={isAdminUser ? 4 : 6}
               >
                 详情
               </Table.HeaderCell>
@@ -326,6 +344,11 @@ const LogsTable = () => {
                 return (
                   <Table.Row key={log.id}>
                     <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
+                    {
+                      isAdminUser && (
+                        <Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
+                      )
+                    }
                     {
                       isAdminUser && (
                         <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
@@ -345,7 +368,7 @@ const LogsTable = () => {
 
           <Table.Footer>
             <Table.Row>
-              <Table.HeaderCell colSpan={'9'}>
+              <Table.HeaderCell colSpan={'10'}>
                 <Select
                   placeholder='选择明细分类'
                   options={LOG_OPTIONS}

+ 108 - 91
web/src/components/OperationSetting.js

@@ -1,9 +1,9 @@
 import React, {useEffect, useState} from 'react';
 import {Divider, Form, Grid, Header} from 'semantic-ui-react';
-import {API, showError, verifyJSON} from '../helpers';
+import {API, showError, showSuccess, timestamp2string, verifyJSON} from '../helpers';
 
 const OperationSetting = () => {
-    let [inputs, setInputs] = useState({
+    let now = new Date();let [inputs, setInputs] = useState({
         QuotaForNewUser: 0,
         QuotaForInviter: 0,
         QuotaForInvitee: 0,
@@ -20,28 +20,28 @@ const OperationSetting = () => {
         DisplayInCurrencyEnabled: '',
         DisplayTokenStatEnabled: '',
         ApproximateTokenEnabled: '',
-        RetryTimes: 0,
+        RetryTimes: 0
     });
     const [originInputs, setOriginInputs] = useState({});
-    let [loading, setLoading] = useState(false);
+    let [loading, setLoading] = useState(false);let [historyTimestamp, setHistoryTimestamp] = useState(timestamp2string(now.getTime() / 1000 - 30 * 24 * 3600)); // a month ago
 
-    const getOptions = async () => {
-        const res = await API.get('/api/option/');
-        const {success, message, data} = res.data;
-        if (success) {
-            let newInputs = {};
-            data.forEach((item) => {
-                if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
-                    item.value = JSON.stringify(JSON.parse(item.value), null, 2);
-                }
-                newInputs[item.key] = item.value;
-            });
-            setInputs(newInputs);
-            setOriginInputs(newInputs);
-        } else {
-            showError(message);
+  const getOptions = async () => {
+    const res = await API.get('/api/option/');
+    const { success, message, data } = res.data;
+    if (success) {
+      let newInputs = {};
+      data.forEach((item) => {
+        if (item.key === 'ModelRatio' || item.key === 'GroupRatio') {
+          item.value = JSON.stringify(JSON.parse(item.value), null, 2);
         }
-    };
+        newInputs[item.key] = item.value;
+      });
+      setInputs(newInputs);
+      setOriginInputs(newInputs);
+    } else {
+      showError(message);
+    }
+  };
 
     useEffect(() => {
         getOptions().then();
@@ -73,72 +73,73 @@ const OperationSetting = () => {
         }
     };
 
-    const submitConfig = async (group) => {
-        switch (group) {
-            case 'monitor':
-                if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
-                    await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
-                }
-                if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
-                    await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
-                }
-                break;
-            case 'stable':
-                await updateOption('StablePrice', inputs.StablePrice);
-                await updateOption('NormalPrice', inputs.NormalPrice);
-                await updateOption('BasePrice', inputs.BasePrice);
-                localStorage.setItem('stable_price', inputs.StablePrice);
-                localStorage.setItem('normal_price', inputs.NormalPrice);
-                localStorage.setItem('base_price', inputs.BasePrice);
-                break;
-            case 'ratio':
-                if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
-                    if (!verifyJSON(inputs.ModelRatio)) {
-                        showError('模型倍率不是合法的 JSON 字符串');
-                        return;
-                    }
-                    await updateOption('ModelRatio', inputs.ModelRatio);
-                }
-                if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
-                    if (!verifyJSON(inputs.GroupRatio)) {
-                        showError('分组倍率不是合法的 JSON 字符串');
-                        return;
-                    }
-                    await updateOption('GroupRatio', inputs.GroupRatio);
-                }
-                break;
-            case 'quota':
-                if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
-                    await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
-                }
-                if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
-                    await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
-                }
-                if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
-                    await updateOption('QuotaForInviter', inputs.QuotaForInviter);
-                }
-                if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
-                    await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
-                }
-                break;
-            case 'general':
-                if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
-                    await updateOption('TopUpLink', inputs.TopUpLink);
-                }
-                if (originInputs['ChatLink'] !== inputs.ChatLink) {
-                    await updateOption('ChatLink', inputs.ChatLink);
-                }
-                if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
-                    await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
-                }
-                if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
-                    await updateOption('RetryTimes', inputs.RetryTimes);
-                }
-                break;
+  const submitConfig = async (group) => {
+    switch (group) {
+      case 'monitor':
+        if (originInputs['ChannelDisableThreshold'] !== inputs.ChannelDisableThreshold) {
+          await updateOption('ChannelDisableThreshold', inputs.ChannelDisableThreshold);
         }
-    };
+        if (originInputs['QuotaRemindThreshold'] !== inputs.QuotaRemindThreshold) {
+          await updateOption('QuotaRemindThreshold', inputs.QuotaRemindThreshold);
+        }
+        break;
+      case 'ratio':
+        if (originInputs['ModelRatio'] !== inputs.ModelRatio) {
+          if (!verifyJSON(inputs.ModelRatio)) {
+            showError('模型倍率不是合法的 JSON 字符串');
+            return;
+          }
+          await updateOption('ModelRatio', inputs.ModelRatio);
+        }
+        if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
+          if (!verifyJSON(inputs.GroupRatio)) {
+            showError('分组倍率不是合法的 JSON 字符串');
+            return;
+          }
+          await updateOption('GroupRatio', inputs.GroupRatio);
+        }
+        break;
+      case 'quota':
+        if (originInputs['QuotaForNewUser'] !== inputs.QuotaForNewUser) {
+          await updateOption('QuotaForNewUser', inputs.QuotaForNewUser);
+        }
+        if (originInputs['QuotaForInvitee'] !== inputs.QuotaForInvitee) {
+          await updateOption('QuotaForInvitee', inputs.QuotaForInvitee);
+        }
+        if (originInputs['QuotaForInviter'] !== inputs.QuotaForInviter) {
+          await updateOption('QuotaForInviter', inputs.QuotaForInviter);
+        }
+        if (originInputs['PreConsumedQuota'] !== inputs.PreConsumedQuota) {
+          await updateOption('PreConsumedQuota', inputs.PreConsumedQuota);
+        }
+        break;
+      case 'general':
+        if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
+          await updateOption('TopUpLink', inputs.TopUpLink);
+        }
+        if (originInputs['ChatLink'] !== inputs.ChatLink) {
+          await updateOption('ChatLink', inputs.ChatLink);
+        }
+        if (originInputs['QuotaPerUnit'] !== inputs.QuotaPerUnit) {
+          await updateOption('QuotaPerUnit', inputs.QuotaPerUnit);
+        }
+        if (originInputs['RetryTimes'] !== inputs.RetryTimes) {
+          await updateOption('RetryTimes', inputs.RetryTimes);
+        }
+        break;
+    }
+  };
 
-    return (
+    const deleteHistoryLogs = async () => {
+    console.log(inputs);
+    const res = await API.delete(`/api/log/?target_timestamp=${Date.parse(historyTimestamp) / 1000}`);
+    const { success, message, data } = res.data;
+    if (success) {
+      showSuccess(`${data} 条日志已清理!`);
+      return;
+    }
+    showError('日志清理失败:' + message);
+  };return (
         <Grid columns={1}>
             <Grid.Column>
                 <Form loading={loading}>
@@ -187,12 +188,7 @@ const OperationSetting = () => {
                         />
                     </Form.Group>
                     <Form.Group inline>
-                        <Form.Checkbox
-                            checked={inputs.LogConsumeEnabled === 'true'}
-                            label='启用额度消费日志记录'
-                            name='LogConsumeEnabled'
-                            onChange={handleInputChange}
-                        />
+
                         <Form.Checkbox
                             checked={inputs.DisplayInCurrencyEnabled === 'true'}
                             label='以货币形式显示额度'
@@ -214,7 +210,28 @@ const OperationSetting = () => {
                     </Form.Group>
                     <Form.Button onClick={() => {
                         submitConfig('general').then();
-                    }}>保存通用设置</Form.Button>
+                    }}>保存通用设置</Form.Button><Divider />
+          <Header as='h3'>
+            日志设置
+          </Header>
+          <Form.Group inline>
+            <Form.Checkbox
+              checked={inputs.LogConsumeEnabled === 'true'}
+              label='启用额度消费日志记录'
+              name='LogConsumeEnabled'
+              onChange={handleInputChange}
+            />
+          </Form.Group>
+          <Form.Group widths={4}>
+            <Form.Input label='目标时间' value={historyTimestamp} type='datetime-local'
+                        name='history_timestamp'
+                        onChange={(e, { name, value }) => {
+                          setHistoryTimestamp(value);
+                        }} />
+          </Form.Group>
+          <Form.Button onClick={() => {
+            deleteHistoryLogs().then();
+          }}>清理历史日志</Form.Button>
                     <Divider/>
                     <Header as='h3'>
                         监控设置

+ 1 - 1
web/src/components/TokensTable.js

@@ -96,7 +96,7 @@ const TokensTable = () => {
     let nextUrl;
 
     if (nextLink) {
-      nextUrl = nextLink + `/#/?settings={"key":"sk-${key}"}`;
+      nextUrl = nextLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
     } else {
       nextUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
     }