فهرست منبع

feat: support xunfei's llm (close #206)

JustSong 2 سال پیش
والد
کامیت
8a866078b2
9فایلهای تغییر یافته به همراه363 افزوده شده و 43 حذف شده
  1. 1 0
      README.md
  2. 2 0
      common/constants.go
  3. 1 0
      common/model-ratio.go
  4. 9 0
      controller/model.go
  5. 71 42
      controller/relay-text.go
  6. 274 0
      controller/relay-xunfei.go
  7. 1 0
      go.mod
  8. 2 0
      go.sum
  9. 2 1
      web/src/constants/channel.constants.js

+ 1 - 0
README.md

@@ -64,6 +64,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
    + [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
    + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
    + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+   + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
    + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
 2. 支持配置镜像以及众多第三方代理服务:
    + [x] [OpenAI-SB](https://openai-sb.com)

+ 2 - 0
common/constants.go

@@ -157,6 +157,7 @@ const (
 	ChannelTypeBaidu     = 15
 	ChannelTypeZhipu     = 16
 	ChannelTypeAli       = 17
+	ChannelTypeXunfei    = 18
 )
 
 var ChannelBaseURLs = []string{
@@ -178,4 +179,5 @@ var ChannelBaseURLs = []string{
 	"https://aip.baidubce.com",       // 15
 	"https://open.bigmodel.cn",       // 16
 	"https://dashscope.aliyuncs.com", // 17
+	"",                               // 18
 }

+ 1 - 0
common/model-ratio.go

@@ -49,6 +49,7 @@ var ModelRatio = map[string]float64{
 	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens
 	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
 	"qwen-plus-v1":            0.5715, // Same as above
+	"SparkDesk":               0.8572, // TBD
 }
 
 func ModelRatio2JSONString() string {

+ 9 - 0
controller/model.go

@@ -351,6 +351,15 @@ func init() {
 			Root:       "qwen-plus-v1",
 			Parent:     nil,
 		},
+		{
+			Id:         "SparkDesk",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "xunfei",
+			Permission: permission,
+			Root:       "SparkDesk",
+			Parent:     nil,
+		},
 	}
 	openAIModelsMap = make(map[string]OpenAIModels)
 	for _, model := range openAIModels {

+ 71 - 42
controller/relay-text.go

@@ -21,6 +21,7 @@ const (
 	APITypeBaidu
 	APITypeZhipu
 	APITypeAli
+	APITypeXunfei
 )
 
 var httpClient *http.Client
@@ -97,7 +98,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		apiType = APITypeZhipu
 	case common.ChannelTypeAli:
 		apiType = APITypeAli
-
+	case common.ChannelTypeXunfei:
+		apiType = APITypeXunfei
 	}
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
@@ -250,52 +252,60 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		}
 		requestBody = bytes.NewBuffer(jsonStr)
 	}
-	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
-	if err != nil {
-		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
-	}
-	apiKey := c.Request.Header.Get("Authorization")
-	apiKey = strings.TrimPrefix(apiKey, "Bearer ")
-	switch apiType {
-	case APITypeOpenAI:
-		if channelType == common.ChannelTypeAzure {
-			req.Header.Set("api-key", apiKey)
-		} else {
-			req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+
+	var req *http.Request
+	var resp *http.Response
+	isStream := textRequest.Stream
+
+	if apiType != APITypeXunfei { // cause xunfei use websocket
+		req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+		if err != nil {
+			return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 		}
-	case APITypeClaude:
-		req.Header.Set("x-api-key", apiKey)
-		anthropicVersion := c.Request.Header.Get("anthropic-version")
-		if anthropicVersion == "" {
-			anthropicVersion = "2023-06-01"
+		apiKey := c.Request.Header.Get("Authorization")
+		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+		switch apiType {
+		case APITypeOpenAI:
+			if channelType == common.ChannelTypeAzure {
+				req.Header.Set("api-key", apiKey)
+			} else {
+				req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
+			}
+		case APITypeClaude:
+			req.Header.Set("x-api-key", apiKey)
+			anthropicVersion := c.Request.Header.Get("anthropic-version")
+			if anthropicVersion == "" {
+				anthropicVersion = "2023-06-01"
+			}
+			req.Header.Set("anthropic-version", anthropicVersion)
+		case APITypeZhipu:
+			token := getZhipuToken(apiKey)
+			req.Header.Set("Authorization", token)
+		case APITypeAli:
+			req.Header.Set("Authorization", "Bearer "+apiKey)
+			if textRequest.Stream {
+				req.Header.Set("X-DashScope-SSE", "enable")
+			}
 		}
-		req.Header.Set("anthropic-version", anthropicVersion)
-	case APITypeZhipu:
-		token := getZhipuToken(apiKey)
-		req.Header.Set("Authorization", token)
-	case APITypeAli:
-		req.Header.Set("Authorization", "Bearer "+apiKey)
-		if textRequest.Stream {
-			req.Header.Set("X-DashScope-SSE", "enable")
+		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+		//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
+		resp, err = httpClient.Do(req)
+		if err != nil {
+			return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 		}
+		err = req.Body.Close()
+		if err != nil {
+			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+		}
+		err = c.Request.Body.Close()
+		if err != nil {
+			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
+		}
+		isStream = strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 	}
-	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
-	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
-	//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
-	resp, err := httpClient.Do(req)
-	if err != nil {
-		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
-	}
-	err = req.Body.Close()
-	if err != nil {
-		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
-	}
-	err = c.Request.Body.Close()
-	if err != nil {
-		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
-	}
+
 	var textResponse TextResponse
-	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 	var streamResponseText string
 
 	defer func() {
@@ -470,6 +480,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			}
 			return nil
 		}
+	case APITypeXunfei:
+		if isStream {
+			auth := c.Request.Header.Get("Authorization")
+			auth = strings.TrimPrefix(auth, "Bearer ")
+			splits := strings.Split(auth, "|")
+			if len(splits) != 3 {
+				return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
+			}
+			err, usage := xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
+			if err != nil {
+				return err
+			}
+			if usage != nil {
+				textResponse.Usage = *usage
+			}
+			return nil
+		} else {
+			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
+		}
 	default:
 		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 	}

+ 274 - 0
controller/relay-xunfei.go

@@ -0,0 +1,274 @@
+package controller
+
+import (
+	"crypto/hmac"
+	"crypto/sha256"
+	"encoding/base64"
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
+	"io"
+	"net/http"
+	"net/url"
+	"one-api/common"
+	"strings"
+	"time"
+)
+
+// https://console.xfyun.cn/services/cbm
+// https://www.xfyun.cn/doc/spark/Web.html
+
+type XunfeiMessage struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
+}
+
+type XunfeiChatRequest struct {
+	Header struct {
+		AppId string `json:"app_id"`
+	} `json:"header"`
+	Parameter struct {
+		Chat struct {
+			Domain      string  `json:"domain,omitempty"`
+			Temperature float64 `json:"temperature,omitempty"`
+			TopK        int     `json:"top_k,omitempty"`
+			MaxTokens   int     `json:"max_tokens,omitempty"`
+			Auditing    bool    `json:"auditing,omitempty"`
+		} `json:"chat"`
+	} `json:"parameter"`
+	Payload struct {
+		Message struct {
+			Text []XunfeiMessage `json:"text"`
+		} `json:"message"`
+	} `json:"payload"`
+}
+
+type XunfeiChatResponseTextItem struct {
+	Content string `json:"content"`
+	Role    string `json:"role"`
+	Index   int    `json:"index"`
+}
+
+type XunfeiChatResponse struct {
+	Header struct {
+		Code    int    `json:"code"`
+		Message string `json:"message"`
+		Sid     string `json:"sid"`
+		Status  int    `json:"status"`
+	} `json:"header"`
+	Payload struct {
+		Choices struct {
+			Status int                          `json:"status"`
+			Seq    int                          `json:"seq"`
+			Text   []XunfeiChatResponseTextItem `json:"text"`
+		} `json:"choices"`
+	} `json:"payload"`
+	Usage struct {
+		//Text struct {
+		//	QuestionTokens   string `json:"question_tokens"`
+		//	PromptTokens     string `json:"prompt_tokens"`
+		//	CompletionTokens string `json:"completion_tokens"`
+		//	TotalTokens      string `json:"total_tokens"`
+		//} `json:"text"`
+		Text Usage `json:"text"`
+	} `json:"usage"`
+}
+
+func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest {
+	messages := make([]XunfeiMessage, 0, len(request.Messages))
+	for _, message := range request.Messages {
+		if message.Role == "system" {
+			messages = append(messages, XunfeiMessage{
+				Role:    "user",
+				Content: message.Content,
+			})
+			messages = append(messages, XunfeiMessage{
+				Role:    "assistant",
+				Content: "Okay",
+			})
+		} else {
+			messages = append(messages, XunfeiMessage{
+				Role:    message.Role,
+				Content: message.Content,
+			})
+		}
+	}
+	xunfeiRequest := XunfeiChatRequest{}
+	xunfeiRequest.Header.AppId = xunfeiAppId
+	xunfeiRequest.Parameter.Chat.Domain = "general"
+	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
+	xunfeiRequest.Parameter.Chat.TopK = request.N
+	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
+	xunfeiRequest.Payload.Message.Text = messages
+	return &xunfeiRequest
+}
+
+func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
+	if len(response.Payload.Choices.Text) == 0 {
+		response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
+			{
+				Content: "",
+			},
+		}
+	}
+	choice := OpenAITextResponseChoice{
+		Index: 0,
+		Message: Message{
+			Role:    "assistant",
+			Content: response.Payload.Choices.Text[0].Content,
+		},
+	}
+	fullTextResponse := OpenAITextResponse{
+		Object:  "chat.completion",
+		Created: common.GetTimestamp(),
+		Choices: []OpenAITextResponseChoice{choice},
+		Usage:   response.Usage.Text,
+	}
+	return &fullTextResponse
+}
+
+func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
+	if len(xunfeiResponse.Payload.Choices.Text) == 0 {
+		xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
+			{
+				Content: "",
+			},
+		}
+	}
+	var choice ChatCompletionsStreamResponseChoice
+	choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
+	response := ChatCompletionsStreamResponse{
+		Object:  "chat.completion.chunk",
+		Created: common.GetTimestamp(),
+		Model:   "SparkDesk",
+		Choices: []ChatCompletionsStreamResponseChoice{choice},
+	}
+	return &response
+}
+
+func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
+	HmacWithShaToBase64 := func(algorithm, data, key string) string {
+		mac := hmac.New(sha256.New, []byte(key))
+		mac.Write([]byte(data))
+		encodeData := mac.Sum(nil)
+		return base64.StdEncoding.EncodeToString(encodeData)
+	}
+	ul, err := url.Parse(hostUrl)
+	if err != nil {
+		fmt.Println(err)
+	}
+	date := time.Now().UTC().Format(time.RFC1123)
+	signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
+	sign := strings.Join(signString, "\n")
+	sha := HmacWithShaToBase64("hmac-sha256", sign, apiSecret)
+	authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
+		"hmac-sha256", "host date request-line", sha)
+	authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
+	v := url.Values{}
+	v.Add("host", ul.Host)
+	v.Add("date", date)
+	v.Add("authorization", authorization)
+	callUrl := hostUrl + "?" + v.Encode()
+	return callUrl
+}
+
+func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiKey string, apiSecret string) (*OpenAIErrorWithStatusCode, *Usage) {
+	var usage Usage
+	d := websocket.Dialer{
+		HandshakeTimeout: 5 * time.Second,
+	}
+	hostUrl := "wss://aichat.xf-yun.com/v1/chat"
+	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
+	if err != nil || resp.StatusCode != 101 {
+		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
+	}
+	data := requestOpenAI2Xunfei(textRequest, appId)
+	err = conn.WriteJSON(data)
+	if err != nil {
+		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
+	}
+	dataChan := make(chan XunfeiChatResponse)
+	stopChan := make(chan bool)
+	go func() {
+		for {
+			_, msg, err := conn.ReadMessage()
+			if err != nil {
+				common.SysError("error reading stream response: " + err.Error())
+				break
+			}
+			var response XunfeiChatResponse
+			err = json.Unmarshal(msg, &response)
+			if err != nil {
+				common.SysError("error unmarshalling stream response: " + err.Error())
+				break
+			}
+			dataChan <- response
+			if response.Payload.Choices.Status == 2 {
+				break
+			}
+		}
+		stopChan <- true
+	}()
+	c.Writer.Header().Set("Content-Type", "text/event-stream")
+	c.Writer.Header().Set("Cache-Control", "no-cache")
+	c.Writer.Header().Set("Connection", "keep-alive")
+	c.Writer.Header().Set("Transfer-Encoding", "chunked")
+	c.Writer.Header().Set("X-Accel-Buffering", "no")
+	c.Stream(func(w io.Writer) bool {
+		select {
+		case xunfeiResponse := <-dataChan:
+			usage.PromptTokens += xunfeiResponse.Usage.Text.PromptTokens
+			usage.CompletionTokens += xunfeiResponse.Usage.Text.CompletionTokens
+			usage.TotalTokens += xunfeiResponse.Usage.Text.TotalTokens
+			response := streamResponseXunfei2OpenAI(&xunfeiResponse)
+			jsonResponse, err := json.Marshal(response)
+			if err != nil {
+				common.SysError("error marshalling stream response: " + err.Error())
+				return true
+			}
+			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+			return true
+		case <-stopChan:
+			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+			return false
+		}
+	})
+	return nil, &usage
+}
+
+func xunfeiHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
+	var xunfeiResponse XunfeiChatResponse
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = json.Unmarshal(responseBody, &xunfeiResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if xunfeiResponse.Header.Code != 0 {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: OpenAIError{
+				Message: xunfeiResponse.Header.Message,
+				Type:    "xunfei_error",
+				Param:   "",
+				Code:    xunfeiResponse.Header.Code,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+	fullTextResponse := responseXunfei2OpenAI(&xunfeiResponse)
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &fullTextResponse.Usage
+}

+ 1 - 0
go.mod

@@ -13,6 +13,7 @@ require (
 	github.com/go-redis/redis/v8 v8.11.5
 	github.com/golang-jwt/jwt v3.2.2+incompatible
 	github.com/google/uuid v1.3.0
+	github.com/gorilla/websocket v1.5.0
 	github.com/pkoukk/tiktoken-go v0.1.1
 	golang.org/x/crypto v0.9.0
 	gorm.io/driver/mysql v1.4.3

+ 2 - 0
go.sum

@@ -67,6 +67,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
 github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
 github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
 github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
+github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
+github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
 github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
 github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=

+ 2 - 1
web/src/constants/channel.constants.js

@@ -5,6 +5,7 @@ export const CHANNEL_OPTIONS = [
   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' },
   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' },
   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
+  { key: 18, text: '讯飞星火认知大模型', value: 18, color: 'blue' },
   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
   { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
   { key: 2, text: '代理:API2D', value: 2, color: 'blue' },
@@ -15,5 +16,5 @@ export const CHANNEL_OPTIONS = [
   { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' },
   { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' },
   { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' },
-  { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' },
+  { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' }
 ];