Browse Source

Merge remote-tracking branch 'origin/main'

# Conflicts:
#	controller/relay.go
#	main.go
#	middleware/distributor.go
CaIon 2 years ago
parent
commit
9c08d78349

+ 14 - 1
README.md

@@ -68,6 +68,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
    + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
+   + [x] [360 智脑](https://ai.360.cn)
 2. 支持配置镜像以及众多第三方代理服务:
    + [x] [OpenAI-SB](https://openai-sb.com)
    + [x] [API2D](https://api2d.com/r/197971)
@@ -108,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 
 数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
 
+如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
+
 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
 
 如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
@@ -274,8 +277,9 @@ graph LR
 不加的话将会使用负载均衡的方式使用多个渠道。
 
 ### 环境变量
-1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储
+1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用
    + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
+   + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
    + 例子:`SESSION_SECRET=random_string`
 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
@@ -302,6 +306,14 @@ graph LR
    + 例子:`CHANNEL_TEST_FREQUENCY=1440`
 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
    + 例子:`POLLING_INTERVAL=5`
+10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+    + 例子:`BATCH_UPDATE_ENABLED=true`
+    + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
+11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+    + 例子:`BATCH_UPDATE_INTERVAL=5`
+12. 请求频率限制:
+    + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+    + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
 
 ### 命令行参数
 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
@@ -338,6 +350,7 @@ https://openai.justsong.cn
 5. ChatGPT Next Web 报错:`Failed to fetch`
    + 部署的时候不要设置 `BASE_URL`。
    + 检查你的接口地址和 API Key 有没有填对。
+   + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
 6. 报错:`当前分组负载已饱和,请稍后再试`
    + 上游通道 429 了。
 

+ 51 - 40
common/constants.go

@@ -98,6 +98,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second
 
 var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
 
+var BatchUpdateEnabled = false
+var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
+
 const (
 	RoleGuestUser  = 0
 	RoleCommonUser = 1
@@ -115,10 +118,10 @@ var (
 // All duration's unit is seconds
 // Shouldn't larger then RateLimitKeyExpirationDuration
 var (
-	GlobalApiRateLimitNum            = 180
+	GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
 	GlobalApiRateLimitDuration int64 = 3 * 60
 
-	GlobalWebRateLimitNum            = 60
+	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
 	GlobalWebRateLimitDuration int64 = 3 * 60
 
 	UploadRateLimitNum            = 10
@@ -158,45 +161,53 @@ const (
 )
 
 const (
-	ChannelTypeUnknown   = 0
-	ChannelTypeOpenAI    = 1
-	ChannelTypeAPI2D     = 2
-	ChannelTypeAzure     = 3
-	ChannelTypeCloseAI   = 4
-	ChannelTypeOpenAISB  = 5
-	ChannelTypeOpenAIMax = 6
-	ChannelTypeOhMyGPT   = 7
-	ChannelTypeCustom    = 8
-	ChannelTypeAILS      = 9
-	ChannelTypeAIProxy   = 10
-	ChannelTypePaLM      = 11
-	ChannelTypeAPI2GPT   = 12
-	ChannelTypeAIGC2D    = 13
-	ChannelTypeAnthropic = 14
-	ChannelTypeBaidu     = 15
-	ChannelTypeZhipu     = 16
-	ChannelTypeAli       = 17
-	ChannelTypeXunfei    = 18
+	ChannelTypeUnknown        = 0
+	ChannelTypeOpenAI         = 1
+	ChannelTypeAPI2D          = 2
+	ChannelTypeAzure          = 3
+	ChannelTypeCloseAI        = 4
+	ChannelTypeOpenAISB       = 5
+	ChannelTypeOpenAIMax      = 6
+	ChannelTypeOhMyGPT        = 7
+	ChannelTypeCustom         = 8
+	ChannelTypeAILS           = 9
+	ChannelTypeAIProxy        = 10
+	ChannelTypePaLM           = 11
+	ChannelTypeAPI2GPT        = 12
+	ChannelTypeAIGC2D         = 13
+	ChannelTypeAnthropic      = 14
+	ChannelTypeBaidu          = 15
+	ChannelTypeZhipu          = 16
+	ChannelTypeAli            = 17
+	ChannelTypeXunfei         = 18
+	ChannelType360            = 19
+	ChannelTypeOpenRouter     = 20
+	ChannelTypeAIProxyLibrary = 21
+	ChannelTypeFastGPT        = 22
 )
 
 var ChannelBaseURLs = []string{
-	"",                               // 0
-	"https://api.openai.com",         // 1
-	"https://oa.api2d.net",           // 2
-	"",                               // 3
-	"https://api.closeai-proxy.xyz",  // 4
-	"https://api.openai-sb.com",      // 5
-	"https://api.openaimax.com",      // 6
-	"https://api.ohmygpt.com",        // 7
-	"",                               // 8
-	"https://api.caipacity.com",      // 9
-	"https://api.aiproxy.io",         // 10
-	"",                               // 11
-	"https://api.api2gpt.com",        // 12
-	"https://api.aigc2d.com",         // 13
-	"https://api.anthropic.com",      // 14
-	"https://aip.baidubce.com",       // 15
-	"https://open.bigmodel.cn",       // 16
-	"https://dashscope.aliyuncs.com", // 17
-	"",                               // 18
+	"",                                // 0
+	"https://api.openai.com",          // 1
+	"https://oa.api2d.net",            // 2
+	"",                                // 3
+	"https://api.closeai-proxy.xyz",   // 4
+	"https://api.openai-sb.com",       // 5
+	"https://api.openaimax.com",       // 6
+	"https://api.ohmygpt.com",         // 7
+	"",                                // 8
+	"https://api.caipacity.com",       // 9
+	"https://api.aiproxy.io",          // 10
+	"",                                // 11
+	"https://api.api2gpt.com",         // 12
+	"https://api.aigc2d.com",          // 13
+	"https://api.anthropic.com",       // 14
+	"https://aip.baidubce.com",        // 15
+	"https://open.bigmodel.cn",        // 16
+	"https://dashscope.aliyuncs.com",  // 17
+	"",                                // 18
+	"https://ai.360.cn",               // 19
+	"https://openrouter.ai/api",       // 20
+	"https://api.aiproxy.io",          // 21
+	"https://fastgpt.run/api/openapi", // 22
 }

+ 46 - 40
common/model-ratio.go

@@ -13,46 +13,52 @@ import (
 // 1 === $0.002 / 1K tokens
 // 1 === ¥0.014 / 1k tokens
 var ModelRatio = map[string]float64{
-	"gpt-4":                   15,
-	"gpt-4-0314":              15,
-	"gpt-4-0613":              15,
-	"gpt-4-32k":               30,
-	"gpt-4-32k-0314":          30,
-	"gpt-4-32k-0613":          30,
-	"gpt-3.5-turbo":           0.75, // $0.0015 / 1K tokens
-	"gpt-3.5-turbo-0301":      0.75,
-	"gpt-3.5-turbo-0613":      0.75,
-	"gpt-3.5-turbo-16k":       1.5, // $0.003 / 1K tokens
-	"gpt-3.5-turbo-16k-0613":  1.5,
-	"text-ada-001":            0.2,
-	"text-babbage-001":        0.25,
-	"text-curie-001":          1,
-	"text-davinci-002":        10,
-	"text-davinci-003":        10,
-	"text-davinci-edit-001":   10,
-	"code-davinci-edit-001":   10,
-	"whisper-1":               10,
-	"davinci":                 10,
-	"curie":                   10,
-	"babbage":                 10,
-	"ada":                     10,
-	"text-embedding-ada-002":  0.05,
-	"text-search-ada-doc-001": 10,
-	"text-moderation-stable":  0.1,
-	"text-moderation-latest":  0.1,
-	"dall-e":                  8,
-	"claude-instant-1":        0.815,  // $1.63 / 1M tokens
-	"claude-2":                5.51,   // $11.02 / 1M tokens
-	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens
-	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens
-	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens
-	"PaLM-2":                  1,
-	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens
-	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens
-	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens
-	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
-	"qwen-plus-v1":            0.5715, // Same as above
-	"SparkDesk":               0.8572, // TBD
+	"gpt-4":                     15,
+	"gpt-4-0314":                15,
+	"gpt-4-0613":                15,
+	"gpt-4-32k":                 30,
+	"gpt-4-32k-0314":            30,
+	"gpt-4-32k-0613":            30,
+	"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens
+	"gpt-3.5-turbo-0301":        0.75,
+	"gpt-3.5-turbo-0613":        0.75,
+	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens
+	"gpt-3.5-turbo-16k-0613":    1.5,
+	"text-ada-001":              0.2,
+	"text-babbage-001":          0.25,
+	"text-curie-001":            1,
+	"text-davinci-002":          10,
+	"text-davinci-003":          10,
+	"text-davinci-edit-001":     10,
+	"code-davinci-edit-001":     10,
+	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
+	"davinci":                   10,
+	"curie":                     10,
+	"babbage":                   10,
+	"ada":                       10,
+	"text-embedding-ada-002":    0.05,
+	"text-search-ada-doc-001":   10,
+	"text-moderation-stable":    0.1,
+	"text-moderation-latest":    0.1,
+	"dall-e":                    8,
+	"claude-instant-1":          0.815,  // $1.63 / 1M tokens
+	"claude-2":                  5.51,   // $11.02 / 1M tokens
+	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens
+	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens
+	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens
+	"PaLM-2":                    1,
+	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
+	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
+	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens
+	"qwen-v1":                   0.8572, // ¥0.012 / 1k tokens
+	"qwen-plus-v1":              1,      // ¥0.014 / 1k tokens
+	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens
+	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens
+	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens
+	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens
+	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens
+	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
+	"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens
 }
 
 func ModelRatio2JSONString() string {

+ 10 - 1
controller/channel-test.go

@@ -14,7 +14,7 @@ import (
 	"time"
 )
 
-func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
+func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
 	switch channel.Type {
 	case common.ChannelTypePaLM:
 		fallthrough
@@ -24,10 +24,19 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
 		fallthrough
 	case common.ChannelTypeZhipu:
 		fallthrough
+	case common.ChannelTypeAli:
+		fallthrough
+	case common.ChannelType360:
+		fallthrough
 	case common.ChannelTypeXunfei:
 		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
 	case common.ChannelTypeAzure:
 		request.Model = "gpt-35-turbo"
+		defer func() {
+			if err != nil {
+				err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
+			}
+		}()
 	default:
 		request.Model = "gpt-3.5-turbo"
 	}

+ 1 - 1
controller/channel.go

@@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) {
 	}
 	channel.CreatedTime = common.GetTimestamp()
 	keys := strings.Split(channel.Key, "\n")
-	channels := make([]model.Channel, 0)
+	channels := make([]model.Channel, 0, len(keys))
 	for _, key := range keys {
 		if key == "" {
 			continue

+ 63 - 0
controller/model.go

@@ -63,6 +63,15 @@ func init() {
 			Root:       "dall-e",
 			Parent:     nil,
 		},
+		{
+			Id:         "whisper-1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "whisper-1",
+			Parent:     nil,
+		},
 		{
 			Id:         "gpt-3.5-turbo",
 			Object:     "model",
@@ -351,6 +360,15 @@ func init() {
 			Root:       "qwen-plus-v1",
 			Parent:     nil,
 		},
+		{
+			Id:         "text-embedding-v1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "ali",
+			Permission: permission,
+			Root:       "text-embedding-v1",
+			Parent:     nil,
+		},
 		{
 			Id:         "SparkDesk",
 			Object:     "model",
@@ -360,6 +378,51 @@ func init() {
 			Root:       "SparkDesk",
 			Parent:     nil,
 		},
+		{
+			Id:         "360GPT_S2_V9",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "360",
+			Permission: permission,
+			Root:       "360GPT_S2_V9",
+			Parent:     nil,
+		},
+		{
+			Id:         "embedding-bert-512-v1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "360",
+			Permission: permission,
+			Root:       "embedding-bert-512-v1",
+			Parent:     nil,
+		},
+		{
+			Id:         "embedding_s1_v1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "360",
+			Permission: permission,
+			Root:       "embedding_s1_v1",
+			Parent:     nil,
+		},
+		{
+			Id:         "semantic_similarity_s1_v1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "360",
+			Permission: permission,
+			Root:       "semantic_similarity_s1_v1",
+			Parent:     nil,
+		},
+		{
+			Id:         "360GPT_S2_V9.4",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "360",
+			Permission: permission,
+			Root:       "360GPT_S2_V9.4",
+			Parent:     nil,
+		},
 	}
 	openAIModelsMap = make(map[string]OpenAIModels)
 	for _, model := range openAIModels {

+ 220 - 0
controller/relay-aiproxy.go

@@ -0,0 +1,220 @@
+package controller
+
+import (
+	"bufio"
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"strconv"
+	"strings"
+)
+
+// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
+
+type AIProxyLibraryRequest struct {
+	Model     string `json:"model"`
+	Query     string `json:"query"`
+	LibraryId string `json:"libraryId"`
+	Stream    bool   `json:"stream"`
+}
+
+type AIProxyLibraryError struct {
+	ErrCode int    `json:"errCode"`
+	Message string `json:"message"`
+}
+
+type AIProxyLibraryDocument struct {
+	Title string `json:"title"`
+	URL   string `json:"url"`
+}
+
+type AIProxyLibraryResponse struct {
+	Success   bool                     `json:"success"`
+	Answer    string                   `json:"answer"`
+	Documents []AIProxyLibraryDocument `json:"documents"`
+	AIProxyLibraryError
+}
+
+type AIProxyLibraryStreamResponse struct {
+	Content   string                   `json:"content"`
+	Finish    bool                     `json:"finish"`
+	Model     string                   `json:"model"`
+	Documents []AIProxyLibraryDocument `json:"documents"`
+}
+
+func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
+	query := ""
+	if len(request.Messages) != 0 {
+		query = request.Messages[len(request.Messages)-1].Content
+	}
+	return &AIProxyLibraryRequest{
+		Model:  request.Model,
+		Stream: request.Stream,
+		Query:  query,
+	}
+}
+
+func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
+	if len(documents) == 0 {
+		return ""
+	}
+	content := "\n\n参考文档:\n"
+	for i, document := range documents {
+		content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
+	}
+	return content
+}
+
+func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
+	content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
+	choice := OpenAITextResponseChoice{
+		Index: 0,
+		Message: Message{
+			Role:    "assistant",
+			Content: content,
+		},
+		FinishReason: "stop",
+	}
+	fullTextResponse := OpenAITextResponse{
+		Id:      common.GetUUID(),
+		Object:  "chat.completion",
+		Created: common.GetTimestamp(),
+		Choices: []OpenAITextResponseChoice{choice},
+	}
+	return &fullTextResponse
+}
+
+func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
+	var choice ChatCompletionsStreamResponseChoice
+	choice.Delta.Content = aiProxyDocuments2Markdown(documents)
+	choice.FinishReason = &stopFinishReason
+	return &ChatCompletionsStreamResponse{
+		Id:      common.GetUUID(),
+		Object:  "chat.completion.chunk",
+		Created: common.GetTimestamp(),
+		Model:   "",
+		Choices: []ChatCompletionsStreamResponseChoice{choice},
+	}
+}
+
+func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
+	var choice ChatCompletionsStreamResponseChoice
+	choice.Delta.Content = response.Content
+	return &ChatCompletionsStreamResponse{
+		Id:      common.GetUUID(),
+		Object:  "chat.completion.chunk",
+		Created: common.GetTimestamp(),
+		Model:   response.Model,
+		Choices: []ChatCompletionsStreamResponseChoice{choice},
+	}
+}
+
+func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var usage Usage
+	scanner := bufio.NewScanner(resp.Body)
+	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+		if atEOF && len(data) == 0 {
+			return 0, nil, nil
+		}
+		if i := strings.Index(string(data), "\n"); i >= 0 {
+			return i + 1, data[0:i], nil
+		}
+		if atEOF {
+			return len(data), data, nil
+		}
+		return 0, nil, nil
+	})
+	dataChan := make(chan string)
+	stopChan := make(chan bool)
+	go func() {
+		for scanner.Scan() {
+			data := scanner.Text()
+			if len(data) < 5 { // ignore blank line or wrong format
+				continue
+			}
+			if data[:5] != "data:" {
+				continue
+			}
+			data = data[5:]
+			dataChan <- data
+		}
+		stopChan <- true
+	}()
+	setEventStreamHeaders(c)
+	var documents []AIProxyLibraryDocument
+	c.Stream(func(w io.Writer) bool {
+		select {
+		case data := <-dataChan:
+			var AIProxyLibraryResponse AIProxyLibraryStreamResponse
+			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
+			if err != nil {
+				common.SysError("error unmarshalling stream response: " + err.Error())
+				return true
+			}
+			if len(AIProxyLibraryResponse.Documents) != 0 {
+				documents = AIProxyLibraryResponse.Documents
+			}
+			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
+			jsonResponse, err := json.Marshal(response)
+			if err != nil {
+				common.SysError("error marshalling stream response: " + err.Error())
+				return true
+			}
+			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+			return true
+		case <-stopChan:
+			response := documentsAIProxyLibrary(documents)
+			jsonResponse, err := json.Marshal(response)
+			if err != nil {
+				common.SysError("error marshalling stream response: " + err.Error())
+				return true
+			}
+			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+			return false
+		}
+	})
+	err := resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	return nil, &usage
+}
+
+func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var AIProxyLibraryResponse AIProxyLibraryResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if AIProxyLibraryResponse.ErrCode != 0 {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: OpenAIError{
+				Message: AIProxyLibraryResponse.Message,
+				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode),
+				Code:    AIProxyLibraryResponse.ErrCode,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &fullTextResponse.Usage
+}

+ 88 - 0
controller/relay-ali.go

@@ -35,6 +35,29 @@ type AliChatRequest struct {
 	Parameters AliParameters `json:"parameters,omitempty"`
 }
 
+type AliEmbeddingRequest struct {
+	Model string `json:"model"`
+	Input struct {
+		Texts []string `json:"texts"`
+	} `json:"input"`
+	Parameters *struct {
+		TextType string `json:"text_type,omitempty"`
+	} `json:"parameters,omitempty"`
+}
+
+type AliEmbedding struct {
+	Embedding []float64 `json:"embedding"`
+	TextIndex int       `json:"text_index"`
+}
+
+type AliEmbeddingResponse struct {
+	Output struct {
+		Embeddings []AliEmbedding `json:"embeddings"`
+	} `json:"output"`
+	Usage AliUsage `json:"usage"`
+	AliError
+}
+
 type AliError struct {
 	Code      string `json:"code"`
 	Message   string `json:"message"`
@@ -44,6 +67,7 @@ type AliError struct {
 type AliUsage struct {
 	InputTokens  int `json:"input_tokens"`
 	OutputTokens int `json:"output_tokens"`
+	TotalTokens  int `json:"total_tokens"`
 }
 
 type AliOutput struct {
@@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
 	}
 }
 
+func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
+	return &AliEmbeddingRequest{
+		Model: "text-embedding-v1",
+		Input: struct {
+			Texts []string `json:"texts"`
+		}{
+			Texts: request.ParseInput(),
+		},
+	}
+}
+
+func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var aliResponse AliEmbeddingResponse
+	err := json.NewDecoder(resp.Body).Decode(&aliResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	if aliResponse.Code != "" {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: OpenAIError{
+				Message: aliResponse.Message,
+				Type:    aliResponse.Code,
+				Param:   aliResponse.RequestId,
+				Code:    aliResponse.Code,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+
+	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &fullTextResponse.Usage
+}
+
+func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
+	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
+		Object: "list",
+		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
+		Model:  "text-embedding-v1",
+		Usage:  Usage{TotalTokens: response.Usage.TotalTokens},
+	}
+
+	for _, item := range response.Output.Embeddings {
+		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
+			Object:    `embedding`,
+			Index:     item.TextIndex,
+			Embedding: item.Embedding,
+		})
+	}
+	return &openAIEmbeddingResponse
+}
+
 func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 	choice := OpenAITextResponseChoice{
 		Index: 0,

+ 147 - 0
controller/relay-audio.go

@@ -0,0 +1,147 @@
+package controller
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/model"
+
+	"github.com/gin-gonic/gin"
+)
+
+func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
+	audioModel := "whisper-1"
+
+	tokenId := c.GetInt("token_id")
+	channelType := c.GetInt("channel")
+	userId := c.GetInt("id")
+	group := c.GetString("group")
+
+	preConsumedTokens := common.PreConsumedQuota
+	modelRatio := common.GetModelRatio(audioModel)
+	groupRatio := common.GetGroupRatio(group)
+	ratio := modelRatio * groupRatio
+	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
+	userQuota, err := model.CacheGetUserQuota(userId)
+	if err != nil {
+		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+	}
+	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
+	if err != nil {
+		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
+	}
+	if userQuota > 100*preConsumedQuota {
+		// in this case, we do not pre-consume quota
+		// because the user has enough quota
+		preConsumedQuota = 0
+	}
+	if preConsumedQuota > 0 {
+		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
+		if err != nil {
+			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
+		}
+	}
+
+	// map model name
+	modelMapping := c.GetString("model_mapping")
+	if modelMapping != "" {
+		modelMap := make(map[string]string)
+		err := json.Unmarshal([]byte(modelMapping), &modelMap)
+		if err != nil {
+			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
+		}
+		if modelMap[audioModel] != "" {
+			audioModel = modelMap[audioModel]
+		}
+	}
+
+	baseURL := common.ChannelBaseURLs[channelType]
+	requestURL := c.Request.URL.String()
+
+	if c.GetString("base_url") != "" {
+		baseURL = c.GetString("base_url")
+	}
+
+	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+	requestBody := c.Request.Body
+
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	if err != nil {
+		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
+	}
+	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+
+	resp, err := httpClient.Do(req)
+	if err != nil {
+		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+	}
+
+	err = req.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+	}
+	err = c.Request.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+	}
+	var audioResponse AudioResponse
+
+	defer func() {
+		go func() {
+			quota := countTokenText(audioResponse.Text, audioModel)
+			quotaDelta := quota - preConsumedQuota
+			err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
+			if err != nil {
+				common.SysError("error consuming token remain quota: " + err.Error())
+			}
+			err = model.CacheUpdateUserQuota(userId)
+			if err != nil {
+				common.SysError("error update user quota cache: " + err.Error())
+			}
+			if quota != 0 {
+				tokenName := c.GetString("token_name")
+				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+				model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
+				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
+				channelId := c.GetInt("channel_id")
+				model.UpdateChannelUsedQuota(channelId, quota)
+			}
+		}()
+	}()
+
+	responseBody, err := io.ReadAll(resp.Body)
+
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+	}
+	err = json.Unmarshal(responseBody, &audioResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+	}
+
+	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.WriteHeader(resp.StatusCode)
+
+	_, err = io.Copy(c.Writer, resp.Body)
+	if err != nil {
+		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+	}
+	return nil
+}

