|
|
@@ -1,16 +1,18 @@
|
|
|
package coze
|
|
|
|
|
|
import (
|
|
|
+ "bufio"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
+ "one-api/common"
|
|
|
"one-api/dto"
|
|
|
- "one-api/relay/common"
|
|
|
relaycommon "one-api/relay/common"
|
|
|
"one-api/relay/helper"
|
|
|
"one-api/service"
|
|
|
+ "strings"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
|
@@ -95,6 +97,124 @@ func cozeChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|
|
return nil, &usage
|
|
|
}
|
|
|
|
|
|
+func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
+ scanner := bufio.NewScanner(resp.Body)
|
|
|
+ scanner.Split(bufio.ScanLines)
|
|
|
+ helper.SetEventStreamHeaders(c)
|
|
|
+ id := helper.GetResponseID(c)
|
|
|
+ var responseText string
|
|
|
+
|
|
|
+ var currentEvent string
|
|
|
+ var currentData string
|
|
|
+ var usage dto.Usage
|
|
|
+
|
|
|
+ for scanner.Scan() {
|
|
|
+ line := scanner.Text()
|
|
|
+
|
|
|
+ if line == "" {
|
|
|
+ if currentEvent != "" && currentData != "" {
|
|
|
+ // handle last event
|
|
|
+ handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
|
|
+ currentEvent = ""
|
|
|
+ currentData = ""
|
|
|
+ }
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if strings.HasPrefix(line, "event:") {
|
|
|
+ currentEvent = strings.TrimSpace(line[6:])
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if strings.HasPrefix(line, "data:") {
|
|
|
+ currentData = strings.TrimSpace(line[5:])
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Last event
|
|
|
+ if currentEvent != "" && currentData != "" {
|
|
|
+ handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := scanner.Err(); err != nil {
|
|
|
+ return service.OpenAIErrorWrapper(err, "stream_scanner_error", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ helper.Done(c)
|
|
|
+
|
|
|
+ if usage.TotalTokens == 0 {
|
|
|
+ usage.PromptTokens = info.PromptTokens
|
|
|
+ usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
|
|
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil, &usage
|
|
|
+}
|
|
|
+
|
|
|
+func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
|
|
+ switch event {
|
|
|
+ case "conversation.chat.completed":
|
|
|
+ // 将 data 解析为 CozeChatResponseData
|
|
|
+ var chatData CozeChatResponseData
|
|
|
+ err := json.Unmarshal([]byte(data), &chatData)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ usage.PromptTokens = chatData.Usage.InputCount
|
|
|
+ usage.CompletionTokens = chatData.Usage.OutputCount
|
|
|
+ usage.TotalTokens = chatData.Usage.TokenCount
|
|
|
+
|
|
|
+ finishReason := "stop"
|
|
|
+ stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
|
|
|
+ helper.ObjectData(c, stopResponse)
|
|
|
+
|
|
|
+ case "conversation.message.delta":
|
|
|
+ // 将 data 解析为 CozeChatV3MessageDetail
|
|
|
+ var messageData CozeChatV3MessageDetail
|
|
|
+ err := json.Unmarshal([]byte(data), &messageData)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ var content string
|
|
|
+ err = json.Unmarshal(messageData.Content, &content)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ *responseText += content
|
|
|
+
|
|
|
+ openaiResponse := dto.ChatCompletionsStreamResponse{
|
|
|
+ Id: id,
|
|
|
+ Object: "chat.completion.chunk",
|
|
|
+ Created: common.GetTimestamp(),
|
|
|
+ Model: info.UpstreamModelName,
|
|
|
+ }
|
|
|
+
|
|
|
+ choice := dto.ChatCompletionsStreamResponseChoice{
|
|
|
+ Index: 0,
|
|
|
+ }
|
|
|
+ choice.Delta.SetContentString(content)
|
|
|
+ openaiResponse.Choices = append(openaiResponse.Choices, choice)
|
|
|
+
|
|
|
+ helper.ObjectData(c, openaiResponse)
|
|
|
+
|
|
|
+ case "error":
|
|
|
+ var errorData CozeError
|
|
|
+ err := json.Unmarshal([]byte(data), &errorData)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error_unmarshalling_stream_response: " + err.Error())
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
|
|
|
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
|
|
|
|
|
|
@@ -160,7 +280,7 @@ func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*ht
|
|
|
return resp, nil
|
|
|
}
|
|
|
|
|
|
-func doRequest(req *http.Request, info *common.RelayInfo) (*http.Response, error) {
|
|
|
+func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
|
|
|
var client *http.Client
|
|
|
var err error // 声明 err 变量
|
|
|
if proxyURL, ok := info.ChannelSetting["proxy"]; ok {
|