Преглед изворни кода

feat: send SSE ping before get response

creamlike1024 пре 9 месеци
родитељ
комит
ec61534256
4 измењених фајлова са 122 додато и 87 уклоњено
  1. 64 2
      relay/channel/api_request.go
  2. 13 5
      relay/helper/common.go
  3. 44 71
      relay/helper/stream_scanner.go
  4. 1 9
      relay/relay-text.go

+ 64 - 2
relay/channel/api_request.go

@@ -1,16 +1,23 @@
 package channel
 
 import (
+	"context"
 	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
-	"github.com/gorilla/websocket"
 	"io"
 	"net/http"
 	common2 "one-api/common"
 	"one-api/relay/common"
 	"one-api/relay/constant"
+	"one-api/relay/helper"
 	"one-api/service"
+	"one-api/setting/operation_setting"
+	"sync"
+	"time"
+
+	"github.com/bytedance/gopkg/util/gopool"
+	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
 )
 
 func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
@@ -105,7 +112,62 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
 	} else {
 		client = service.GetHttpClient()
 	}
+	// 流式请求 ping 保活
+	var stopPinger func()
+	generalSettings := operation_setting.GetGeneralSetting()
+	pingEnabled := generalSettings.PingIntervalEnabled
+	var pingerWg sync.WaitGroup
+	if info.IsStream {
+		helper.SetEventStreamHeaders(c)
+		pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
+		var pingerCtx context.Context
+		pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
+
+		if pingEnabled {
+			pingerWg.Add(1)
+			gopool.Go(func() {
+				defer pingerWg.Done()
+				if pingInterval <= 0 {
+					pingInterval = helper.DefaultPingInterval
+				}
+
+				ticker := time.NewTicker(pingInterval)
+				defer ticker.Stop()
+				var pingMutex sync.Mutex
+				if common2.DebugEnabled {
+					println("SSE ping goroutine started")
+				}
+
+				for {
+					select {
+					case <-ticker.C:
+						pingMutex.Lock()
+						err2 := helper.PingData(c)
+						pingMutex.Unlock()
+						if err2 != nil {
+							common2.LogError(c, "SSE ping error: "+err.Error())
+							return
+						}
+						if common2.DebugEnabled {
+							println("SSE ping data sent.")
+						}
+					case <-pingerCtx.Done():
+						if common2.DebugEnabled {
+							println("SSE ping goroutine stopped.")
+						}
+						return
+					}
+				}
+			})
+		}
+	}
+
 	resp, err := client.Do(req)
+	// request结束后停止ping
+	if info.IsStream && pingEnabled {
+		stopPinger()
+		pingerWg.Wait()
+	}
 	if err != nil {
 		return nil, err
 	}

+ 13 - 5
relay/helper/common.go