+ 2 - 13
controller/relay-baidu.go

@@ -144,20 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
 }
 
 func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
-	baiduEmbeddingRequest := BaiduEmbeddingRequest{
-		Input: nil,
+	return &BaiduEmbeddingRequest{
+		Input: request.ParseInput(),
 	}
-	switch request.Input.(type) {
-	case string:
-		baiduEmbeddingRequest.Input = []string{request.Input.(string)}
-	case []any:
-		for _, item := range request.Input.([]any) {
-			if str, ok := item.(string); ok {
-				baiduEmbeddingRequest.Input = append(baiduEmbeddingRequest.Input, str)
-			}
-		}
-	}
-	return &baiduEmbeddingRequest
 }
 
 func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {

+ 60 - 4
controller/relay-text.go

@@ -22,6 +22,7 @@ const (
 	APITypeZhipu
 	APITypeAli
 	APITypeXunfei
+	APITypeAIProxyLibrary
 )
 
 var httpClient *http.Client
@@ -104,6 +105,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		apiType = APITypeAli
 	case common.ChannelTypeXunfei:
 		apiType = APITypeXunfei
+	case common.ChannelTypeAIProxyLibrary:
+		apiType = APITypeAIProxyLibrary
 	}
 
 	baseURL := common.ChannelBaseURLs[channelType]
