|
|
@@ -14,38 +14,33 @@ import (
|
|
|
relayconstant "one-api/relay/constant"
|
|
|
"one-api/service"
|
|
|
"strings"
|
|
|
- "sync"
|
|
|
"time"
|
|
|
)
|
|
|
|
|
|
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
|
|
|
- //checkSensitive := constant.ShouldCheckCompletionSensitive()
|
|
|
+ hasStreamUsage := false
|
|
|
+ responseId := ""
|
|
|
+ var createAt int64 = 0
|
|
|
+ var systemFingerprint string
|
|
|
+
|
|
|
var responseTextBuilder strings.Builder
|
|
|
- var usage dto.Usage
|
|
|
+ var usage = &dto.Usage{}
|
|
|
toolCount := 0
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
|
- scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
|
- if atEOF && len(data) == 0 {
|
|
|
- return 0, nil, nil
|
|
|
- }
|
|
|
- if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
|
- return i + 1, data[0:i], nil
|
|
|
- }
|
|
|
- if atEOF {
|
|
|
- return len(data), data, nil
|
|
|
- }
|
|
|
- return 0, nil, nil
|
|
|
- })
|
|
|
- dataChan := make(chan string, 5)
|
|
|
+ scanner.Split(bufio.ScanLines)
|
|
|
+ var streamItems []string // store stream items
|
|
|
+
|
|
|
+ service.SetEventStreamHeaders(c)
|
|
|
+
|
|
|
+ ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
|
|
|
+ defer ticker.Stop()
|
|
|
+
|
|
|
stopChan := make(chan bool, 2)
|
|
|
defer close(stopChan)
|
|
|
- defer close(dataChan)
|
|
|
- var wg sync.WaitGroup
|
|
|
+
|
|
|
go func() {
|
|
|
- wg.Add(1)
|
|
|
- defer wg.Done()
|
|
|
- var streamItems []string // store stream items
|
|
|
for scanner.Scan() {
|
|
|
+ ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
|
|
data := scanner.Text()
|
|
|
if len(data) < 6 { // ignore blank line or wrong format
|
|
|
continue
|
|
|
@@ -53,54 +48,42 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|
|
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
|
|
continue
|
|
|
}
|
|
|
- if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
|
|
|
- // send data timeout, stop the stream
|
|
|
- common.LogError(c, "send data timeout, stop the stream")
|
|
|
- break
|
|
|
- }
|
|
|
data = data[6:]
|
|
|
if !strings.HasPrefix(data, "[DONE]") {
|
|
|
+ service.StringData(c, data)
|
|
|
streamItems = append(streamItems, data)
|
|
|
}
|
|
|
}
|
|
|
- // 计算token
|
|
|
- streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
|
|
- switch info.RelayMode {
|
|
|
- case relayconstant.RelayModeChatCompletions:
|
|
|
- var streamResponses []dto.ChatCompletionsStreamResponseSimple
|
|
|
- err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
|
|
- if err != nil {
|
|
|
- // 一次性解析失败,逐个解析
|
|
|
- common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
- for _, item := range streamItems {
|
|
|
- var streamResponse dto.ChatCompletionsStreamResponseSimple
|
|
|
- err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
|
|
- if err == nil {
|
|
|
- if streamResponse.Usage != nil {
|
|
|
- if streamResponse.Usage.TotalTokens != 0 {
|
|
|
- usage = *streamResponse.Usage
|
|
|
- }
|
|
|
- }
|
|
|
- for _, choice := range streamResponse.Choices {
|
|
|
- responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
|
|
- if choice.Delta.ToolCalls != nil {
|
|
|
- if len(choice.Delta.ToolCalls) > toolCount {
|
|
|
- toolCount = len(choice.Delta.ToolCalls)
|
|
|
- }
|
|
|
- for _, tool := range choice.Delta.ToolCalls {
|
|
|
- responseTextBuilder.WriteString(tool.Function.Name)
|
|
|
- responseTextBuilder.WriteString(tool.Function.Arguments)
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- } else {
|
|
|
- for _, streamResponse := range streamResponses {
|
|
|
- if streamResponse.Usage != nil {
|
|
|
- if streamResponse.Usage.TotalTokens != 0 {
|
|
|
- usage = *streamResponse.Usage
|
|
|
- }
|
|
|
+ stopChan <- true
|
|
|
+ }()
|
|
|
+
|
|
|
+ select {
|
|
|
+ case <-ticker.C:
|
|
|
+ // 超时处理逻辑
|
|
|
+ common.LogError(c, "streaming timeout")
|
|
|
+ case <-stopChan:
|
|
|
+ // 正常结束
|
|
|
+ }
|
|
|
+
|
|
|
+ // 计算token
|
|
|
+ streamResp := "[" + strings.Join(streamItems, ",") + "]"
|
|
|
+ switch info.RelayMode {
|
|
|
+ case relayconstant.RelayModeChatCompletions:
|
|
|
+ var streamResponses []dto.ChatCompletionsStreamResponse
|
|
|
+ err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
|
|
+ if err != nil {
|
|
|
+ // 一次性解析失败,逐个解析
|
|
|
+ common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
+ for _, item := range streamItems {
|
|
|
+ var streamResponse dto.ChatCompletionsStreamResponse
|
|
|
+ err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
|
|
+ if err == nil {
|
|
|
+ responseId = streamResponse.Id
|
|
|
+ createAt = streamResponse.Created
|
|
|
+ systemFingerprint = streamResponse.GetSystemFingerprint()
|
|
|
+ if service.ValidUsage(streamResponse.Usage) {
|
|
|
+ usage = streamResponse.Usage
|
|
|
+ hasStreamUsage = true
|
|
|
}
|
|
|
for _, choice := range streamResponse.Choices {
|
|
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
|
|
@@ -116,67 +99,71 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- case relayconstant.RelayModeCompletions:
|
|
|
- var streamResponses []dto.CompletionsStreamResponse
|
|
|
- err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
|
|
- if err != nil {
|
|
|
- // 一次性解析失败,逐个解析
|
|
|
- common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
- for _, item := range streamItems {
|
|
|
- var streamResponse dto.CompletionsStreamResponse
|
|
|
- err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
|
|
- if err == nil {
|
|
|
- for _, choice := range streamResponse.Choices {
|
|
|
- responseTextBuilder.WriteString(choice.Text)
|
|
|
+ } else {
|
|
|
+ for _, streamResponse := range streamResponses {
|
|
|
+ responseId = streamResponse.Id
|
|
|
+ createAt = streamResponse.Created
|
|
|
+ systemFingerprint = streamResponse.GetSystemFingerprint()
|
|
|
+ if service.ValidUsage(streamResponse.Usage) {
|
|
|
+ usage = streamResponse.Usage
|
|
|
+ hasStreamUsage = true
|
|
|
+ }
|
|
|
+ for _, choice := range streamResponse.Choices {
|
|
|
+ responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
|
|
+ if choice.Delta.ToolCalls != nil {
|
|
|
+ if len(choice.Delta.ToolCalls) > toolCount {
|
|
|
+ toolCount = len(choice.Delta.ToolCalls)
|
|
|
+ }
|
|
|
+ for _, tool := range choice.Delta.ToolCalls {
|
|
|
+ responseTextBuilder.WriteString(tool.Function.Name)
|
|
|
+ responseTextBuilder.WriteString(tool.Function.Arguments)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- } else {
|
|
|
- for _, streamResponse := range streamResponses {
|
|
|
+ }
|
|
|
+ }
|
|
|
+ case relayconstant.RelayModeCompletions:
|
|
|
+ var streamResponses []dto.CompletionsStreamResponse
|
|
|
+ err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
|
|
|
+ if err != nil {
|
|
|
+ // 一次性解析失败,逐个解析
|
|
|
+ common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
+ for _, item := range streamItems {
|
|
|
+ var streamResponse dto.CompletionsStreamResponse
|
|
|
+ err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
|
|
+ if err == nil {
|
|
|
for _, choice := range streamResponse.Choices {
|
|
|
responseTextBuilder.WriteString(choice.Text)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- }
|
|
|
- if len(dataChan) > 0 {
|
|
|
- // wait data out
|
|
|
- time.Sleep(2 * time.Second)
|
|
|
- }
|
|
|
- common.SafeSendBool(stopChan, true)
|
|
|
- }()
|
|
|
- service.SetEventStreamHeaders(c)
|
|
|
- isFirst := true
|
|
|
- ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
|
|
|
- defer ticker.Stop()
|
|
|
- c.Stream(func(w io.Writer) bool {
|
|
|
- select {
|
|
|
- case <-ticker.C:
|
|
|
- common.LogError(c, "reading data from upstream timeout")
|
|
|
- return false
|
|
|
- case data := <-dataChan:
|
|
|
- if isFirst {
|
|
|
- isFirst = false
|
|
|
- info.FirstResponseTime = time.Now()
|
|
|
- }
|
|
|
- ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
|
|
|
- if strings.HasPrefix(data, "data: [DONE]") {
|
|
|
- data = data[:12]
|
|
|
+ } else {
|
|
|
+ for _, streamResponse := range streamResponses {
|
|
|
+ for _, choice := range streamResponse.Choices {
|
|
|
+ responseTextBuilder.WriteString(choice.Text)
|
|
|
+ }
|
|
|
}
|
|
|
- // some implementations may add \r at the end of data
|
|
|
- data = strings.TrimSuffix(data, "\r")
|
|
|
- c.Render(-1, common.CustomEvent{Data: data})
|
|
|
- return true
|
|
|
- case <-stopChan:
|
|
|
- return false
|
|
|
}
|
|
|
- })
|
|
|
+ }
|
|
|
+
|
|
|
+ if !hasStreamUsage {
|
|
|
+ usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
|
|
+ usage.CompletionTokens += toolCount * 7
|
|
|
+ }
|
|
|
+
|
|
|
+ if info.ShouldIncludeUsage && !hasStreamUsage {
|
|
|
+ response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage)
|
|
|
+ response.SetSystemFingerprint(systemFingerprint)
|
|
|
+ service.ObjectData(c, response)
|
|
|
+ }
|
|
|
+
|
|
|
+ service.Done(c)
|
|
|
+
|
|
|
err := resp.Body.Close()
|
|
|
if err != nil {
|
|
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
|
|
|
}
|
|
|
- wg.Wait()
|
|
|
- return nil, &usage, responseTextBuilder.String(), toolCount
|
|
|
+ return nil, usage, responseTextBuilder.String(), toolCount
|
|
|
}
|
|
|
|
|
|
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|