Browse Source

feat: support ali's llm (close #326)

JustSong 2 years ago
parent
commit
e92da7928b

+ 1 - 0
README.md

@@ -63,6 +63,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
    + [x] [Anthropic Claude 系列模型](https://anthropic.com)
    + [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+   + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
 2. 支持配置镜像以及众多第三方代理服务:
    + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)

+ 19 - 17
common/constants.go

@@ -156,24 +156,26 @@ const (
 	ChannelTypeAnthropic = 14
 	ChannelTypeBaidu     = 15
 	ChannelTypeZhipu     = 16
+	ChannelTypeAli       = 17
 )
 
 var ChannelBaseURLs = []string{
-	"",                              // 0
-	"https://api.openai.com",        // 1
-	"https://oa.api2d.net",          // 2
-	"",                              // 3
-	"https://api.closeai-proxy.xyz", // 4
-	"https://api.openai-sb.com",     // 5
-	"https://api.openaimax.com",     // 6
-	"https://api.ohmygpt.com",       // 7
-	"",                              // 8
-	"https://api.caipacity.com",     // 9
-	"https://api.aiproxy.io",        // 10
-	"",                              // 11
-	"https://api.api2gpt.com",       // 12
-	"https://api.aigc2d.com",        // 13
-	"https://api.anthropic.com",     // 14
-	"https://aip.baidubce.com",      // 15
-	"https://open.bigmodel.cn",      // 16
+	"",                               // 0
+	"https://api.openai.com",         // 1
+	"https://oa.api2d.net",           // 2
+	"",                               // 3
+	"https://api.closeai-proxy.xyz",  // 4
+	"https://api.openai-sb.com",      // 5
+	"https://api.openaimax.com",      // 6
+	"https://api.ohmygpt.com",        // 7
+	"",                               // 8
+	"https://api.caipacity.com",      // 9
+	"https://api.aiproxy.io",         // 10
+	"",                               // 11
+	"https://api.api2gpt.com",        // 12
+	"https://api.aigc2d.com",         // 13
+	"https://api.anthropic.com",      // 14
+	"https://aip.baidubce.com",       // 15
+	"https://open.bigmodel.cn",       // 16
+	"https://dashscope.aliyuncs.com", // 17
 }

+ 2 - 0
common/model-ratio.go

@@ -46,6 +46,8 @@ var ModelRatio = map[string]float64{
 	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens
 	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens
 	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens
+	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
+	"qwen-plus-v1":            0.5715, // Same as above
 }
 
 func ModelRatio2JSONString() string {

+ 18 - 0
controller/model.go

@@ -324,6 +324,24 @@ func init() {
 			Root:       "chatglm_lite",
 			Parent:     nil,
 		},
+		{
+			Id:         "qwen-v1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "ali",
+			Permission: permission,
+			Root:       "qwen-v1",
+			Parent:     nil,
+		},
+		{
+			Id:         "qwen-plus-v1",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "ali",
+			Permission: permission,
+			Root:       "qwen-plus-v1",
+			Parent:     nil,
+		},
 	}
 	openAIModelsMap = make(map[string]OpenAIModels)
 	for _, model := range openAIModels {

+ 240 - 0
controller/relay-ali.go

@@ -0,0 +1,240 @@
+package controller
+
+import (
+	"bufio"
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"strings"
+)
+
+// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
+
+type AliMessage struct {
+	User string `json:"user"`
+	Bot  string `json:"bot"`
+}
+
+type AliInput struct {
+	Prompt  string       `json:"prompt"`
+	History []AliMessage `json:"history"`
+}
+
+type AliParameters struct {
+	TopP         float64 `json:"top_p,omitempty"`
+	TopK         int     `json:"top_k,omitempty"`
+	Seed         uint64  `json:"seed,omitempty"`
+	EnableSearch bool    `json:"enable_search,omitempty"`
+}
+
+type AliChatRequest struct {
+	Model      string        `json:"model"`
+	Input      AliInput      `json:"input"`
+	Parameters AliParameters `json:"parameters,omitempty"`
+}
+
+type AliError struct {
+	Code      string `json:"code"`
+	Message   string `json:"message"`
+	RequestId string `json:"request_id"`
+}
+
+type AliUsage struct {
+	InputTokens  int `json:"input_tokens"`
+	OutputTokens int `json:"output_tokens"`
+}
+
+type AliOutput struct {
+	Text         string `json:"text"`
+	FinishReason string `json:"finish_reason"`
+}
+
+type AliChatResponse struct {
+	Output AliOutput `json:"output"`
+	Usage  AliUsage  `json:"usage"`
+	AliError
+}
+
+func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
+	messages := make([]AliMessage, 0, len(request.Messages))
+	prompt := ""
+	for i := 0; i < len(request.Messages); i++ {
+		message := request.Messages[i]
+		if message.Role == "system" {
+			messages = append(messages, AliMessage{
+				User: message.Content,
+				Bot:  "Okay",
+			})
+			continue
+		} else {
+			if i == len(request.Messages)-1 {
+				prompt = message.Content
+				break
+			}
+			messages = append(messages, AliMessage{
+				User: message.Content,
+				Bot:  request.Messages[i+1].Content,
+			})
+			i++
+		}
+	}
+	return &AliChatRequest{
+		Model: request.Model,
+		Input: AliInput{
+			Prompt:  prompt,
+			History: messages,
+		},
+		//Parameters: AliParameters{  // ChatGPT's parameters are not compatible with Ali's
+		//	TopP: request.TopP,
+		//	TopK: 50,
+		//	//Seed:         0,
+		//	//EnableSearch: false,
+		//},
+	}
+}
+
+func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
+	choice := OpenAITextResponseChoice{
+		Index: 0,
+		Message: Message{
+			Role:    "assistant",
+			Content: response.Output.Text,
+		},
+		FinishReason: response.Output.FinishReason,
+	}
+	fullTextResponse := OpenAITextResponse{
+		Id:      response.RequestId,
+		Object:  "chat.completion",
+		Created: common.GetTimestamp(),
+		Choices: []OpenAITextResponseChoice{choice},
+		Usage: Usage{
+			PromptTokens:     response.Usage.InputTokens,
+			CompletionTokens: response.Usage.OutputTokens,
+			TotalTokens:      response.Usage.InputTokens + response.Usage.OutputTokens,
+		},
+	}
+	return &fullTextResponse
+}
+
+func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
+	var choice ChatCompletionsStreamResponseChoice
+	choice.Delta.Content = aliResponse.Output.Text
+	choice.FinishReason = aliResponse.Output.FinishReason
+	response := ChatCompletionsStreamResponse{
+		Id:      aliResponse.RequestId,
+		Object:  "chat.completion.chunk",
+		Created: common.GetTimestamp(),
+		Model:   "ernie-bot",
+		Choices: []ChatCompletionsStreamResponseChoice{choice},
+	}
+	return &response
+}
+
+func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var usage Usage
+	scanner := bufio.NewScanner(resp.Body)
+	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+		if atEOF && len(data) == 0 {
+			return 0, nil, nil
+		}
+		if i := strings.Index(string(data), "\n"); i >= 0 {
+			return i + 1, data[0:i], nil
+		}
+		if atEOF {
+			return len(data), data, nil
+		}
+		return 0, nil, nil
+	})
+	dataChan := make(chan string)
+	stopChan := make(chan bool)
+	go func() {
+		for scanner.Scan() {
+			data := scanner.Text()
+			if len(data) < 5 { // ignore blank line or wrong format
+				continue
+			}
+			if data[:5] != "data:" {
+				continue
+			}
+			data = data[5:]
+			dataChan <- data
+		}
+		stopChan <- true
+	}()
+	c.Writer.Header().Set("Content-Type", "text/event-stream")
+	c.Writer.Header().Set("Cache-Control", "no-cache")
+	c.Writer.Header().Set("Connection", "keep-alive")
+	c.Writer.Header().Set("Transfer-Encoding", "chunked")
+	c.Writer.Header().Set("X-Accel-Buffering", "no")
+	lastResponseText := ""
+	c.Stream(func(w io.Writer) bool {
+		select {
+		case data := <-dataChan:
+			var aliResponse AliChatResponse
+			err := json.Unmarshal([]byte(data), &aliResponse)
+			if err != nil {
+				common.SysError("error unmarshalling stream response: " + err.Error())
+				return true
+			}
+			usage.PromptTokens += aliResponse.Usage.InputTokens
+			usage.CompletionTokens += aliResponse.Usage.OutputTokens
+			usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
+			response := streamResponseAli2OpenAI(&aliResponse)
+			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
+			lastResponseText = aliResponse.Output.Text
+			jsonResponse, err := json.Marshal(response)
+			if err != nil {
+				common.SysError("error marshalling stream response: " + err.Error())
+				return true
+			}
+			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+			return true
+		case <-stopChan:
+			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+			return false
+		}
+	})
+	err := resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	return nil, &usage
+}
+
+func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var aliResponse AliChatResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &aliResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if aliResponse.Code != "" {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: OpenAIError{
+				Message: aliResponse.Message,
+				Type:    aliResponse.Code,
+				Param:   aliResponse.RequestId,
+				Code:    aliResponse.Code,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+	fullTextResponse := responseAli2OpenAI(&aliResponse)
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &fullTextResponse.Usage
+}

+ 39 - 1
controller/relay-text.go

@@ -20,6 +20,7 @@ const (
 	APITypePaLM
 	APITypeBaidu
 	APITypeZhipu
+	APITypeAli
 )
 
 var httpClient *http.Client
@@ -94,6 +95,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		apiType = APITypePaLM
 	case common.ChannelTypeZhipu:
 		apiType = APITypeZhipu
+	case common.ChannelTypeAli:
+		apiType = APITypeAli
+
 	}
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
@@ -153,6 +157,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			method = "sse-invoke"
 		}
 		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
+	case APITypeAli:
+		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
 	}
 	var promptTokens int
 	var completionTokens int
@@ -226,6 +232,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 		}
 		requestBody = bytes.NewBuffer(jsonStr)
+	case APITypeAli:
+		aliRequest := requestOpenAI2Ali(textRequest)
+		jsonStr, err := json.Marshal(aliRequest)
+		if err != nil {
+			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
+		}
+		requestBody = bytes.NewBuffer(jsonStr)
 	}
 	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 	if err != nil {
@@ -250,6 +263,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	case APITypeZhipu:
 		token := getZhipuToken(apiKey)
 		req.Header.Set("Authorization", token)
+	case APITypeAli:
+		req.Header.Set("Authorization", "Bearer "+apiKey)
+		if textRequest.Stream {
+			req.Header.Set("X-DashScope-SSE", "enable")
+		}
 	}
 	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@@ -280,7 +298,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			if strings.HasPrefix(textRequest.Model, "gpt-4") {
 				completionRatio = 2
 			}
-			if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu {
+			if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu && apiType != APITypeAli {
 				completionTokens = countTokenText(streamResponseText, textRequest.Model)
 			} else {
 				promptTokens = textResponse.Usage.PromptTokens
@@ -415,6 +433,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			}
 			return nil
 		}
+	case APITypeAli:
+		if isStream {
+			err, usage := aliStreamHandler(c, resp)
+			if err != nil {
+				return err
+			}
+			if usage != nil {
+				textResponse.Usage = *usage
+			}
+			return nil
+		} else {
+			err, usage := aliHandler(c, resp)
+			if err != nil {
+				return err
+			}
+			if usage != nil {
+				textResponse.Usage = *usage
+			}
+			return nil
+		}
 	default:
 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 	}

+ 2 - 1
web/src/constants/channel.constants.js

@@ -4,6 +4,7 @@ export const CHANNEL_OPTIONS = [
   { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' },
   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
+  { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
   { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
   { key: 2, text: '代理:API2D', value: 2, color: 'blue' },
@@ -14,5 +15,5 @@ export const CHANNEL_OPTIONS = [
   { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' },
   { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' },
   { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' },
-  { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' }
+  { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' },
 ];