@@ -172,6 +175,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 	case APITypeAli:
 		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
+		if relayMode == RelayModeEmbeddings {
+			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
+		}
+	case APITypeAIProxyLibrary:
+		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
 	}
 	var promptTokens int
 	var completionTokens int
@@ -258,8 +266,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		}
 		requestBody = bytes.NewBuffer(jsonStr)
 	case APITypeAli:
-		aliRequest := requestOpenAI2Ali(textRequest)
-		jsonStr, err := json.Marshal(aliRequest)
+		var jsonStr []byte
+		var err error
+		switch relayMode {
+		case RelayModeEmbeddings:
+			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
+			jsonStr, err = json.Marshal(aliEmbeddingRequest)
+		default:
+			aliRequest := requestOpenAI2Ali(textRequest)
+			jsonStr, err = json.Marshal(aliRequest)
+		}
+		if err != nil {
+			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+		}
+		requestBody = bytes.NewBuffer(jsonStr)
+	case APITypeAIProxyLibrary:
+		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
+		aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
+		jsonStr, err := json.Marshal(aiProxyLibraryRequest)
 		if err != nil {
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 		}
@@ -287,6 +311,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 				req.Header.Set("api-key", apiKey)
 			} else {
 				req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+				if channelType == common.ChannelTypeOpenRouter {
+					req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
+					req.Header.Set("X-Title", "One API")
+				}
 			}
 		case APITypeClaude:
 			req.Header.Set("x-api-key", apiKey)
