2
0
Эх сурвалжийг харах

Merge branch '9Ninety-fix/sse_ping'

creamlike1024 7 сар өмнө
parent
commit
4a59b3ccd6

+ 64 - 2
relay/channel/api_request.go

@@ -1,16 +1,23 @@
 package channel
 
 import (
+	"context"
 	"errors"
 	"fmt"
-	"github.com/gin-gonic/gin"
-	"github.com/gorilla/websocket"
 	"io"
 	"net/http"
 	common2 "one-api/common"
 	"one-api/relay/common"
 	"one-api/relay/constant"
+	"one-api/relay/helper"
 	"one-api/service"
+	"one-api/setting/operation_setting"
+	"sync"
+	"time"
+
+	"github.com/bytedance/gopkg/util/gopool"
+	"github.com/gin-gonic/gin"
+	"github.com/gorilla/websocket"
 )
 
 func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
@@ -105,7 +112,62 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
 	} else {
 		client = service.GetHttpClient()
 	}
+	// 流式请求 ping 保活
+	var stopPinger func()
+	generalSettings := operation_setting.GetGeneralSetting()
+	pingEnabled := generalSettings.PingIntervalEnabled
+	var pingerWg sync.WaitGroup
+	if info.IsStream {
+		helper.SetEventStreamHeaders(c)
+		pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
+		var pingerCtx context.Context
+		pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
+
+		if pingEnabled {
+			pingerWg.Add(1)
+			gopool.Go(func() {
+				defer pingerWg.Done()
+				if pingInterval <= 0 {
+					pingInterval = helper.DefaultPingInterval
+				}
+
+				ticker := time.NewTicker(pingInterval)
+				defer ticker.Stop()
+				var pingMutex sync.Mutex
+				if common2.DebugEnabled {
+					println("SSE ping goroutine started")
+				}
+
+				for {
+					select {
+					case <-ticker.C:
+						pingMutex.Lock()
+						err2 := helper.PingData(c)
+						pingMutex.Unlock()
+						if err2 != nil {
+							common2.LogError(c, "SSE ping error: "+err.Error())
+							return
+						}
+						if common2.DebugEnabled {
+							println("SSE ping data sent.")
+						}
+					case <-pingerCtx.Done():
+						if common2.DebugEnabled {
+							println("SSE ping goroutine stopped.")
+						}
+						return
+					}
+				}
+			})
+		}
+	}
+
 	resp, err := client.Do(req)
+	// request结束后停止ping
+	if info.IsStream && pingEnabled {
+		stopPinger()
+		pingerWg.Wait()
+	}
 	if err != nil {
 		return nil, err
 	}

+ 13 - 5
relay/helper/common.go

@@ -12,11 +12,19 @@ import (
 )
 
 func SetEventStreamHeaders(c *gin.Context) {
-	c.Writer.Header().Set("Content-Type", "text/event-stream")
-	c.Writer.Header().Set("Cache-Control", "no-cache")
-	c.Writer.Header().Set("Connection", "keep-alive")
-	c.Writer.Header().Set("Transfer-Encoding", "chunked")
-	c.Writer.Header().Set("X-Accel-Buffering", "no")
+    // 检查是否已经设置过头部
+    if _, exists := c.Get("event_stream_headers_set"); exists {
+        return
+    }
+    
+    c.Writer.Header().Set("Content-Type", "text/event-stream")
+    c.Writer.Header().Set("Cache-Control", "no-cache")
+    c.Writer.Header().Set("Connection", "keep-alive")
+    c.Writer.Header().Set("Transfer-Encoding", "chunked")
+    c.Writer.Header().Set("X-Accel-Buffering", "no")
+    
+    // 设置标志,表示头部已经设置过
+    c.Set("event_stream_headers_set", true)
 }
 
 func ClaudeData(c *gin.Context, resp dto.ClaudeResponse) error {

+ 2 - 1
relay/helper/stream_scanner.go

@@ -3,7 +3,6 @@ package helper
 import (
 	"bufio"
 	"context"
-	"github.com/bytedance/gopkg/util/gopool"
 	"io"
 	"net/http"
 	"one-api/common"
@@ -14,6 +13,8 @@ import (
 	"sync"
 	"time"
 
+	"github.com/bytedance/gopkg/util/gopool"
+
 	"github.com/gin-gonic/gin"
 )
 

+ 1 - 0
relay/relay-text.go

@@ -194,6 +194,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 
 	var httpResp *http.Response
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}