Parcourir la source

refactor: 重构流模式逻辑

CalciumIon il y a 1 an
Parent
commit
7029065892

+ 11 - 4
dto/text_response.go

@@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct {
 	ToolCalls []ToolCall `json:"tool_calls,omitempty"`
 }
 
-func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool {
-	return c.Content == nil && len(c.ToolCalls) == 0
-}
-
 func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
 	c.Content = &s
 }
@@ -105,6 +101,17 @@ type ChatCompletionsStreamResponse struct {
 	Usage             *Usage                                `json:"usage"`
 }
 
+func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string {
+	if c.SystemFingerprint == nil {
+		return ""
+	}
+	return *c.SystemFingerprint
+}
+
+func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) {
+	c.SystemFingerprint = &s
+}
+
 type ChatCompletionsStreamResponseSimple struct {
 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
 	Usage   *Usage                                `json:"usage"`

+ 1 - 8
relay/channel/openai/adaptor.go

@@ -14,7 +14,6 @@ import (
 	"one-api/relay/channel/minimax"
 	"one-api/relay/channel/moonshot"
 	relaycommon "one-api/relay/common"
-	"one-api/service"
 	"strings"
 )
 
@@ -90,13 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 	if info.IsStream {
-		var responseText string
-		var toolCount int
-		err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info)
-		if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 {
-			usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
-			usage.CompletionTokens += toolCount * 7
-		}
+		err, usage, _, _ = OpenaiStreamHandler(c, resp, info)
 	} else {
 		err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 	}

+ 98 - 111
relay/channel/openai/relay-openai.go

@@ -14,38 +14,33 @@ import (
 	relayconstant "one-api/relay/constant"
 	"one-api/service"
 	"strings"
-	"sync"
 	"time"
 )
 
 func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) {
-	//checkSensitive := constant.ShouldCheckCompletionSensitive()
+	hasStreamUsage := false
+	responseId := ""
+	var createAt int64 = 0
+	var systemFingerprint string
+
 	var responseTextBuilder strings.Builder
-	var usage dto.Usage
+	var usage = &dto.Usage{}
 	toolCount := 0
 	scanner := bufio.NewScanner(resp.Body)
-	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
-		if atEOF && len(data) == 0 {
-			return 0, nil, nil
-		}
-		if i := strings.Index(string(data), "\n"); i >= 0 {
-			return i + 1, data[0:i], nil
-		}
-		if atEOF {
-			return len(data), data, nil
-		}
-		return 0, nil, nil
-	})
-	dataChan := make(chan string, 5)
+	scanner.Split(bufio.ScanLines)
+	var streamItems []string // store stream items
+
+	service.SetEventStreamHeaders(c)
+
+	ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
+	defer ticker.Stop()
+
 	stopChan := make(chan bool, 2)
 	defer close(stopChan)
