Browse Source

Merge pull request #20 from Calcium-Ion/optimize/hign--cpu

fix: 修复客户端中断请求,计算补全阻塞问题
Calcium-Ion 2 years ago
parent
commit
e095900d88
1 changed files with 45 additions and 25 deletions
  1. 45 25
      controller/relay-openai.go

+ 45 - 25
controller/relay-openai.go

@@ -9,10 +9,12 @@ import (
 	"net/http"
 	"one-api/common"
 	"strings"
+	"sync"
+	"time"
 )
 
 func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
-	responseText := ""
+	var responseTextBuilder strings.Builder
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 		if atEOF && len(data) == 0 {
@@ -26,9 +28,16 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 		}
 		return 0, nil, nil
 	})
-	dataChan := make(chan string)
-	stopChan := make(chan bool)
+	dataChan := make(chan string, 5)
+	stopChan := make(chan bool, 2)
+	defer close(stopChan)
+	defer close(dataChan)
+	var wg sync.WaitGroup
+
 	go func() {
+		wg.Add(1)
+		defer wg.Done()
+		var streamItems []string
 		for scanner.Scan() {
 			data := scanner.Text()
 			if len(data) < 6 { // ignore blank line or wrong format
@@ -40,29 +49,39 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 			dataChan <- data
 			data = data[6:]
 			if !strings.HasPrefix(data, "[DONE]") {
-				switch relayMode {
-				case RelayModeChatCompletions:
-					var streamResponse ChatCompletionsStreamResponseSimple
-					err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse)
-					if err != nil {
-						common.SysError("error unmarshalling stream response: " + err.Error())
-						continue // just ignore the error
-					}
-					for _, choice := range streamResponse.Choices {
-						responseText += choice.Delta.Content
-					}
-				case RelayModeCompletions:
-					var streamResponse CompletionsStreamResponse
-					err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse)
-					if err != nil {
-						common.SysError("error unmarshalling stream response: " + err.Error())
-						continue
-					}
-					for _, choice := range streamResponse.Choices {
-						responseText += choice.Text
-					}
+				streamItems = append(streamItems, data)
+			}
+		}
+		streamResp := "[" + strings.Join(streamItems, ",") + "]"
+		switch relayMode {
+		case RelayModeChatCompletions:
+			var streamResponses []ChatCompletionsStreamResponseSimple
+			err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
+			if err != nil {
+				common.SysError("error unmarshalling stream response: " + err.Error())
+				return // just ignore the error
+			}
+			for _, streamResponse := range streamResponses {
+				for _, choice := range streamResponse.Choices {
+					responseTextBuilder.WriteString(choice.Delta.Content)
 				}
 			}
+		case RelayModeCompletions:
+			var streamResponses []CompletionsStreamResponse
+			err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
+			if err != nil {
+				common.SysError("error unmarshalling stream response: " + err.Error())
+				return // just ignore the error
+			}
+			for _, streamResponse := range streamResponses {
+				for _, choice := range streamResponse.Choices {
+					responseTextBuilder.WriteString(choice.Text)
+				}
+			}
+		}
+		if len(dataChan) > 0 {
+			// wait data out
+			time.Sleep(2 * time.Second)
 		}
 		stopChan <- true
 	}()
@@ -85,7 +104,8 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 	if err != nil {
 		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
 	}
-	return nil, responseText
+	wg.Wait()
+	return nil, responseTextBuilder.String()
 }
 
 func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {