فهرست منبع

fix(realtime): 修复ws 握手失败、计费问题

(cherry picked from commit 618dffc43fd5a5f4065944db87761f9ee18e44d3)
Xyfacai 1 سال پیش
والد
کامیت
be64408a25
4فایلهای تغییر یافته به همراه28 افزوده شده و 10 حذف شده
  1. 1 0
      controller/relay.go
  2. 21 6
      relay/channel/openai/adaptor.go
  3. 4 0
      relay/channel/openai/relay-openai.go
  4. 2 4
      service/quota.go

+ 1 - 0
controller/relay.go

@@ -145,6 +145,7 @@ func Relay(c *gin.Context) {
 }
 
 var upgrader = websocket.Upgrader{
+	Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol
 	CheckOrigin: func(r *http.Request) bool {
 		return true // 允许跨域
 	},

+ 21 - 6
relay/channel/openai/adaptor.go

@@ -63,18 +63,33 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	}
 }
 
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
-	channel.SetupApiRequestHeader(info, c, req)
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
+	channel.SetupApiRequestHeader(info, c, header)
 	if info.ChannelType == common.ChannelTypeAzure {
-		req.Set("api-key", info.ApiKey)
+		header.Set("api-key", info.ApiKey)
 		return nil
 	}
 	if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
-		req.Set("OpenAI-Organization", info.Organization)
+		header.Set("OpenAI-Organization", info.Organization)
 	}
-	req.Set("Authorization", "Bearer "+info.ApiKey)
 	if info.RelayMode == constant.RelayModeRealtime {
-		req.Set("openai-beta", "realtime=v1")
+		swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
+		if swp != "" {
+			items := []string{
+				"realtime",
+				"openai-insecure-api-key." + info.ApiKey,
+				"openai-beta.realtime-v1",
+			}
+			header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
+			//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
+			//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
+			//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
+		} else {
+			header.Set("openai-beta", "realtime=v1")
+			header.Set("Authorization", "Bearer "+info.ApiKey)
+		}
+	} else {
+		header.Set("Authorization", "Bearer "+info.ApiKey)
 	}
 	//if info.ChannelType == common.ChannelTypeOpenRouter {
 	//	req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")

+ 4 - 0
relay/channel/openai/relay-openai.go

@@ -483,7 +483,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 							errChan <- fmt.Errorf("error consume usage: %v", err)
 							return
 						}
+						// 本次计费完成,清除
 						usage = &dto.RealtimeUsage{}
+
+						localUsage = &dto.RealtimeUsage{}
 					} else {
 						textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
 						if err != nil {
@@ -501,6 +504,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 							errChan <- fmt.Errorf("error consume usage: %v", err)
 							return
 						}
+						// 本次计费完成,清除
 						localUsage = &dto.RealtimeUsage{}
 						// print now usage
 					}

+ 2 - 4
service/quota.go

@@ -78,10 +78,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 
 	quota := 0
 	if !usePrice {
-		quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
-		quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
-
-		quota = int(math.Round(float64(quota) * ratio))
+		quota = int(math.Round(float64(textInputTokens)*ratio + float64(textOutTokens)*ratio*completionRatio))
+		quota += int(math.Round(float64(audioInputTokens)*ratio*audioRatio + float64(audioOutTokens)*ratio*audioRatio*audioCompletionRatio))
 		if ratio != 0 && quota <= 0 {
 			quota = 1
 		}