|
|
@@ -1,16 +1,23 @@
|
|
|
package channel
|
|
|
|
|
|
import (
|
|
|
+ "context"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
- "github.com/gin-gonic/gin"
|
|
|
- "github.com/gorilla/websocket"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
common2 "one-api/common"
|
|
|
"one-api/relay/common"
|
|
|
"one-api/relay/constant"
|
|
|
+ "one-api/relay/helper"
|
|
|
"one-api/service"
|
|
|
+ "one-api/setting/operation_setting"
|
|
|
+ "sync"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/bytedance/gopkg/util/gopool"
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
+ "github.com/gorilla/websocket"
|
|
|
)
|
|
|
|
|
|
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
|
|
|
@@ -105,7 +112,62 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
|
|
} else {
|
|
|
client = service.GetHttpClient()
|
|
|
}
|
|
|
+ // 流式请求 ping 保活
|
|
|
+ var stopPinger func()
|
|
|
+ generalSettings := operation_setting.GetGeneralSetting()
|
|
|
+ pingEnabled := generalSettings.PingIntervalEnabled
|
|
|
+ var pingerWg sync.WaitGroup
|
|
|
+ if info.IsStream {
|
|
|
+ helper.SetEventStreamHeaders(c)
|
|
|
+ pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
|
|
|
+ var pingerCtx context.Context
|
|
|
+ pingerCtx, stopPinger = context.WithCancel(c.Request.Context())
|
|
|
+
|
|
|
+ if pingEnabled {
|
|
|
+ pingerWg.Add(1)
|
|
|
+ gopool.Go(func() {
|
|
|
+ defer pingerWg.Done()
|
|
|
+ if pingInterval <= 0 {
|
|
|
+ pingInterval = helper.DefaultPingInterval
|
|
|
+ }
|
|
|
+
|
|
|
+ ticker := time.NewTicker(pingInterval)
|
|
|
+ defer ticker.Stop()
|
|
|
+ var pingMutex sync.Mutex
|
|
|
+ if common2.DebugEnabled {
|
|
|
+ println("SSE ping goroutine started")
|
|
|
+ }
|
|
|
+
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-ticker.C:
|
|
|
+ pingMutex.Lock()
|
|
|
+ err2 := helper.PingData(c)
|
|
|
+ pingMutex.Unlock()
|
|
|
+ if err2 != nil {
|
|
|
+ common2.LogError(c, "SSE ping error: "+err.Error())
|
|
|
+ return
|
|
|
+ }
|
|
|
+ if common2.DebugEnabled {
|
|
|
+ println("SSE ping data sent.")
|
|
|
+ }
|
|
|
+ case <-pingerCtx.Done():
|
|
|
+ if common2.DebugEnabled {
|
|
|
+ println("SSE ping goroutine stopped.")
|
|
|
+ }
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
resp, err := client.Do(req)
|
|
|
+ // request结束后停止ping
|
|
|
+ if info.IsStream && pingEnabled {
|
|
|
+ stopPinger()
|
|
|
+ pingerWg.Wait()
|
|
|
+ }
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|