Browse Source

add OaiResponsesStreamHandler

creamlike1024 8 months ago
parent
commit
1236fa8fe4
3 changed files with 70 additions and 2 deletions
  1. 7 1
      dto/openai_response.go
  2. 1 1
      relay/channel/openai/adaptor.go
  3. 62 0
      relay/channel/openai/relay-openai.go

+ 7 - 1
dto/openai_response.go

@@ -214,7 +214,7 @@ type OpenAIResponsesResponse struct {
 	Tools              []interface{}      `json:"tools"`
 	TopP               float64            `json:"top_p"`
 	Truncation         string             `json:"truncation"`
-	Usage              Usage              `json:"usage"`
+	Usage              *Usage             `json:"usage"`
 	User               json.RawMessage    `json:"user"`
 	Metadata           json.RawMessage    `json:"metadata"`
 }
@@ -236,3 +236,9 @@ type ResponsesOutputContent struct {
 	Text        string        `json:"text"`
 	Annotations []interface{} `json:"annotations"`
 }
+
+// ResponsesStreamResponse 用于处理 /v1/responses 流式响应
+type ResponsesStreamResponse struct {
+	Type     string                   `json:"type"`
+	Response *OpenAIResponsesResponse `json:"response"`
+}

+ 1 - 1
relay/channel/openai/adaptor.go

@@ -427,7 +427,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, usage = common_handler.RerankHandler(c, info, resp)
 	case constant.RelayModeResponses:
 		if info.IsStream {
-			err, usage = OaiStreamHandler(c, resp, info)
+			err, usage = OaiResponsesStreamHandler(c, resp, info)
 		} else {
 			err, usage = OpenaiResponsesHandler(c, resp, info)
 		}

+ 62 - 0
relay/channel/openai/relay-openai.go

@@ -694,3 +694,65 @@ func OpenaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycomm
 	usage.TotalTokens = responsesResponse.Usage.TotalTokens
 	return nil, &usage
 }
+
+func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	if resp == nil || resp.Body == nil {
+		common.LogError(c, "invalid response or response body")
+		return service.OpenAIErrorWrapper(fmt.Errorf("invalid response"), "invalid_response", http.StatusInternalServerError), nil
+	}
+
+	var usage = &dto.Usage{}
+	var streamItems []string // 存储流式数据项
+	// var responseTextBuilder strings.Builder
+	// var toolCount int
+	var forceFormat bool
+
+	if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
+		forceFormat = forceFmt
+	}
+
+	var lastStreamData string
+
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+		if lastStreamData != "" {
+			// 处理上一条数据
+			sendResponsesStreamData(c, lastStreamData, forceFormat)
+		}
+		lastStreamData = data
+		streamItems = append(streamItems, data)
+
+		// 检查当前数据是否包含 completed 状态和 usage 信息
+		var streamResponse dto.ResponsesStreamResponse
+		if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
+			if streamResponse.Type == "response.completed" {
+				// 处理 completed 状态
+				usage.PromptTokens = streamResponse.Response.Usage.InputTokens
+				usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
+				usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
+			}
+		}
+		return true
+	})
+
+	// 处理最后一条数据
+	sendResponsesStreamData(c, lastStreamData, forceFormat)
+
+	// 处理token计算
+	// if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
+	// 	common.SysError("error processing tokens: " + err.Error())
+	// }
+
+	return nil, usage
+}
+
+func sendResponsesStreamData(c *gin.Context, data string, forceFormat bool) error {
+	if data == "" {
+		return nil
+	}
+
+	if forceFormat {
+		return helper.ObjectData(c, data)
+	} else {
+		return helper.StringData(c, data)
+	}
+}