@@ -303,6 +331,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			if textRequest.Stream {
 				req.Header.Set("X-DashScope-SSE", "enable")
 			}
+		default:
+			req.Header.Set("Authorization", "Bearer "+apiKey)
 		}
 		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@@ -365,7 +395,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
-
 					model.UpdateChannelUsedQuota(channelId, quota)
 				}
 			}
@@ -491,7 +520,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			}
 			return nil
 		} else {
-			err, usage := aliHandler(c, resp)
+			var err *OpenAIErrorWithStatusCode
+			var usage *Usage
+			switch relayMode {
+			case RelayModeEmbeddings:
+				err, usage = aliEmbeddingHandler(c, resp)
+			default:
+				err, usage = aliHandler(c, resp)
+			}
 			if err != nil {
 				return err
 			}
@@ -519,6 +555,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		} else {
 			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
 		}
+	case APITypeAIProxyLibrary:
+		if isStream {
+			err, usage := aiProxyLibraryStreamHandler(c, resp)
+			if err != nil {
+				return err
+			}
+			if usage != nil {
+				textResponse.Usage = *usage
+			}
+			return nil
+		} else {
+			err, usage := aiProxyLibraryHandler(c, resp)
+			if err != nil {
+				return err
+			}
+			if usage != nil {
+				textResponse.Usage = *usage
+			}
+			return nil
+		}
 	default:
 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 	}

