|
|
@@ -26,14 +26,10 @@ func OaiResponsesToChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp
|
|
|
defer service.CloseResponseBodyGracefully(resp)
|
|
|
|
|
|
var responsesResp dto.OpenAIResponsesResponse
|
|
|
- const maxResponseBodyBytes = 10 << 20 // 10MB
|
|
|
- body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes+1))
|
|
|
+ body, err := io.ReadAll(resp.Body)
|
|
|
if err != nil {
|
|
|
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
|
|
|
}
|
|
|
- if int64(len(body)) > maxResponseBodyBytes {
|
|
|
- return nil, types.NewOpenAIError(fmt.Errorf("response body exceeds %d bytes", maxResponseBodyBytes), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
- }
|
|
|
|
|
|
if err := common.Unmarshal(body, &responsesResp); err != nil {
|
|
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
|
|
@@ -77,12 +73,99 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
|
|
|
|
|
var (
|
|
|
usage = &dto.Usage{}
|
|
|
- textBuilder strings.Builder
|
|
|
+ outputText strings.Builder
|
|
|
+ usageText strings.Builder
|
|
|
sentStart bool
|
|
|
sentStop bool
|
|
|
+ sawToolCall bool
|
|
|
streamErr *types.NewAPIError
|
|
|
)
|
|
|
|
|
|
+ toolCallIndexByID := make(map[string]int)
|
|
|
+ toolCallNameByID := make(map[string]string)
|
|
|
+ toolCallArgsByID := make(map[string]string)
|
|
|
+ toolCallNameSent := make(map[string]bool)
|
|
|
+ toolCallCanonicalIDByItemID := make(map[string]string)
|
|
|
+
|
|
|
+ sendStartIfNeeded := func() bool {
|
|
|
+ if sentStart {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
|
|
+ streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ sentStart = true
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ sendToolCallDelta := func(callID string, name string, argsDelta string) bool {
|
|
|
+ if callID == "" {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ if outputText.Len() > 0 {
|
|
|
+ // Prefer streaming assistant text over tool calls to match non-stream behavior.
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ if !sendStartIfNeeded() {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ idx, ok := toolCallIndexByID[callID]
|
|
|
+ if !ok {
|
|
|
+ idx = len(toolCallIndexByID)
|
|
|
+ toolCallIndexByID[callID] = idx
|
|
|
+ }
|
|
|
+ if name != "" {
|
|
|
+ toolCallNameByID[callID] = name
|
|
|
+ }
|
|
|
+ if toolCallNameByID[callID] != "" {
|
|
|
+ name = toolCallNameByID[callID]
|
|
|
+ }
|
|
|
+
|
|
|
+ tool := dto.ToolCallResponse{
|
|
|
+ ID: callID,
|
|
|
+ Type: "function",
|
|
|
+ Function: dto.FunctionResponse{
|
|
|
+ Arguments: argsDelta,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ tool.SetIndex(idx)
|
|
|
+ if name != "" && !toolCallNameSent[callID] {
|
|
|
+ tool.Function.Name = name
|
|
|
+ toolCallNameSent[callID] = true
|
|
|
+ }
|
|
|
+
|
|
|
+ chunk := &dto.ChatCompletionsStreamResponse{
|
|
|
+ Id: responseId,
|
|
|
+ Object: "chat.completion.chunk",
|
|
|
+ Created: createAt,
|
|
|
+ Model: model,
|
|
|
+ Choices: []dto.ChatCompletionsStreamResponseChoice{
|
|
|
+ {
|
|
|
+ Index: 0,
|
|
|
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
|
|
+ ToolCalls: []dto.ToolCallResponse{tool},
|
|
|
+ },
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ if err := helper.ObjectData(c, chunk); err != nil {
|
|
|
+ streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ sawToolCall = true
|
|
|
+
|
|
|
+ // Include tool call data in the local builder for fallback token estimation.
|
|
|
+ if tool.Function.Name != "" {
|
|
|
+ usageText.WriteString(tool.Function.Name)
|
|
|
+ }
|
|
|
+ if argsDelta != "" {
|
|
|
+ usageText.WriteString(argsDelta)
|
|
|
+ }
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
if streamErr != nil {
|
|
|
return false
|
|
|
@@ -106,16 +189,13 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
|
|
}
|
|
|
|
|
|
case "response.output_text.delta":
|
|
|
- if !sentStart {
|
|
|
- if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
|
|
- streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
- return false
|
|
|
- }
|
|
|
- sentStart = true
|
|
|
+ if !sendStartIfNeeded() {
|
|
|
+ return false
|
|
|
}
|
|
|
|
|
|
if streamResp.Delta != "" {
|
|
|
- textBuilder.WriteString(streamResp.Delta)
|
|
|
+ outputText.WriteString(streamResp.Delta)
|
|
|
+ usageText.WriteString(streamResp.Delta)
|
|
|
delta := streamResp.Delta
|
|
|
chunk := &dto.ChatCompletionsStreamResponse{
|
|
|
Id: responseId,
|
|
|
@@ -137,6 +217,59 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ case "response.output_item.added", "response.output_item.done":
|
|
|
+ if streamResp.Item == nil {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ if streamResp.Item.Type != "function_call" {
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ itemID := strings.TrimSpace(streamResp.Item.ID)
|
|
|
+ callID := strings.TrimSpace(streamResp.Item.CallId)
|
|
|
+ if callID == "" {
|
|
|
+ callID = itemID
|
|
|
+ }
|
|
|
+ if itemID != "" && callID != "" {
|
|
|
+ toolCallCanonicalIDByItemID[itemID] = callID
|
|
|
+ }
|
|
|
+ name := strings.TrimSpace(streamResp.Item.Name)
|
|
|
+ if name != "" {
|
|
|
+ toolCallNameByID[callID] = name
|
|
|
+ }
|
|
|
+
|
|
|
+ newArgs := streamResp.Item.Arguments
|
|
|
+ prevArgs := toolCallArgsByID[callID]
|
|
|
+ argsDelta := ""
|
|
|
+ if newArgs != "" {
|
|
|
+ if strings.HasPrefix(newArgs, prevArgs) {
|
|
|
+ argsDelta = newArgs[len(prevArgs):]
|
|
|
+ } else {
|
|
|
+ argsDelta = newArgs
|
|
|
+ }
|
|
|
+ toolCallArgsByID[callID] = newArgs
|
|
|
+ }
|
|
|
+
|
|
|
+ if !sendToolCallDelta(callID, name, argsDelta) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ case "response.function_call_arguments.delta":
|
|
|
+ itemID := strings.TrimSpace(streamResp.ItemID)
|
|
|
+ callID := toolCallCanonicalIDByItemID[itemID]
|
|
|
+ if callID == "" {
|
|
|
+ callID = itemID
|
|
|
+ }
|
|
|
+ if callID == "" {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ toolCallArgsByID[callID] += streamResp.Delta
|
|
|
+ if !sendToolCallDelta(callID, "", streamResp.Delta) {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ case "response.function_call_arguments.done":
|
|
|
+
|
|
|
case "response.completed":
|
|
|
if streamResp.Response != nil {
|
|
|
if streamResp.Response.Model != "" {
|
|
|
@@ -170,15 +303,15 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if !sentStart {
|
|
|
- if err := helper.ObjectData(c, helper.GenerateStartEmptyResponse(responseId, createAt, model, nil)); err != nil {
|
|
|
- streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
- return false
|
|
|
- }
|
|
|
- sentStart = true
|
|
|
+ if !sendStartIfNeeded() {
|
|
|
+ return false
|
|
|
}
|
|
|
if !sentStop {
|
|
|
- stop := helper.GenerateStopResponse(responseId, createAt, model, "stop")
|
|
|
+ finishReason := "stop"
|
|
|
+ if sawToolCall && outputText.Len() == 0 {
|
|
|
+ finishReason = "tool_calls"
|
|
|
+ }
|
|
|
+ stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
|
|
|
if err := helper.ObjectData(c, stop); err != nil {
|
|
|
streamErr = types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
return false
|
|
|
@@ -196,8 +329,6 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
|
|
streamErr = types.NewOpenAIError(fmt.Errorf("responses stream error: %s", streamResp.Type), types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
return false
|
|
|
|
|
|
- case "response.output_item.added", "response.output_item.done":
|
|
|
-
|
|
|
default:
|
|
|
}
|
|
|
|
|
|
@@ -209,7 +340,7 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
|
|
}
|
|
|
|
|
|
if usage.TotalTokens == 0 {
|
|
|
- usage = service.ResponseText2Usage(c, textBuilder.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
|
|
+ usage = service.ResponseText2Usage(c, usageText.String(), info.UpstreamModelName, info.GetEstimatePromptTokens())
|
|
|
}
|
|
|
|
|
|
if !sentStart {
|
|
|
@@ -218,7 +349,11 @@ func OaiResponsesToChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo
|
|
|
}
|
|
|
}
|
|
|
if !sentStop {
|
|
|
- stop := helper.GenerateStopResponse(responseId, createAt, model, "stop")
|
|
|
+ finishReason := "stop"
|
|
|
+ if sawToolCall && outputText.Len() == 0 {
|
|
|
+ finishReason = "tool_calls"
|
|
|
+ }
|
|
|
+ stop := helper.GenerateStopResponse(responseId, createAt, model, finishReason)
|
|
|
if err := helper.ObjectData(c, stop); err != nil {
|
|
|
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError)
|
|
|
}
|