ソースを参照

feat: Add Baidu Qianfan V2 channel support #725

- Update channel constants to include Baidu V2 channel
- Create new Baidu V2 adaptor for relay
- Add Baidu V2 models and channel configuration
- Update relay adaptor to support Baidu V2 channel
- Modify web channel constants to include Baidu V2 option
[email protected] 10 ヶ月 前
コミット
eceb6afcdd

+ 6 - 4
common/constants.go

@@ -231,8 +231,9 @@ const (
 	ChannelTypeVertexAi       = 41
 	ChannelTypeMistral        = 42
 	ChannelTypeDeepSeek       = 43
-	ChannelTypeMokaAI         = 47
-	ChannelTypeVolcEngine     = 48
+	ChannelTypeMokaAI         = 44
+	ChannelTypeVolcEngine     = 45
+	ChannelTypeBaiduV2        = 46
 	ChannelTypeDummy          // this one is only for count, do not add any channel after this
 
 )
@@ -282,6 +283,7 @@ var ChannelBaseURLs = []string{
 	"",                                          //41
 	"https://api.mistral.ai",                    //42
 	"https://api.deepseek.com",                  //43
-	"https://api.moka.ai",                       //43
-	"https://ark.cn-beijing.volces.com",         //44
+	"https://api.moka.ai",                       //44
+	"https://ark.cn-beijing.volces.com",         //45
+	"https://qianfan.baidubce.com",              //46
 }

+ 76 - 0
relay/channel/baidu_v2/adaptor.go

@@ -0,0 +1,76 @@
+package baidu_v2
+
+import (
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/dto"
+	"one-api/relay/channel"
+	"one-api/relay/channel/openai"
+	relaycommon "one-api/relay/common"
+)
+
+type Adaptor struct {
+}
+
+func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
+}
+
+func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
+}
+
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+	channel.SetupApiRequestHeader(info, c, req)
+	req.Set("Authorization", "Bearer "+info.ApiKey)
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
+	if request == nil {
+		return nil, errors.New("request is nil")
+	}
+	return request, nil
+}
+
+func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
+	return nil, nil
+}
+
+func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
+	//TODO implement me
+	return nil, errors.New("not implemented")
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
+	return channel.DoApiRequest(a, c, info, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	if info.IsStream {
+		err, usage = openai.OaiStreamHandler(c, resp, info)
+	} else {
+		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+	}
+	return
+}
+
+func (a *Adaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *Adaptor) GetChannelName() string {
+	return ChannelName
+}

+ 29 - 0
relay/channel/baidu_v2/constants.go

@@ -0,0 +1,29 @@
+package baidu_v2
+
+var ModelList = []string{
+	"ernie-4.0-8k-latest",
+	"ernie-4.0-8k-preview",
+	"ernie-4.0-8k",
+	"ernie-4.0-turbo-8k-latest",
+	"ernie-4.0-turbo-8k-preview",
+	"ernie-4.0-turbo-8k",
+	"ernie-4.0-turbo-128k",
+	"ernie-3.5-8k-preview",
+	"ernie-3.5-8k",
+	"ernie-3.5-128k",
+	"ernie-speed-8k",
+	"ernie-speed-128k",
+	"ernie-speed-pro-128k",
+	"ernie-lite-8k",
+	"ernie-lite-pro-128k",
+	"ernie-tiny-8k",
+	"ernie-char-8k",
+	"ernie-char-fiction-8k",
+	"ernie-novel-8k",
+	"deepseek-v3",
+	"deepseek-r1",
+	"deepseek-r1-distill-qwen-32b",
+	"deepseek-r1-distill-qwen-14b",
+}
+
+var ChannelName = "volcengine"

+ 1 - 2
relay/channel/deepseek/adaptor.go

@@ -29,7 +29,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
+	return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -54,7 +54,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 	return nil, errors.New("not implemented")
 }
 
-
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 18 - 6
relay/channel/volcengine/adaptor.go

