Browse Source

feat: support baidu's models now (close #286)

JustSong 2 years ago
parent
commit
9a1db61675

+ 2 - 1
README.md

@@ -60,8 +60,9 @@ _✨ All in one 的 OpenAI 接口,整合各种 API 访问方式,开箱即用
 ## 功能
 1. 支持多种 API 访问渠道:
    + [x] OpenAI 官方通道(支持配置镜像)
-   + [x] [Anthropic Claude 系列模型](https://anthropic.com)
    + [x] **Azure OpenAI API**
+   + [x] [Anthropic Claude 系列模型](https://anthropic.com)
+   + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
    + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
    + [x] [OpenAI-SB](https://openai-sb.com)
    + [x] [API2D](https://api2d.com/r/197971)

+ 2 - 0
common/constants.go

@@ -152,6 +152,7 @@ const (
 	ChannelTypeAPI2GPT   = 12
 	ChannelTypeAIGC2D    = 13
 	ChannelTypeAnthropic = 14
+	ChannelTypeBaidu     = 15
 )
 
 var ChannelBaseURLs = []string{
@@ -170,4 +171,5 @@ var ChannelBaseURLs = []string{
 	"https://api.api2gpt.com",       // 12
 	"https://api.aigc2d.com",        // 13
 	"https://api.anthropic.com",     // 14
+	"https://aip.baidubce.com",      // 15
 }

+ 3 - 0
common/model-ratio.go

@@ -4,6 +4,7 @@ import "encoding/json"
 
 // ModelRatio
 // https://platform.openai.com/docs/models/model-endpoint-compatibility
+// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
 // https://openai.com/pricing
 // TODO: when a new api is enabled, check the pricing here
 // 1 === $0.002 / 1K tokens
@@ -38,6 +39,8 @@ var ModelRatio = map[string]float64{
 	"dall-e":                  8,
 	"claude-instant-1":        0.75,
 	"claude-2":                30,
+	"ERNIE-Bot":               1,    // 0.012元/千tokens
+	"ERNIE-Bot-turbo":         0.67, // 0.008元/千tokens
 }
 
 func ModelRatio2JSONString() string {

+ 18 - 0
controller/model.go

@@ -288,6 +288,24 @@ func init() {
 			Root:       "claude-2",
 			Parent:     nil,
 		},
+		{
+			Id:         "ERNIE-Bot",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "baidu",
+			Permission: permission,
+			Root:       "ERNIE-Bot",
+			Parent:     nil,
+		},
+		{
+			Id:         "ERNIE-Bot-turbo",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "baidu",
+			Permission: permission,
+			Root:       "ERNIE-Bot-turbo",
+			Parent:     nil,
+		},
 	}
 	openAIModelsMap = make(map[string]OpenAIModels)
 	for _, model := range openAIModels {

+ 203 - 0
controller/relay-baidu.go

@@ -0,0 +1,203 @@
+package controller
+
+import (
+	"bufio"
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"strings"
+)
+
+// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
+
+type BaiduTokenResponse struct {
+	RefreshToken  string `json:"refresh_token"`
+	ExpiresIn     int    `json:"expires_in"`
+	SessionKey    string `json:"session_key"`
+	AccessToken   string `json:"access_token"`
+	Scope         string `json:"scope"`
+	SessionSecret string `json:"session_secret"`
+}
+
+type BaiduMessage struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
+}
+
+type BaiduChatRequest struct {
+	Messages []BaiduMessage `json:"messages"`
+	Stream   bool           `json:"stream"`
+	UserId   string         `json:"user_id,omitempty"`
+}
+
+type BaiduError struct {
+	ErrorCode int    `json:"error_code"`
+	ErrorMsg  string `json:"error_msg"`
+}
+
+type BaiduChatResponse struct {
+	Id               string `json:"id"`
+	Object           string `json:"object"`
+	Created          int64  `json:"created"`
+	Result           string `json:"result"`
+	IsTruncated      bool   `json:"is_truncated"`
+	NeedClearHistory bool   `json:"need_clear_history"`
+	Usage            Usage  `json:"usage"`
+	BaiduError
+}
+
+type BaiduChatStreamResponse struct {
+	BaiduChatResponse
+	SentenceId int  `json:"sentence_id"`
+	IsEnd      bool `json:"is_end"`
+}
+
+func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
+	messages := make([]BaiduMessage, 0, len(request.Messages))
+	for _, message := range request.Messages {
+		messages = append(messages, BaiduMessage{
+			Role:    message.Role,
+			Content: message.Content,
+		})
+	}
+	return &BaiduChatRequest{
+		Messages: messages,
+		Stream:   request.Stream,
+	}
+}
+
+func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
+	choice := OpenAITextResponseChoice{
+		Index: 0,
+		Message: Message{
+			Role:    "assistant",
+			Content: response.Result,
+		},
+		FinishReason: "stop",
+	}
+	fullTextResponse := OpenAITextResponse{
+		Id:      response.Id,
+		Object:  "chat.completion",
+		Created: response.Created,
+		Choices: []OpenAITextResponseChoice{choice},
+		Usage:   response.Usage,
+	}
+	return &fullTextResponse
+}
+
+func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
+	var choice ChatCompletionsStreamResponseChoice
+	choice.Delta.Content = baiduResponse.Result
+	choice.FinishReason = "stop"
+	response := ChatCompletionsStreamResponse{
+		Id:      baiduResponse.Id,
+		Object:  "chat.completion.chunk",
+		Created: baiduResponse.Created,
+		Model:   "ernie-bot",
+		Choices: []ChatCompletionsStreamResponseChoice{choice},
+	}
+	return &response
+}
+
+func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var usage Usage
+	scanner := bufio.NewScanner(resp.Body)
+	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+		if atEOF && len(data) == 0 {
+			return 0, nil, nil
+		}
+		if i := strings.Index(string(data), "\n"); i >= 0 {
+			return i + 1, data[0:i], nil
+		}
+		if atEOF {
+			return len(data), data, nil
+		}
+		return 0, nil, nil
+	})
+	dataChan := make(chan string)
+	stopChan := make(chan bool)
+	go func() {
+		for scanner.Scan() {
+			data := scanner.Text()
+			if len(data) < 6 { // ignore blank line or wrong format
+				continue
+			}
+			data = data[6:]
+			dataChan <- data
+		}
+		stopChan <- true
+	}()
+	c.Writer.Header().Set("Content-Type", "text/event-stream")
+	c.Writer.Header().Set("Cache-Control", "no-cache")
+	c.Writer.Header().Set("Connection", "keep-alive")
+	c.Writer.Header().Set("Transfer-Encoding", "chunked")
+	c.Writer.Header().Set("X-Accel-Buffering", "no")
+	c.Stream(func(w io.Writer) bool {
+		select {
+		case data := <-dataChan:
+			var baiduResponse BaiduChatStreamResponse
+			err := json.Unmarshal([]byte(data), &baiduResponse)
+			if err != nil {
+				common.SysError("error unmarshalling stream response: " + err.Error())
+				return true
+			}
+			usage.PromptTokens += baiduResponse.Usage.PromptTokens
+			usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
+			usage.TotalTokens += baiduResponse.Usage.TotalTokens
+			response := streamResponseBaidu2OpenAI(&baiduResponse)
+			jsonResponse, err := json.Marshal(response)
+			if err != nil {
+				common.SysError("error marshalling stream response: " + err.Error())
+				return true
+			}
+			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+			return true
+		case <-stopChan:
+			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+			return false
+		}
+	})
+	err := resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	return nil, &usage
+}
+
+func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var baiduResponse BaiduChatResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &baiduResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if baiduResponse.ErrorMsg != "" {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: OpenAIError{
+				Message: baiduResponse.ErrorMsg,
+				Type:    "baidu_error",
+				Param:   "",
+				Code:    baiduResponse.ErrorCode,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+	fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &fullTextResponse.Usage
+}

+ 39 - 1
controller/relay-text.go

@@ -18,6 +18,7 @@ const (
 	APITypeOpenAI = iota
 	APITypeClaude
 	APITypePaLM
+	APITypeBaidu
 )
 
 func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -79,6 +80,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	apiType := APITypeOpenAI
 	if strings.HasPrefix(textRequest.Model, "claude") {
 		apiType = APITypeClaude
+	} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
+		apiType = APITypeBaidu
 	}
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
@@ -112,6 +115,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		if baseURL != "" {
 			fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
 		}