+ 18 - 0
controller/relay-utils.go

@@ -15,6 +15,24 @@ var stopFinishReason = "stop"
 
 var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 
+func InitTokenEncoders() {
+	common.SysLog("initializing token encoders")
+	fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
+	if err != nil {
+		common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
+	}
+	for model, _ := range common.ModelRatio {
+		tokenEncoder, err := tiktoken.EncodingForModel(model)
+		if err != nil {
+			common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
+			tokenEncoderMap[model] = fallbackTokenEncoder
+			continue
+		}
+		tokenEncoderMap[model] = tokenEncoder
+	}
+	common.SysLog("token encoders initialized")
+}
+
 func getTokenEncoder(model string) *tiktoken.Tiktoken {
 	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
 		return tokenEncoder

+ 25 - 18
controller/relay.go

@@ -2,7 +2,6 @@ package controller
 
 import (
 	"fmt"
-	"log"
 	"net/http"
 	"one-api/common"
 	"strconv"
@@ -29,6 +28,7 @@ const (
 	RelayModeMidjourneyChange
 	RelayModeMidjourneyNotify
 	RelayModeMidjourneyTaskFetch
+	RelayModeAudio
 )
 
 // https://platform.openai.com/docs/api-reference/chat
@@ -45,6 +45,26 @@ type GeneralOpenAIRequest struct {
 	Input       any       `json:"input,omitempty"`
 	Instruction string    `json:"instruction,omitempty"`
 	Size        string    `json:"size,omitempty"`
+	Functions   any       `json:"functions,omitempty"`
+}
+
+func (r GeneralOpenAIRequest) ParseInput() []string {
+	if r.Input == nil {
+		return nil
+	}
+	var input []string
+	switch r.Input.(type) {
+	case string:
+		input = []string{r.Input.(string)}
+	case []any:
+		input = make([]string, 0, len(r.Input.([]any)))
+		for _, item := range r.Input.([]any) {
+			if str, ok := item.(string); ok {
+				input = append(input, str)
+			}
+		}
+	}
+	return input
 }
 
 type ChatRequest struct {
@@ -67,6 +87,10 @@ type ImageRequest struct {
 	Size   string `json:"size"`
 }
 
+type AudioResponse struct {
+	Text string `json:"text,omitempty"`
+}
+
 type Usage struct {
 	PromptTokens     int `json:"prompt_tokens"`
 	CompletionTokens int `json:"completion_tokens"`
@@ -147,23 +171,6 @@ type CompletionsStreamResponse struct {
 	} `json:"choices"`
 }
 
-type MidjourneyRequest struct {
-	Prompt      string   `json:"prompt"`
-	NotifyHook  string   `json:"notifyHook"`
-	Action      string   `json:"action"`
-	Index       int      `json:"index"`
-	State       string   `json:"state"`
-	TaskId      string   `json:"taskId"`
-	Base64Array []string `json:"base64Array"`
-}
-
-type MidjourneyResponse struct {
-	Code        int         `json:"code"`
-	Description string      `json:"description"`
-	Properties  interface{} `json:"properties"`
-	Result      string      `json:"result"`
-}
-
 func Relay(c *gin.Context) {
 	relayMode := RelayModeUnknown
 	if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {

+ 2 - 1
i18n/en.json

@@ -523,5 +523,6 @@
   "按照如下格式输入:": "Enter in the following format:",
   "模型版本": "Model version",
   "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
-  "点击查看": "click to view"
+  "点击查看": "click to view",
+  "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
 }

+ 6 - 0
main.go

@@ -78,6 +78,12 @@ func main() {
 		go controller.AutomaticallyTestChannels(frequency)
 	}
 	go controller.UpdateMidjourneyTask()
+	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
+		common.BatchUpdateEnabled = true
+		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
+		model.InitBatchUpdater()
+	}
+	controller.InitTokenEncoders()
 
 	// Initialize HTTP server
 	server := gin.Default()

+ 12 - 1
middleware/auth.go

@@ -109,7 +109,18 @@ func TokenAuth() func(c *gin.Context) {
 			c.Abort()
 			return
 		}
-		if !model.CacheIsUserEnabled(token.UserId) {
+		userEnabled, err := model.IsUserEnabled(token.UserId)
+		if err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{
+				"error": gin.H{
+					"message": err.Error(),
+					"type":    "one_api_error",
+				},
+			})
+			c.Abort()
+			return
+		}
+		if !userEnabled {
 			c.JSON(http.StatusForbidden, gin.H{
 				"error": gin.H{
 					"message": "用户已被封禁",

+ 11 - 5
middleware/distributor.go

@@ -2,7 +2,6 @@ package middleware
 
 import (
 	"fmt"
-	"log"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
@@ -22,7 +21,6 @@ func Distribute() func(c *gin.Context) {
 		userGroup, _ := model.CacheGetUserGroup(userId)
 		c.Set("group", userGroup)
 		var channel *model.Channel
-		var err error
 		channelId, ok := c.Get("channelId")
 		if ok {
 			id, err := strconv.Atoi(channelId.(string))
@@ -58,7 +56,6 @@ func Distribute() func(c *gin.Context) {
 				return
 			}
 		} else {
-
 			// Select a channel for the user
 			var modelRequest ModelRequest
 			if strings.HasPrefix(c.Request.URL.Path, "/mj") {
@@ -79,7 +76,6 @@ func Distribute() func(c *gin.Context) {
 					return
 				}
 			}
-
 			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 				if modelRequest.Model == "" {
 					modelRequest.Model = "text-moderation-stable"
@@ -95,6 +91,11 @@ func Distribute() func(c *gin.Context) {
 					modelRequest.Model = "dall-e"
 				}
 			}
+			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
+				if modelRequest.Model == "" {
+					modelRequest.Model = "whisper-1"
+				}
+			}
 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 			if err != nil {
 				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
@@ -118,8 +119,13 @@ func Distribute() func(c *gin.Context) {
 		c.Set("model_mapping", channel.ModelMapping)
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 		c.Set("base_url", channel.BaseURL)
-		if channel.Type == common.ChannelTypeAzure || channel.Type == common.ChannelTypeXunfei {
+		switch channel.Type {
+		case common.ChannelTypeAzure:
+			c.Set("api_version", channel.Other)
+		case common.ChannelTypeXunfei:
 			c.Set("api_version", channel.Other)
+		case common.ChannelTypeAIProxyLibrary:
+			c.Set("library_id", channel.Other)
 		}
 		c.Next()
 	}

+ 16 - 11
model/cache.go

@@ -103,23 +103,28 @@ func CacheDecreaseUserQuota(id int, quota int) error {
 	return err
 }
 
-func CacheIsUserEnabled(userId int) bool {
+func CacheIsUserEnabled(userId int) (bool, error) {
 	if !common.RedisEnabled {
 		return IsUserEnabled(userId)
 	}
 	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
+	if err == nil {
+		return enabled == "1", nil
+	}
+
+	userEnabled, err := IsUserEnabled(userId)
 	if err != nil {
-		status := common.UserStatusDisabled
-		if IsUserEnabled(userId) {
-			status = common.UserStatusEnabled
-		}
-		enabled = fmt.Sprintf("%d", status)
-		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
-		if err != nil {
-			common.SysError("Redis set user enabled error: " + err.Error())
-		}
+		return false, err
+	}
+	enabled = "0"
+	if userEnabled {
+		enabled = "1"
+	}
+	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
+	if err != nil {
+		common.SysError("Redis set user enabled error: " + err.Error())
 	}
-	return enabled == "1"
+	return userEnabled, err
 }
 
 var group2model2channels map[string]map[string][]*Channel

+ 8 - 0
model/channel.go

@@ -141,6 +141,14 @@ func UpdateChannelStatusById(id int, status int) {
 }
 
 func UpdateChannelUsedQuota(id int, quota int) {
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
+		return
+	}
+	updateChannelUsedQuota(id, quota)
+}
+
+func updateChannelUsedQuota(id int, quota int) {
 	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
 	if err != nil {
 		common.SysError("failed to update channel used quota: " + err.Error())

+ 40 - 19
model/token.go

@@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
 	}
 	token, err = CacheGetTokenByKey(key)
 	if err == nil {
+		if token.Status == common.TokenStatusExhausted {
+			return nil, errors.New("该令牌额度已用尽")
+		} else if token.Status == common.TokenStatusExpired {
+			return nil, errors.New("该令牌已过期")
+		}
 		if token.Status != common.TokenStatusEnabled {
 			return nil, errors.New("该令牌状态不可用")
 		}
 		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
-			token.Status = common.TokenStatusExpired
-			err := token.SelectUpdate()
-			if err != nil {
-				common.SysError("failed to update token status" + err.Error())
+			if !common.RedisEnabled {
+				token.Status = common.TokenStatusExpired
+				err := token.SelectUpdate()
+				if err != nil {
+					common.SysError("failed to update token status" + err.Error())
+				}
 			}
 			return nil, errors.New("该令牌已过期")
 		}
 		if !token.UnlimitedQuota && token.RemainQuota <= 0 {
-			token.Status = common.TokenStatusExhausted
-			err := token.SelectUpdate()
-			if err != nil {
-				common.SysError("failed to update token status" + err.Error())
+			if !common.RedisEnabled {
+				// in this case, we can make sure the token is exhausted
+				token.Status = common.TokenStatusExhausted
+				err := token.SelectUpdate()
+				if err != nil {
+					common.SysError("failed to update token status" + err.Error())
+				}
 			}
 			return nil, errors.New("该令牌额度已用尽")
 		}
-		go func() {
-			token.AccessedTime = common.GetTimestamp()
-			err := token.SelectUpdate()
-			if err != nil {
-				common.SysError("failed to update token" + err.Error())
-			}
-		}()
 		return token, nil
 	}
 	return nil, errors.New("无效的令牌")
@@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
+		return nil
+	}
+	return increaseTokenQuota(id, quota)
+}
+
+func increaseTokenQuota(id int, quota int) (err error) {
 	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 		map[string]interface{}{
-			"remain_quota": gorm.Expr("remain_quota + ?", quota),
-			"used_quota":   gorm.Expr("used_quota - ?", quota),
+			"remain_quota":  gorm.Expr("remain_quota + ?", quota),
+			"used_quota":    gorm.Expr("used_quota - ?", quota),
+			"accessed_time": common.GetTimestamp(),
 		},
 	).Error
 	return err
@@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
+		return nil
+	}
+	return decreaseTokenQuota(id, quota)
+}
+
+func decreaseTokenQuota(id int, quota int) (err error) {
 	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 		map[string]interface{}{
-			"remain_quota": gorm.Expr("remain_quota - ?", quota),
-			"used_quota":   gorm.Expr("used_quota + ?", quota),
+			"remain_quota":  gorm.Expr("remain_quota - ?", quota),
+			"used_quota":    gorm.Expr("used_quota + ?", quota),
+			"accessed_time": common.GetTimestamp(),
 		},
 	).Error
 	return err

+ 29 - 6
model/user.go

@@ -235,17 +235,16 @@ func IsAdmin(userId int) bool {
 	return user.Role >= common.RoleAdminUser
 }
 
-func IsUserEnabled(userId int) bool {
+func IsUserEnabled(userId int) (bool, error) {
 	if userId == 0 {
-		return false
+		return false, errors.New("user id is empty")
 	}
 	var user User
 	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
 	if err != nil {
-		common.SysError("no such user " + err.Error())
-		return false
+		return false, err
 	}
-	return user.Status == common.UserStatusEnabled
+	return user.Status == common.UserStatusEnabled, nil
 }
 
 func ValidateAccessToken(token string) (user *User) {
@@ -284,6 +283,14 @@ func IncreaseUserQuota(id int, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeUserQuota, id, quota)
+		return nil
+	}
+	return increaseUserQuota(id, quota)
+}
+
+func increaseUserQuota(id int, quota int) (err error) {
 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
 	return err
 }
@@ -292,6 +299,14 @@ func DecreaseUserQuota(id int, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")
 	}
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
+		return nil
+	}
+	return decreaseUserQuota(id, quota)
+}
+
+func decreaseUserQuota(id int, quota int) (err error) {
 	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
 	return err
 }
@@ -302,10 +317,18 @@ func GetRootUserEmail() (email string) {
 }
 
 func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
+	if common.BatchUpdateEnabled {
+		addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
+		return
+	}
+	updateUserUsedQuotaAndRequestCount(id, quota, 1)
+}
+
+func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
 	err := DB.Model(&User{}).Where("id = ?", id).Updates(
 		map[string]interface{}{
 			"used_quota":    gorm.Expr("used_quota + ?", quota),
-			"request_count": gorm.Expr("request_count + ?", 1),
+			"request_count": gorm.Expr("request_count + ?", count),
 		},
 	).Error
 	if err != nil {

+ 75 - 0
model/utils.go

@@ -0,0 +1,75 @@
+package model
+
+import (
+	"one-api/common"
+	"sync"
+	"time"
+)
+
+const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
+
+const (
+	BatchUpdateTypeUserQuota = iota
+	BatchUpdateTypeTokenQuota
+	BatchUpdateTypeUsedQuotaAndRequestCount
+	BatchUpdateTypeChannelUsedQuota
+)
+
+var batchUpdateStores []map[int]int
+var batchUpdateLocks []sync.Mutex
+
+func init() {
+	for i := 0; i < BatchUpdateTypeCount; i++ {
+		batchUpdateStores = append(batchUpdateStores, make(map[int]int))
+		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
+	}
+}
+
+func InitBatchUpdater() {
+	go func() {
+		for {
+			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
+			batchUpdate()
+		}
+	}()
+}
+
+func addNewRecord(type_ int, id int, value int) {
+	batchUpdateLocks[type_].Lock()
+	defer batchUpdateLocks[type_].Unlock()
+	if _, ok := batchUpdateStores[type_][id]; !ok {
+		batchUpdateStores[type_][id] = value
+	} else {
+		batchUpdateStores[type_][id] += value
+	}
+}
+
+func batchUpdate() {
+	common.SysLog("batch update started")
+	for i := 0; i < BatchUpdateTypeCount; i++ {
+		batchUpdateLocks[i].Lock()
+		store := batchUpdateStores[i]
+		batchUpdateStores[i] = make(map[int]int)
+		batchUpdateLocks[i].Unlock()
+
+		for key, value := range store {
+			switch i {
+			case BatchUpdateTypeUserQuota:
+				err := increaseUserQuota(key, value)
+				if err != nil {
+					common.SysError("failed to batch update user quota: " + err.Error())
+				}
+			case BatchUpdateTypeTokenQuota:
+				err := increaseTokenQuota(key, value)
+				if err != nil {
+					common.SysError("failed to batch update token quota: " + err.Error())
+				}
+			case BatchUpdateTypeUsedQuotaAndRequestCount:
+				updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
+			case BatchUpdateTypeChannelUsedQuota:
+				updateChannelUsedQuota(key, value)
+			}
+		}
+	}
+	common.SysLog("batch update finished")
+}

+ 2 - 2
router/relay-router.go

@@ -26,8 +26,8 @@ func SetRelayRouter(router *gin.Engine) {
 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
 		relayV1Router.POST("/embeddings", controller.Relay)
 		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
-		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
-		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
+		relayV1Router.POST("/audio/transcriptions", controller.Relay)
+		relayV1Router.POST("/audio/translations", controller.Relay)
 		relayV1Router.GET("/files", controller.RelayNotImplemented)
 		relayV1Router.POST("/files", controller.RelayNotImplemented)
 		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)

+ 1 - 1
web/src/components/LogsTable.js

@@ -324,7 +324,7 @@ const LogsTable = () => {
               .map((log, idx) => {
                 if (log.deleted) return <></>;
                 return (
-                  <Table.Row key={log.created_at}>
+                  <Table.Row key={log.id}>
                     <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
                     {
                       isAdminUser && (

+ 4 - 0
web/src/constants/channel.constants.js

@@ -7,7 +7,11 @@ export const CHANNEL_OPTIONS = [
   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
+  { key: 19, text: '360 智脑', value: 19, color: 'blue' },
   { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
+  { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
+  { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
+  { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
   { key: 2, text: '代理:API2D', value: 2, color: 'blue' },
   { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
   { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' },

+ 75 - 22
web/src/pages/Channel/EditChannel.js

@@ -1,6 +1,6 @@
 import React, { useEffect, useState } from 'react';
 import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
-import { useParams, useNavigate } from 'react-router-dom';
+import { useNavigate, useParams } from 'react-router-dom';
 import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
 import { CHANNEL_OPTIONS } from '../../constants';
 
@@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = {
   'gpt-4-32k-0314': 'gpt-4-32k'
 };
 
+function type2secretPrompt(type) {
+  // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
+  switch (type) {
+    case 15:
+      return '按照如下格式输入:APIKey|SecretKey';
+    case 18:
+      return '按照如下格式输入:APPID|APISecret|APIKey';
+    case 22:
+      return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
+    default:
+      return '请输入渠道对应的鉴权密钥';
+  }
+}
+
 const EditChannel = () => {
   const params = useParams();
   const navigate = useNavigate();
@@ -19,7 +33,7 @@ const EditChannel = () => {
   const handleCancel = () => {
     navigate('/channel');
   };
-  
+
   const originInputs = {
     name: '',
     type: 1,
@@ -53,7 +67,7 @@ const EditChannel = () => {
           localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
           break;
         case 17:
-          localModels = ['qwen-v1', 'qwen-plus-v1'];
+          localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
           break;
         case 16:
           localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
@@ -61,6 +75,9 @@ const EditChannel = () => {
         case 18:
           localModels = ['SparkDesk'];
           break;
+        case 19:
+          localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'];
+          break;
       }
       setInputs((inputs) => ({ ...inputs, models: localModels }));
     }
@@ -190,6 +207,24 @@ const EditChannel = () => {
     }
   };
 
+  const addCustomModel = () => {
+    if (customModel.trim() === '') return;
+    if (inputs.models.includes(customModel)) return;
+    let localModels = [...inputs.models];
+    localModels.push(customModel);
+    let localModelOptions = [];
+    localModelOptions.push({
+      key: customModel,
+      text: customModel,
+      value: customModel
+    });
+    setModelOptions(modelOptions => {
+      return [...modelOptions, ...localModelOptions];
+    });
+    setCustomModel('');
+    handleInputChange(null, { name: 'models', value: localModels });
+  };
+
   return (
     <>
       <Segment loading={loading}>
@@ -292,6 +327,20 @@ const EditChannel = () => {
               </Form.Field>
             )
           }
+          {
+            inputs.type === 21 && (
+              <Form.Field>
+                <Form.Input
+                  label='知识库 ID'
+                  name='other'
+                  placeholder={'请输入知识库 ID,例如:123456'}
+                  onChange={handleInputChange}
+                  value={inputs.other}
+                  autoComplete='new-password'
+                />
+              </Form.Field>
+            )
+          }
           <Form.Field>
             <Form.Dropdown
               label='模型'
@@ -319,29 +368,19 @@ const EditChannel = () => {
             }}>清除所有模型</Button>
             <Input
               action={
-                <Button type={'button'} onClick={() => {
-                  if (customModel.trim() === '') return;
-                  if (inputs.models.includes(customModel)) return;
-                  let localModels = [...inputs.models];
-                  localModels.push(customModel);
-                  let localModelOptions = [];
-                  localModelOptions.push({
-                    key: customModel,
-                    text: customModel,
-                    value: customModel
-                  });
-                  setModelOptions(modelOptions => {
-                    return [...modelOptions, ...localModelOptions];
-                  });
-                  setCustomModel('');
-                  handleInputChange(null, { name: 'models', value: localModels });
-                }}>填入</Button>
+                <Button type={'button'} onClick={addCustomModel}>填入</Button>
               }
               placeholder='输入自定义模型名称'
               value={customModel}
               onChange={(e, { value }) => {
                 setCustomModel(value);
               }}
+              onKeyDown={(e) => {
+                if (e.key === 'Enter') {
+                  addCustomModel();
+                  e.preventDefault();
+                }
+              }}
             />
           </div>
           <Form.Field>
@@ -372,7 +411,7 @@ const EditChannel = () => {
                 label='密钥'
                 name='key'
                 required
-                placeholder={inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
+                placeholder={type2secretPrompt(inputs.type)}
                 onChange={handleInputChange}
                 value={inputs.key}
                 autoComplete='new-password'
@@ -390,7 +429,7 @@ const EditChannel = () => {
             )
           }
           {
-            inputs.type !== 3 && inputs.type !== 8 && (
+            inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
               <Form.Field>
                 <Form.Input
                   label='代理'
@@ -403,6 +442,20 @@ const EditChannel = () => {
               </Form.Field>
             )
           }
+          {
+            inputs.type === 22 && (
+              <Form.Field>
+                <Form.Input
+                  label='私有部署地址'
+                  name='base_url'
+                  placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
+                  onChange={handleInputChange}
+                  value={inputs.base_url}
+                  autoComplete='new-password'
+                />
+              </Form.Field>
+            )
+          }
           <Button onClick={handleCancel}>取消</Button>
           <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
         </Form>