@@ -12,11 +12,19 @@ import (
 )
 
 func SetEventStreamHeaders(c *gin.Context) {
-	c.Writer.Header().Set("Content-Type", "text/event-stream")
-	c.Writer.Header().Set("Cache-Control", "no-cache")
-	c.Writer.Header().Set("Connection", "keep-alive")
-	c.Writer.Header().Set("Transfer-Encoding", "chunked")
-	c.Writer.Header().Set("X-Accel-Buffering", "no")
+    // 检查是否已经设置过头部
+    if _, exists := c.Get("event_stream_headers_set"); exists {
+        return
+    }
+    
+    c.Writer.Header().Set("Content-Type", "text/event-stream")
+    c.Writer.Header().Set("Cache-Control", "no-cache")
+    c.Writer.Header().Set("Connection", "keep-alive")
+    c.Writer.Header().Set("Transfer-Encoding", "chunked")
+    c.Writer.Header().Set("X-Accel-Buffering", "no")
+    
+    // 设置标志,表示头部已经设置过
+    c.Set("event_stream_headers_set", true)
 }
 
 func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {

+ 44 - 71
relay/helper/stream_scanner.go

@@ -3,7 +3,6 @@ package helper
 import (
 	"bufio"
 	"context"
-	"github.com/bytedance/gopkg/util/gopool"
 	"io"
 	"net/http"
 	"one-api/common"
@@ -14,6 +13,8 @@ import (
 	"sync"
 	"time"
 
+	"github.com/bytedance/gopkg/util/gopool"
+
 	"github.com/gin-gonic/gin"
 )
 
@@ -23,76 +24,6 @@ const (
 	DefaultPingInterval      = 10 * time.Second
 )
 
-type DoRequestFunc func(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
-
-// Optional SSE Ping keep-alive mechanism
-//
-// Used to solve the problem of the connection with the client timing out due to no data being sent when the upstream
-// channel response time is long (e.g., thinking model).
-// When enabled, it will send ping data packets to the client via SSE at the specified interval to maintain the connection.
-func DoStreamRequestWithPinger(doRequest DoRequestFunc, c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
-	SetEventStreamHeaders(c)
-
-	generalSettings := operation_setting.GetGeneralSetting()
-	pingEnabled := generalSettings.PingIntervalEnabled
-	pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
-
-	pingerCtx, stopPinger := context.WithCancel(c.Request.Context())
-	var pingerWg sync.WaitGroup
-	var doRequestErr error
-	var resp any
-
-	if pingEnabled {
-		pingerWg.Add(1)
-
-		gopool.Go(func() {
-			defer pingerWg.Done()
-
-			if pingInterval <= 0 {
-				pingInterval = DefaultPingInterval
-			}
-
-			ticker := time.NewTicker(pingInterval)
-			defer ticker.Stop()
-			var pingMutex sync.Mutex
-
-			if common.DebugEnabled {
-				println("SSE ping goroutine started.")
-			}
-
-			for {
-				select {
-				case <-ticker.C:
-					pingMutex.Lock()
-					err := PingData(c)
-					pingMutex.Unlock()
-					if err != nil {
-						common.LogError(c, "SSE ping error: "+err.Error())
-						return
-					}
-					if common.DebugEnabled {
-						println("SSE ping data sent.")
-					}
-				case <-pingerCtx.Done():
-					if common.DebugEnabled {
-						println("SSE ping goroutine stopped.")
-					}
-					return
-				}
-			}
-		})
-	}
-
-	resp, doRequestErr = doRequest(c, info, requestBody)
-
-	stopPinger()
-	if pingEnabled {
-		pingerWg.Wait()
-	}
-
-	return resp, doRequestErr
-}
-
 func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) {
 
 	if resp == nil || dataHandler == nil {
@@ -111,11 +42,26 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 		stopChan   = make(chan bool, 2)
 		scanner    = bufio.NewScanner(resp.Body)
 		ticker     = time.NewTicker(streamingTimeout)
+		pingTicker *time.Ticker
 		writeMutex sync.Mutex // Mutex to protect concurrent writes
 	)
 
+	generalSettings := operation_setting.GetGeneralSetting()
+	pingEnabled := generalSettings.PingIntervalEnabled
+	pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
+	if pingInterval <= 0 {
+		pingInterval = DefaultPingInterval
+	}
+
+	if pingEnabled {
+		pingTicker = time.NewTicker(pingInterval)
+	}
+
 	defer func() {
 		ticker.Stop()
+		if pingTicker != nil {
+			pingTicker.Stop()
+		}
 		close(stopChan)
 	}()
 	scanner.Buffer(make([]byte, InitialScannerBufferSize), MaxScannerBufferSize)
@@ -127,6 +73,33 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
 
 	ctx = context.WithValue(ctx, "stop_chan", stopChan)
 
+	// Handle ping data sending
+	if pingEnabled && pingTicker != nil {
+		gopool.Go(func() {
+			for {
+				select {
+				case <-pingTicker.C:
+					writeMutex.Lock() // Lock before writing
+					err := PingData(c)
+					writeMutex.Unlock() // Unlock after writing
+					if err != nil {
+						common.LogError(c, "ping data error: "+err.Error())
+						common.SafeSendBool(stopChan, true)
+						return
+					}
+					if common.DebugEnabled {
+						println("ping data sent")
+					}
+				case <-ctx.Done():
+					if common.DebugEnabled {
+						println("ping data goroutine stopped")
+					}
+					return
+				}
+			}
+		})
+	}
+
 	common.RelayCtxGo(ctx, func() {
 		for scanner.Scan() {
 			ticker.Reset(streamingTimeout)

+ 1 - 9
relay/relay-text.go

@@ -193,15 +193,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	}
 
 	var httpResp *http.Response
-	var resp any
-
-	if relayInfo.IsStream {
-		// Streaming requests can use SSE ping to keep alive and avoid connection timeout
-		// The judgment of whether ping is enabled will be made within the function
-		resp, err = helper.DoStreamRequestWithPinger(adaptor.DoRequest, c, relayInfo, requestBody)
-	} else {
-		resp, err = adaptor.DoRequest(c, relayInfo, requestBody)
-	}
+	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)