Просмотр исходного кода

feat: enhance OaiResponsesStreamHandler to handle output text and improve response streaming

CaIon 7 месяцев назад
Родитель
Сommit
fe3232bf23
3 измененных файлов с 40 добавлено и 32 удалено
  1. 11 0
      dto/openai_response.go
  2. 21 32
      relay/channel/openai/relay-openai.go
  3. 8 0
      relay/helper/common.go

+ 11 - 0
dto/openai_response.go

@@ -237,8 +237,19 @@ type ResponsesOutputContent struct {
 	Annotations []interface{} `json:"annotations"`
 }
 
+const (
+	BuildInTools_WebSearch  = "web_search_preview"
+	BuildInTools_FileSearch = "file_search"
+)
+
+const (
+	ResponsesOutputTypeItemAdded = "response.output_item.added"
+	ResponsesOutputTypeItemDone  = "response.output_item.done"
+)
+
 // ResponsesStreamResponse 用于处理 /v1/responses 流式响应
 type ResponsesStreamResponse struct {
 	Type     string                   `json:"type"`
 	Response *OpenAIResponsesResponse `json:"response"`
+	Delta    string                   `json:"delta,omitempty"`
 }

+ 21 - 32
relay/channel/openai/relay-openai.go

@@ -702,57 +702,46 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
 	}
 
 	var usage = &dto.Usage{}
-	var streamItems []string // 存储流式数据项
-	// var responseTextBuilder strings.Builder
-	// var toolCount int
-	var forceFormat bool
-
-	if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
-		forceFormat = forceFmt
-	}
-
-	var lastStreamData string
+	var responseTextBuilder strings.Builder
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
-		if lastStreamData != "" {
-			// 处理上一条数据
-			sendResponsesStreamData(c, lastStreamData, forceFormat)
-		}
-		lastStreamData = data
-		streamItems = append(streamItems, data)
 
 		// 检查当前数据是否包含 completed 状态和 usage 信息
 		var streamResponse dto.ResponsesStreamResponse
 		if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
-			if streamResponse.Type == "response.completed" {
-				// 处理 completed 状态
+			sendResponsesStreamData(c, streamResponse, data)
+			switch streamResponse.Type {
+			case "response.completed":
 				usage.PromptTokens = streamResponse.Response.Usage.InputTokens
 				usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
 				usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+			case "response.output_text.delta":
+				// 处理输出文本
+				responseTextBuilder.WriteString(streamResponse.Delta)
+
 			}
 		}
 		return true
 	})
 
-	// 处理最后一条数据
-	sendResponsesStreamData(c, lastStreamData, forceFormat)
+	helper.Done(c)
 
-	// 处理token计算
-	// if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
-	// 	common.SysError("error processing tokens: " + err.Error())
-	// }
+	if usage.CompletionTokens == 0 {
+		// 计算输出文本的 token 数量
+		tempStr := responseTextBuilder.String()
+		if len(tempStr) > 0 {
+			// 非正常结束,使用输出文本的 token 数量
+			completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
+			usage.CompletionTokens = completionTokens
+		}
+	}
 
 	return nil, usage
 }
 
-func sendResponsesStreamData(c *gin.Context, data string, forceFormat bool) error {
+func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
 	if data == "" {
-		return nil
-	}
-
-	if forceFormat {
-		return helper.ObjectData(c, data)
-	} else {
-		return helper.StringData(c, data)
+		return
 	}
+	helper.ResponseChunkData(c, streamResponse, data)
 }

+ 8 - 0
relay/helper/common.go

@@ -43,6 +43,14 @@ func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
 	}
 }
 
+func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
+	c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
+	c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
+	if flusher, ok := c.Writer.(http.Flusher); ok {
+		flusher.Flush()
+	}
+}
+
 func StringData(c *gin.Context, str string) error {
 	//str = strings.TrimPrefix(str, "data: ")
 	//str = strings.TrimSuffix(str, "\r")