+	case APITypeBaidu:
+		switch textRequest.Model {
+		case "ERNIE-Bot":
+			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
+		case "ERNIE-Bot-turbo":
+			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
+		case "BLOOMZ-7B":
+			fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
+		}
+		apiKey := c.Request.Header.Get("Authorization")
+		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
 	}
 	var promptTokens int
 	var completionTokens int
@@ -164,6 +179,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 		}
 		requestBody = bytes.NewBuffer(jsonStr)
+	case APITypeBaidu:
+		baiduRequest := requestOpenAI2Baidu(textRequest)
+		jsonStr, err := json.Marshal(baiduRequest)
+		if err != nil {
+			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+		}
+		requestBody = bytes.NewBuffer(jsonStr)
 	}
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 	if err != nil {
@@ -216,7 +238,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			if strings.HasPrefix(textRequest.Model, "gpt-4") {
 				completionRatio = 2
 			}
-			if isStream {
+			if isStream && apiType != APITypeBaidu {
 				completionTokens = countTokenText(streamResponseText, textRequest.Model)
 			} else {
 				promptTokens = textResponse.Usage.PromptTokens
@@ -285,6 +307,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			textResponse.Usage = *usage
 			return nil
 		}
+	case APITypeBaidu:
+		if isStream {
+			err, usage := baiduStreamHandler(c, resp)
+			if err != nil {
+				return err
+			}
+			textResponse.Usage = *usage
+			return nil
+		} else {
+			err, usage := baiduHandler(c, resp)
+			if err != nil {
+				return err
+			}
+			textResponse.Usage = *usage
+			return nil
+		}
 	default:
 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 	}

+ 1 - 0
web/src/constants/channel.constants.js

@@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [
   { key: 14, text: 'Anthropic', value: 14, color: 'black' },
   { key: 8, text: '自定义', value: 8, color: 'pink' },
   { key: 3, text: 'Azure', value: 3, color: 'olive' },
+  { key: 15, text: 'Baidu', value: 15, color: 'blue' },
   { key: 2, text: 'API2D', value: 2, color: 'blue' },
   { key: 4, text: 'CloseAI', value: 4, color: 'teal' },
   { key: 5, text: 'OpenAI-SB', value: 5, color: 'brown' },