@@ -10,6 +10,7 @@ import (
 	"one-api/relay/channel"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
 )
 
 type Adaptor struct {
@@ -29,7 +30,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
+	switch info.RelayMode {
+	case constant.RelayModeChatCompletions:
+		return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
+	case constant.RelayModeEmbeddings:
+		return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
+	default:
+	}
+	return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -50,8 +58,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	return request, nil
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
@@ -59,9 +66,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
-	if info.IsStream {
-		err, usage = openai.OaiStreamHandler(c, resp, info)
-	} else {
+	switch info.RelayMode {
+	case constant.RelayModeChatCompletions:
+		if info.IsStream {
+			err, usage = openai.OaiStreamHandler(c, resp, info)
+		} else {
+			err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
+		}
+	case constant.RelayModeEmbeddings:
 		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}
 	return

+ 3 - 0
relay/constant/api_type.go

@@ -29,6 +29,7 @@ const (
 	APITypeDeepSeek
 	APITypeMokaAI
 	APITypeVolcEngine
+	APITypeBaiduV2
 	APITypeDummy // this one is only for count, do not add any channel after this
 )
 
@@ -83,6 +84,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
 		apiType = APITypeMokaAI
 	case common.ChannelTypeVolcEngine:
 		apiType = APITypeVolcEngine
+	case common.ChannelTypeBaiduV2:
+		apiType = APITypeBaiduV2
 	}
 	if apiType == -1 {
 		return APITypeOpenAI, false

+ 3 - 0
relay/relay_adaptor.go

@@ -6,6 +6,7 @@ import (
 	"one-api/relay/channel/ali"
 	"one-api/relay/channel/aws"
 	"one-api/relay/channel/baidu"
+	"one-api/relay/channel/baidu_v2"
 	"one-api/relay/channel/claude"
 	"one-api/relay/channel/cloudflare"
 	"one-api/relay/channel/cohere"
@@ -80,6 +81,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
 		return &mokaai.Adaptor{}
 	case constant.APITypeVolcEngine:
 		return &volcengine.Adaptor{}
+	case constant.APITypeBaiduV2:
+		return &baidu_v2.Adaptor{}
 	}
 	return nil
 }

+ 23 - 36
web/src/constants/channel.constants.js

@@ -1,124 +1,111 @@
 export const CHANNEL_OPTIONS = [
-  { key: 1, value: 1, color: 'green', label: 'OpenAI' },
+  { value: 1, color: 'green', label: 'OpenAI' },
   {
-    key: 2,
     value: 2,
     color: 'light-blue',
     label: 'Midjourney Proxy'
   },
   {
-    key: 5,
     value: 5,
     color: 'blue',
     label: 'Midjourney Proxy Plus'
   },
   {
-    key: 36,
     value: 36,
     color: 'purple',
     label: 'Suno API'
   },
-  { key: 4, value: 4, color: 'grey', label: 'Ollama' },
+  { value: 4, color: 'grey', label: 'Ollama' },
   {
-    key: 14,
     value: 14,
     color: 'indigo',
     label: 'Anthropic Claude'
   },
   {
-    key: 33,
     value: 33,
     color: 'indigo',
     label: 'AWS Claude'
   },
-  { key: 41, value: 41, color: 'blue', label: 'Vertex AI' },
+  { value: 41, color: 'blue', label: 'Vertex AI' },
   {
-    key: 3,
     value: 3,
     color: 'teal',
     label: 'Azure OpenAI'
   },
   {
-    key: 34,
     value: 34,
     color: 'purple',
     label: 'Cohere'
   },
-  { key: 39, value: 39, color: 'grey', label: 'Cloudflare' },
-  { key: 43, value: 43, color: 'blue', label: 'DeepSeek' },
+  { value: 39, color: 'grey', label: 'Cloudflare' },
+  { value: 43, color: 'blue', label: 'DeepSeek' },
   {
-    key: 15,
     value: 15,
     color: 'blue',
     label: '百度文心千帆'
   },
   {
-    key: 17,
+    value: 46,
+    color: 'blue',
+    label: '百度文心千帆V2'
+  },
+  {
     value: 17,
     color: 'orange',
     label: '阿里通义千问'
   },
   {
-    key: 18,
     value: 18,
     color: 'blue',
     label: '讯飞星火认知'
   },
   {
-    key: 16,
     value: 16,
     color: 'violet',
     label: '智谱 ChatGLM'
   },
   {
-    key: 26,
     value: 26,
     color: 'purple',
     label: '智谱 GLM-4V'
   },
   {
-    key: 24,
     value: 24,
     color: 'orange',
     label: 'Google Gemini'
   },
   {
-    key: 11,
     value: 11,
     color: 'orange',
     label: 'Google PaLM2'
   },
   {
-    key: 48,
-    value: 48,
+    value: 45,
     color: 'blue',
     label: '火山方舟(豆包)'
   },
-  { key: 25, value: 25, color: 'green', label: 'Moonshot' },
-  { key: 19, value: 19, color: 'blue', label: '360 智脑' },
-  { key: 23, value: 23, color: 'teal', label: '腾讯混元' },
-  { key: 31, value: 31, color: 'green', label: '零一万物' },
-  { key: 35, value: 35, color: 'green', label: 'MiniMax' },
-  { key: 37, value: 37, color: 'teal', label: 'Dify' },
-  { key: 38, value: 38, color: 'blue', label: 'Jina' },
-  { key: 40, value: 40, color: 'purple', label: 'SiliconCloud' },
-  { key: 42, value: 42, color: 'blue', label: 'Mistral AI' },
-  { key: 8, value: 8, color: 'pink', label: '自定义渠道' },
-  {
-    key: 22,
+  { value: 25, color: 'green', label: 'Moonshot' },
+  { value: 19, color: 'blue', label: '360 智脑' },
+  { value: 23, color: 'teal', label: '腾讯混元' },
+  { value: 31, color: 'green', label: '零一万物' },
+  { value: 35, color: 'green', label: 'MiniMax' },
+  { value: 37, color: 'teal', label: 'Dify' },
+  { value: 38, color: 'blue', label: 'Jina' },
+  { value: 40, color: 'purple', label: 'SiliconCloud' },
+  { value: 42, color: 'blue', label: 'Mistral AI' },
+  { value: 8, color: 'pink', label: '自定义渠道' },
+  {
     value: 22,
     color: 'blue',
     label: '知识库:FastGPT'
   },
   {
-    key: 21,
     value: 21,
     color: 'purple',
     label: '知识库:AI Proxy'
   },
   {
-    key: 47,
-    value: 47,
+    value: 44,
     color: 'purple',
     label: '嵌入模型:MokaAI M3E'
   }