Просмотр исходного кода

Merge remote-tracking branch 'origin/alpha' into alpha

t0ng7u 6 месяцев назад
Родитель
Сommit
d34e4f1f28
4 измененных файлов с 19 добавлено и 34 удалено
  1. 4 11
      relay/channel/openai/helper.go
  2. 13 21
      relay/helper/common.go
  3. 1 1
      relay/helper/valid_request.go
  4. 1 1
      relay/image_handler.go

+ 4 - 11
relay/channel/openai/helper.go

@@ -2,9 +2,6 @@ package openai
 
 import (
 	"encoding/json"
-	"errors"
-	"github.com/samber/lo"
-	"net/http"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/logger"
@@ -15,6 +12,8 @@ import (
 	"one-api/types"
 	"strings"
 
+	"github.com/samber/lo"
+
 	"github.com/gin-gonic/gin"
 )
 
@@ -71,11 +70,7 @@ func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
 
 	// send gemini format response
 	c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	} else {
-		return errors.New("streaming error: flusher not found")
-	}
+	_ = helper.FlushWriter(c)
 	return nil
 }
 
@@ -253,9 +248,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
 
 		// 发送最终的 Gemini 响应
 		c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
-		if flusher, ok := c.Writer.(http.Flusher); ok {
-			flusher.Flush()
-		}
+		_ = helper.FlushWriter(c)
 	}
 }
 

+ 13 - 21
relay/helper/common.go

@@ -14,6 +14,14 @@ import (
 	"github.com/gorilla/websocket"
 )
 
+func FlushWriter(c *gin.Context) error {
+	if flusher, ok := c.Writer.(http.Flusher); ok {
+		flusher.Flush()
+		return nil
+	}
+	return errors.New("streaming error: flusher not found")
+}
+
 func SetEventStreamHeaders(c *gin.Context) {
 	// 检查是否已经设置过头部
 	if _, exists := c.Get("event_stream_headers_set"); exists {
@@ -38,49 +46,33 @@ func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {
 		c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
 		c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonData)})
 	}
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	} else {
-		return errors.New("streaming error: flusher not found")
-	}
+	_ = FlushWriter(c)
 	return nil
 }
 
 func ClaudeChunkData(c *gin.Context, resp dto.ClaudeResponse, data string) {
 	c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
 	c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s\n", data)})
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	}
+	_ = FlushWriter(c)
 }
 
 func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data string) {
 	c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("event: %s\n", resp.Type)})
 	c.Render(-1, common.CustomEvent{Data: fmt.Sprintf("data: %s", data)})
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	}
+	_ = FlushWriter(c)
 }
 
 func StringData(c *gin.Context, str string) error {
 	//str = strings.TrimPrefix(str, "data: ")
 	//str = strings.TrimSuffix(str, "\r")
 	c.Render(-1, common.CustomEvent{Data: "data: " + str})
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	} else {
-		return errors.New("streaming error: flusher not found")
-	}
+	_ = FlushWriter(c)
 	return nil
 }
 
 func PingData(c *gin.Context) error {
 	c.Writer.Write([]byte(": PING\n\n"))
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-	} else {
-		return errors.New("streaming error: flusher not found")
-	}
+	_ = FlushWriter(c)
 	return nil
 }
 

+ 1 - 1
relay/helper/valid_request.go

@@ -134,7 +134,7 @@ func GetAndValidOpenAIImageRequest(c *gin.Context, relayMode int) (*dto.ImageReq
 	case relayconstant.RelayModeImagesEdits:
 		_, err := c.MultipartForm()
 		if err != nil {
-			return nil, err
+			return nil, fmt.Errorf("failed to parse image edit form request: %w", err)
 		}
 		formData := c.Request.PostForm
 		imageRequest.Prompt = formData.Get("prompt")

+ 1 - 1
relay/image_handler.go

@@ -51,7 +51,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 	} else {
 		convertedRequest, err := adaptor.ConvertImageRequest(c, info, *imageRequest)
 		if err != nil {
-			return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+			return types.NewError(err, types.ErrorCodeConvertRequestFailed)
 		}
 		if info.RelayMode == relayconstant.RelayModeImagesEdits {
 			requestBody = convertedRequest.(io.Reader)