-	defer close(dataChan)
-	var wg sync.WaitGroup
+
 	go func() {
-		wg.Add(1)
-		defer wg.Done()
-		var streamItems []string // store stream items
 		for scanner.Scan() {
+			ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
 			data := scanner.Text()
 			if len(data) < 6 { // ignore blank line or wrong format
 				continue
@@ -53,54 +48,42 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 			if data[:6] != "data: " && data[:6] != "[DONE]" {
 				continue
 			}
-			if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) {
-				// send data timeout, stop the stream
-				common.LogError(c, "send data timeout, stop the stream")
-				break
-			}
 			data = data[6:]
 			if !strings.HasPrefix(data, "[DONE]") {
+				service.StringData(c, data)
 				streamItems = append(streamItems, data)
 			}
 		}
-		// 计算token
-		streamResp := "[" + strings.Join(streamItems, ",") + "]"
-		switch info.RelayMode {
-		case relayconstant.RelayModeChatCompletions:
-			var streamResponses []dto.ChatCompletionsStreamResponseSimple
-			err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
-			if err != nil {
-				// 一次性解析失败,逐个解析
-				common.SysError("error unmarshalling stream response: " + err.Error())
-				for _, item := range streamItems {
-					var streamResponse dto.ChatCompletionsStreamResponseSimple
-					err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
-					if err == nil {
-						if streamResponse.Usage != nil {
-							if streamResponse.Usage.TotalTokens != 0 {
-								usage = *streamResponse.Usage
-							}
-						}
-						for _, choice := range streamResponse.Choices {
-							responseTextBuilder.WriteString(choice.Delta.GetContentString())
-							if choice.Delta.ToolCalls != nil {
-								if len(choice.Delta.ToolCalls) > toolCount {
-									toolCount = len(choice.Delta.ToolCalls)
-								}
-								for _, tool := range choice.Delta.ToolCalls {
-									responseTextBuilder.WriteString(tool.Function.Name)
-									responseTextBuilder.WriteString(tool.Function.Arguments)
-								}
-							}
-						}
-					}
-				}
-			} else {
-				for _, streamResponse := range streamResponses {
-					if streamResponse.Usage != nil {
-						if streamResponse.Usage.TotalTokens != 0 {
-							usage = *streamResponse.Usage
-						}
+		stopChan <- true
+	}()
+
+	select {
+	case <-ticker.C:
+		// 超时处理逻辑
+		common.LogError(c, "streaming timeout")
+	case <-stopChan:
+		// 正常结束
+	}
+
+	// 计算token
+	streamResp := "[" + strings.Join(streamItems, ",") + "]"
+	switch info.RelayMode {
+	case relayconstant.RelayModeChatCompletions:
+		var streamResponses []dto.ChatCompletionsStreamResponse
+		err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
+		if err != nil {
+			// 一次性解析失败,逐个解析
+			common.SysError("error unmarshalling stream response: " + err.Error())
+			for _, item := range streamItems {
+				var streamResponse dto.ChatCompletionsStreamResponse
+				err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
+				if err == nil {
+					responseId = streamResponse.Id
+					createAt = streamResponse.Created
+					systemFingerprint = streamResponse.GetSystemFingerprint()
+					if service.ValidUsage(streamResponse.Usage) {
+						usage = streamResponse.Usage
+						hasStreamUsage = true
 					}
 					for _, choice := range streamResponse.Choices {
 						responseTextBuilder.WriteString(choice.Delta.GetContentString())
@@ -116,67 +99,71 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 					}
 				}
 			}
-		case relayconstant.RelayModeCompletions:
-			var streamResponses []dto.CompletionsStreamResponse
-			err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
-			if err != nil {
-				// 一次性解析失败,逐个解析
-				common.SysError("error unmarshalling stream response: " + err.Error())
-				for _, item := range streamItems {
-					var streamResponse dto.CompletionsStreamResponse
-					err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
-					if err == nil {
-						for _, choice := range streamResponse.Choices {
-							responseTextBuilder.WriteString(choice.Text)
+		} else {
+			for _, streamResponse := range streamResponses {
+				responseId = streamResponse.Id
+				createAt = streamResponse.Created
+				systemFingerprint = streamResponse.GetSystemFingerprint()
+				if service.ValidUsage(streamResponse.Usage) {
+					usage = streamResponse.Usage
+					hasStreamUsage = true
+				}
+				for _, choice := range streamResponse.Choices {
+					responseTextBuilder.WriteString(choice.Delta.GetContentString())
+					if choice.Delta.ToolCalls != nil {
+						if len(choice.Delta.ToolCalls) > toolCount {
+							toolCount = len(choice.Delta.ToolCalls)
+						}
+						for _, tool := range choice.Delta.ToolCalls {
+							responseTextBuilder.WriteString(tool.Function.Name)
+							responseTextBuilder.WriteString(tool.Function.Arguments)
 						}
 					}
 				}
-			} else {
-				for _, streamResponse := range streamResponses {
+			}
+		}
+	case relayconstant.RelayModeCompletions:
+		var streamResponses []dto.CompletionsStreamResponse
+		err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses)
+		if err != nil {
+			// 一次性解析失败,逐个解析
+			common.SysError("error unmarshalling stream response: " + err.Error())
+			for _, item := range streamItems {
+				var streamResponse dto.CompletionsStreamResponse
+				err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
+				if err == nil {
 					for _, choice := range streamResponse.Choices {
 						responseTextBuilder.WriteString(choice.Text)
 					}
 				}
 			}
-		}
-		if len(dataChan) > 0 {
-			// wait data out
-			time.Sleep(2 * time.Second)
-		}
-		common.SafeSendBool(stopChan, true)
-	}()
-	service.SetEventStreamHeaders(c)
-	isFirst := true
-	ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second)
-	defer ticker.Stop()
-	c.Stream(func(w io.Writer) bool {
-		select {
-		case <-ticker.C:
-			common.LogError(c, "reading data from upstream timeout")
-			return false
-		case data := <-dataChan:
-			if isFirst {
-				isFirst = false
-				info.FirstResponseTime = time.Now()
-			}
-			ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second)
-			if strings.HasPrefix(data, "data: [DONE]") {
-				data = data[:12]
+		} else {
+			for _, streamResponse := range streamResponses {
+				for _, choice := range streamResponse.Choices {
+					responseTextBuilder.WriteString(choice.Text)
+				}
 			}
-			// some implementations may add \r at the end of data
-			data = strings.TrimSuffix(data, "\r")
-			c.Render(-1, common.CustomEvent{Data: data})
-			return true
-		case <-stopChan:
-			return false
 		}
-	})
+	}
+
+	if !hasStreamUsage {
+		usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
+		usage.CompletionTokens += toolCount * 7
+	}
+
+	if info.ShouldIncludeUsage && !hasStreamUsage {
+		response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage)
+		response.SetSystemFingerprint(systemFingerprint)
+		service.ObjectData(c, response)
+	}
+
+	service.Done(c)
+
 	err := resp.Body.Close()
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount
 	}
-	wg.Wait()
-	return nil, &usage, responseTextBuilder.String(), toolCount
+	return nil, usage, responseTextBuilder.String(), toolCount
 }
 
 func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {

+ 4 - 0
service/usage_helpr.go

@@ -36,3 +36,7 @@ func GenerateFinalUsageResponse(id string, createAt int64, model string, usage d
 		Usage:             &usage,
 	}
 }
+
+func ValidUsage(usage *dto.Usage) bool {
+	return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
+}