Browse Source

coze stream

creamlike1024 8 months ago
parent
commit
e379ee8f66
2 changed files with 130 additions and 3 deletions
  1. 8 1
      relay/channel/coze/adaptor.go
  2. 122 2
      relay/channel/coze/relay-coze.go

+ 8 - 1
relay/channel/coze/adaptor.go

@@ -57,6 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 
 // DoRequest implements channel.Adaptor.
 func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
+	if info.IsStream {
+		return channel.DoApiRequest(a, c, info, requestBody)
+	}
 	// 首先发送创建消息请求,成功后再发送获取消息请求
 	// 发送创建消息请求
 	resp, err := channel.DoApiRequest(a, c, info, requestBody)
@@ -93,7 +96,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
 
 // DoResponse implements channel.Adaptor.
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
-	err, usage = cozeChatHandler(c, resp, info)
+	if info.IsStream {
+		err, usage = cozeChatStreamHandler(c, resp, info)
+	} else {
+		err, usage = cozeChatHandler(c, resp, info)
+	}
 	return
 }
 

+ 122 - 2
relay/channel/coze/relay-coze.go

@@ -1,16 +1,18 @@
 package coze
 
 import (
+	"bufio"
 	"encoding/json"
 	"errors"
 	"fmt"
 	"io"
 	"net/http"
+	"one-api/common"
 	"one-api/dto"
-	"one-api/relay/common"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/service"
+	"strings"
 
 	"github.com/gin-gonic/gin"
 )
@@ -95,6 +97,124 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
 	return nil, &usage
 }
 
+func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	scanner := bufio.NewScanner(resp.Body)
+	scanner.Split(bufio.ScanLines)
+	helper.SetEventStreamHeaders(c)
+	id := helper.GetResponseID(c)
+	var responseText string
+
+	var currentEvent string
+	var currentData string
+	var usage dto.Usage
+
+	for scanner.Scan() {
+		line := scanner.Text()
+
+		if line == "" {
+			if currentEvent != "" && currentData != "" {
+				// handle last event
+				handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
+				currentEvent = ""
+				currentData = ""
+			}
+			continue
+		}
+
+		if strings.HasPrefix(line, "event:") {
+			currentEvent = strings.TrimSpace(line[6:])
+			continue
+		}
+
+		if strings.HasPrefix(line, "data:") {
+			currentData = strings.TrimSpace(line[5:])
+			continue
+		}
+	}
+
+	// Last event
+	if currentEvent != "" && currentData != "" {
+		handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
+	}
+
+	if err := scanner.Err(); err != nil {
+		return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
+	}
+	helper.Done(c)
+
+	if usage.TotalTokens == 0 {
+		usage.PromptTokens = info.PromptTokens
+		usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
+		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+	}
+
+	return nil, &usage
+}
+
+func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
+	switch event {
+	case "conversation.chat.completed":
+		// 将 data 解析为 CozeChatResponseData
+		var chatData CozeChatResponseData
+		err := json.Unmarshal([]byte(data), &chatData)
+		if err != nil {
+			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			return
+		}
+
+		usage.PromptTokens = chatData.Usage.InputCount
+		usage.CompletionTokens = chatData.Usage.OutputCount
+		usage.TotalTokens = chatData.Usage.TokenCount
+
+		finishReason := "stop"
+		stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
+		helper.ObjectData(c, stopResponse)
+
+	case "conversation.message.delta":
+		// 将 data 解析为 CozeChatV3MessageDetail
+		var messageData CozeChatV3MessageDetail
+		err := json.Unmarshal([]byte(data), &messageData)
+		if err != nil {
+			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			return
+		}
+
+		var content string
+		err = json.Unmarshal(messageData.Content, &content)
+		if err != nil {
+			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			return
+		}
+
+		*responseText += content
+
+		openaiResponse := dto.ChatCompletionsStreamResponse{
+			Id:      id,
+			Object:  "chat.completion.chunk",
+			Created: common.GetTimestamp(),
+			Model:   info.UpstreamModelName,
+		}
+
+		choice := dto.ChatCompletionsStreamResponseChoice{
+			Index: 0,
+		}
+		choice.Delta.SetContentString(content)
+		openaiResponse.Choices = append(openaiResponse.Choices, choice)
+
+		helper.ObjectData(c, openaiResponse)
+
+	case "error":
+		var errorData CozeError
+		err := json.Unmarshal([]byte(data), &errorData)
+		if err != nil {
+			common.SysError("error_unmarshalling_stream_response: " + err.Error())
+			return
+		}
+
+		common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
+	}
+}
+
 func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
 	requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
 
@@ -160,7 +280,7 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht
 	return resp, nil
 }
 
-func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) {
+func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
 	var client *http.Client
 	var err error // 声明 err 变量
 	if proxyURL, ok := info.ChannelSetting["proxy"]; ok {