Browse Source

fix(helper): improve error handling in FlushWriter and related functions

CaIon 2 weeks ago
parent
commit
b58fa3debc
2 changed files with 42 additions and 15 deletions
  1. 2 2
      relay/channel/gemini/relay-gemini-native.go
  2. 40 13
      relay/helper/common.go

+ 2 - 2
relay/channel/gemini/relay-gemini-native.go

@@ -94,10 +94,10 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
 	helper.SetEventStreamHeaders(c)
 	helper.SetEventStreamHeaders(c)
 
 
 	return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
 	return geminiStreamHandler(c, info, resp, func(data string, geminiResponse *dto.GeminiChatResponse) bool {
-		// 直接发送 GeminiChatResponse 响应
 		err := helper.StringData(c, data)
 		err := helper.StringData(c, data)
 		if err != nil {
 		if err != nil {
-			logger.LogError(c, err.Error())
+			logger.LogError(c, "failed to write stream data: "+err.Error())
+			return false
 		}
 		}
 		info.SendResponseCount++
 		info.SendResponseCount++
 		return true
 		return true

+ 40 - 13
relay/helper/common.go

@@ -14,15 +14,28 @@ import (
 	"github.com/gorilla/websocket"
 	"github.com/gorilla/websocket"
 )
 )
 
 
-func FlushWriter(c *gin.Context) error {
-	if c.Writer == nil {
+func FlushWriter(c *gin.Context) (err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = fmt.Errorf("flush panic recovered: %v", r)
+		}
+	}()
+
+	if c == nil || c.Writer == nil {
 		return nil
 		return nil
 	}
 	}
-	if flusher, ok := c.Writer.(http.Flusher); ok {
-		flusher.Flush()
-		return nil
+
+	if c.Request != nil && c.Request.Context().Err() != nil {
+		return fmt.Errorf("request context done: %w", c.Request.Context().Err())
 	}
 	}
-	return errors.New("streaming error: flusher not found")
+
+	flusher, ok := c.Writer.(http.Flusher)
+	if !ok {
+		return errors.New("streaming error: flusher not found")
+	}
+
+	flusher.Flush()
+	return nil
 }
 }
 
 
 func SetEventStreamHeaders(c *gin.Context) {
 func SetEventStreamHeaders(c *gin.Context) {
@@ -66,17 +79,31 @@ func ResponseChunkData(c *gin.Context, resp dto.ResponsesStreamResponse, data st
 }
 }
 
 
 func StringData(c *gin.Context, str string) error {
 func StringData(c *gin.Context, str string) error {
-	//str = strings.TrimPrefix(str, "data: ")
-	//str = strings.TrimSuffix(str, "\r")
+	if c == nil || c.Writer == nil {
+		return errors.New("context or writer is nil")
+	}
+
+	if c.Request != nil && c.Request.Context().Err() != nil {
+		return fmt.Errorf("request context done: %w", c.Request.Context().Err())
+	}
+
 	c.Render(-1, common.CustomEvent{Data: "data: " + str})
 	c.Render(-1, common.CustomEvent{Data: "data: " + str})
-	_ = FlushWriter(c)
-	return nil
+	return FlushWriter(c)
 }
 }
 
 
 func PingData(c *gin.Context) error {
 func PingData(c *gin.Context) error {
-	c.Writer.Write([]byte(": PING\n\n"))
-	_ = FlushWriter(c)
-	return nil
+	if c == nil || c.Writer == nil {
+		return errors.New("context or writer is nil")
+	}
+
+	if c.Request != nil && c.Request.Context().Err() != nil {
+		return fmt.Errorf("request context done: %w", c.Request.Context().Err())
+	}
+
+	if _, err := c.Writer.Write([]byte(": PING\n\n")); err != nil {
+		return fmt.Errorf("write ping data failed: %w", err)
+	}
+	return FlushWriter(c)
 }
 }
 
 
 func ObjectData(c *gin.Context, object interface{}) error {
 func ObjectData(c *gin.Context, object interface{}) error {