浏览代码

feat: realtime

(cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425)
[email protected] 1 年之前
父节点
当前提交
74f9006b40

+ 15 - 1
common/model-ratio.go

@@ -432,9 +432,23 @@ func GetAudioCompletionRatio(name string) float64 {
 	if strings.HasPrefix(name, "gpt-4o-realtime") {
 		return 10
 	}
-	return 10
+	return 2
 }
 
+//func GetAudioPricePerMinute(name string) float64 {
+//	if strings.HasPrefix(name, "gpt-4o-realtime") {
+//		return 0.06
+//	}
+//	return 0.06
+//}
+//
+//func GetAudioCompletionPricePerMinute(name string) float64 {
+//	if strings.HasPrefix(name, "gpt-4o-realtime") {
+//		return 0.24
+//	}
+//	return 0.24
+//}
+
 func GetCompletionRatioMap() map[string]float64 {
 	if CompletionRatio == nil {
 		CompletionRatio = defaultCompletionRatio

+ 11 - 1
dto/realtime.go

@@ -5,10 +5,18 @@ const (
 	RealtimeEventTypeSessionUpdate      = "session.update"
 	RealtimeEventTypeConversationCreate = "conversation.item.create"
 	RealtimeEventTypeResponseCreate     = "response.create"
+	RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append"
 )
 
 const (
-	RealtimeEventTypeResponseDone = "response.done"
+	RealtimeEventTypeResponseDone                   = "response.done"
+	RealtimeEventTypeSessionUpdated                 = "session.updated"
+	RealtimeEventTypeSessionCreated                 = "session.created"
+	RealtimeEventResponseAudioDelta                 = "response.audio.delta"
+	RealtimeEventResponseAudioTranscriptionDelta    = "response.audio_transcript.delta"
+	RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta"
+	RealtimeEventResponseFunctionCallArgumentsDone  = "response.function_call_arguments.done"
+	RealtimeEventConversationItemCreated            = "conversation.item.created"
 )
 
 type RealtimeEvent struct {
@@ -19,6 +27,8 @@ type RealtimeEvent struct {
 	Item     *RealtimeItem     `json:"item,omitempty"`
 	Error    *OpenAIError      `json:"error,omitempty"`
 	Response *RealtimeResponse `json:"response,omitempty"`
+	Delta    string            `json:"delta,omitempty"`
+	Audio    string            `json:"audio,omitempty"`
 }
 
 type RealtimeResponse struct {

+ 1 - 1
relay/channel/claude/relay-claude.go

@@ -509,7 +509,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
 		}, nil
 	}
 	fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
-	completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName)
+	completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
 	}

+ 1 - 1
relay/channel/cloudflare/relay_cloudflare.go

@@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
 
 	usage := &dto.Usage{}
 	usage.PromptTokens = info.PromptTokens
-	usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName)
+	usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 
 	return nil, usage

+ 1 - 1
relay/channel/dify/relay-dify.go

@@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 	}
 	if usage.TotalTokens == 0 {
 		usage.PromptTokens = info.PromptTokens
-		usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText)
+		usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
 		usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	}
 	return nil, usage

+ 3 - 1
relay/channel/openai/adaptor.go

@@ -47,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 		model_ := info.UpstreamModelName
 		model_ = strings.Replace(model_, ".", "", -1)
 		// https://github.com/songquanpeng/one-api/issues/67
-
 		requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
+		if info.RelayMode == constant.RelayModeRealtime {
+			requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
+		}
 		return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
 	case common.ChannelTypeMiniMax:
 		return minimax.GetRequestURL(info)

+ 61 - 2
relay/channel/openai/relay-openai.go

@@ -9,6 +9,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/gorilla/websocket"
 	"io"
+	"log"
 	"net/http"
 	"one-api/common"
 	"one-api/constant"
@@ -232,7 +233,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
 		completionTokens := 0
 		for _, choice := range simpleResponse.Choices {
-			ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
+			ctkm, _ := service.CountTextToken(string(choice.Message.Content), model)
 			completionTokens += ctkm
 		}
 		simpleResponse.Usage = dto.Usage{
@@ -325,7 +326,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 
 	usage := &dto.Usage{}
 	usage.PromptTokens = info.PromptTokens
-	usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName)
+	usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName)
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return nil, usage
 }
@@ -387,6 +388,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 	errChan := make(chan error, 2)
 
 	usage := &dto.RealtimeUsage{}
+	localUsage := &dto.RealtimeUsage{}
 
 	go func() {
 		for {
@@ -403,6 +405,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 					return
 				}
 
+				realtimeEvent := &dto.RealtimeEvent{}
+				err = json.Unmarshal(message, realtimeEvent)
+				if err != nil {
+					errChan <- fmt.Errorf("error unmarshalling message: %v", err)
+					return
+				}
+
+				if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
+					if realtimeEvent.Session != nil {
+						if realtimeEvent.Session.Tools != nil {
+							info.RealtimeTools = realtimeEvent.Session.Tools
+						}
+					}
+				}
+
+				textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
+				if err != nil {
+					errChan <- fmt.Errorf("error counting text token: %v", err)
+					return
+				}
+				log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
+				localUsage.TotalTokens += textToken + audioToken
+				localUsage.InputTokens += textToken
+				localUsage.InputTokenDetails.TextTokens += textToken
+				localUsage.InputTokenDetails.AudioTokens += audioToken
+
 				err = service.WssString(c, targetConn, string(message))
 				if err != nil {
 					errChan <- fmt.Errorf("error writing to target: %v", err)
@@ -451,6 +479,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 						usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
 						usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
 					}
+				} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
+					realtimeSession := realtimeEvent.Session
+					if realtimeSession != nil {
+						// update audio format
+						info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
+						info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
+					}
+				} else {
+					textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
+					if err != nil {
+						errChan <- fmt.Errorf("error counting text token: %v", err)
+						return
+					}
+					log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)
+					localUsage.TotalTokens += textToken + audioToken
+
+					if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
+						info.IsFirstRequest = false
+						localUsage.InputTokens += textToken + audioToken
+						localUsage.InputTokenDetails.TextTokens += textToken
+						localUsage.InputTokenDetails.AudioTokens += audioToken
+					} else {
+						localUsage.OutputTokens += textToken + audioToken
+						localUsage.OutputTokenDetails.TextTokens += textToken
+						localUsage.OutputTokenDetails.AudioTokens += audioToken
+					}
 				}
 
 				err = service.WssString(c, clientConn, string(message))
@@ -475,5 +529,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 	case <-c.Done():
 	}
 
+	// check usage total tokens, if 0, use local usage
+
+	if usage.TotalTokens == 0 {
+		usage = localUsage
+	}
 	return nil, usage
 }

+ 1 - 1
relay/channel/palm/relay-palm.go

@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 		}, nil
 	}
 	fullTextResponse := responsePaLM2OpenAI(&palmResponse)
-	completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
+	completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
 	usage := dto.Usage{
 		PromptTokens:     promptTokens,
 		CompletionTokens: completionTokens,

+ 8 - 0
relay/common/relay_info.go

@@ -4,6 +4,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/gorilla/websocket"
 	"one-api/common"
+	"one-api/dto"
 	"one-api/relay/constant"
 	"strings"
 	"time"
@@ -35,11 +36,18 @@ type RelayInfo struct {
 	ShouldIncludeUsage   bool
 	ClientWs             *websocket.Conn
 	TargetWs             *websocket.Conn
+	InputAudioFormat     string
+	OutputAudioFormat    string
+	RealtimeTools        []dto.RealTimeTool
+	IsFirstRequest       bool
 }
 
 func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo {
 	info := GenRelayInfo(c)
 	info.ClientWs = ws
+	info.InputAudioFormat = "pcm16"
+	info.OutputAudioFormat = "pcm16"
+	info.IsFirstRequest = true
 	return info
 }
 

+ 1 - 1
relay/relay-audio.go

@@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	promptTokens := 0
 	preConsumedTokens := common.PreConsumedQuota
 	if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
-		promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
+		promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
 		}

+ 11 - 11
relay/websocket.go

@@ -150,7 +150,7 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 	quota := 0
 	if !usePrice {
 		quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio))
-		quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*completionRatio*audioCompletionRatio))
+		quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio))
 
 		quota = int(math.Round(float64(quota) * ratio))
 		if ratio != 0 && quota <= 0 {
@@ -215,16 +215,16 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
 	//}
 }
 
-func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
-	var promptTokens int
-	var err error
-	switch info.RelayMode {
-	default:
-		promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
-	}
-	info.PromptTokens = promptTokens
-	return promptTokens, err
-}
+//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) {
+//	var promptTokens int
+//	var err error
+//	switch info.RelayMode {
+//	default:
+//		promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName)
+//	}
+//	info.PromptTokens = promptTokens
+//	return promptTokens, err
+//}
 
 //func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
 //	var err error

+ 31 - 0
service/audio.go

@@ -0,0 +1,31 @@
+package service
+
+import (
+	"encoding/base64"
+	"fmt"
+)
+
+func parseAudio(audioBase64 string, format string) (duration float64, err error) {
+	audioData, err := base64.StdEncoding.DecodeString(audioBase64)
+	if err != nil {
+		return 0, fmt.Errorf("base64 decode error: %v", err)
+	}
+
+	var samplesCount int
+	var sampleRate int
+
+	switch format {
+	case "pcm16":
+		samplesCount = len(audioData) / 2 // 16位 = 2字节每样本
+		sampleRate = 24000                // 24kHz
+	case "g711_ulaw", "g711_alaw":
+		samplesCount = len(audioData) // 8位 = 1字节每样本
+		sampleRate = 8000             // 8kHz
+	default:
+		samplesCount = len(audioData) // 8位 = 1字节每样本
+		sampleRate = 8000             // 8kHz
+	}
+
+	duration = float64(samplesCount) / float64(sampleRate)
+	return duration, nil
+}

+ 2 - 2
service/relay.go

@@ -48,7 +48,7 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error {
 		common.LogError(c, "websocket connection is nil")
 		return errors.New("websocket connection is nil")
 	}
-	common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
+	//common.LogInfo(c, fmt.Sprintf("sending message: %s", str))
 	return ws.WriteMessage(1, []byte(str))
 }
 
@@ -61,7 +61,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
 		common.LogError(c, "websocket connection is nil")
 		return errors.New("websocket connection is nil")
 	}
-	common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
+	//common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData))
 	return ws.WriteMessage(1, jsonData)
 }
 

+ 78 - 37
service/token_counter.go

@@ -11,6 +11,7 @@ import (
 	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
+	relaycommon "one-api/relay/common"
 	"strings"
 	"unicode/utf8"
 )
@@ -191,43 +192,55 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int,
 	return tkm, nil
 }
 
-func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) {
-	tkm := 0
-	ratio := 1
-	if request.Session != nil {
-		msgTokens, err := CountTokenText(request.Session.Instructions, model)
-		if err != nil {
-			return 0, err
-		}
-		ratio = len(request.Session.Modalities)
-		tkm += msgTokens
-		if request.Session.Tools != nil {
-			toolsData, _ := json.Marshal(request.Session.Tools)
-			var openaiTools []dto.OpenAITools
-			err := json.Unmarshal(toolsData, &openaiTools)
+func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) {
+	audioToken := 0
+	textToken := 0
+	switch request.Type {
+	case dto.RealtimeEventTypeSessionUpdate:
+		if request.Session != nil {
+			msgTokens, err := CountTextToken(request.Session.Instructions, model)
 			if err != nil {
-				return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
+				return 0, 0, err
 			}
-			countStr := ""
-			for _, tool := range openaiTools {
-				countStr = tool.Function.Name
-				if tool.Function.Description != "" {
-					countStr += tool.Function.Description
-				}
-				if tool.Function.Parameters != nil {
-					countStr += fmt.Sprintf("%v", tool.Function.Parameters)
+			textToken += msgTokens
+		}
+	case dto.RealtimeEventResponseAudioDelta:
+		// count audio token
+		atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat)
+		if err != nil {
+			return 0, 0, fmt.Errorf("error counting audio token: %v", err)
+		}
+		audioToken += atk
+	case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
+		// count text token
+		tkm, err := CountTextToken(request.Delta, model)
+		if err != nil {
+			return 0, 0, fmt.Errorf("error counting text token: %v", err)
+		}
+		textToken += tkm
+	case dto.RealtimeEventInputAudioBufferAppend:
+		// count audio token
+		atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat)
+		if err != nil {
+			return 0, 0, fmt.Errorf("error counting audio token: %v", err)
+		}
+		audioToken += atk
+	case dto.RealtimeEventTypeResponseDone:
+		// count tools token
+		if !info.IsFirstRequest {
+			if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
+				for _, tool := range info.RealtimeTools {
+					toolTokens, err := CountTokenInput(tool, model)
+					if err != nil {
+						return 0, 0, err
+					}
+					textToken += 8
+					textToken += toolTokens
 				}
 			}
-			toolTokens, err := CountTokenInput(countStr, model)
-			if err != nil {
-				return 0, err
-			}
-			tkm += 8
-			tkm += toolTokens
 		}
 	}
-	tkm *= ratio
-	return tkm, nil
+	return textToken, audioToken, nil
 }
 
 func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) {
@@ -287,13 +300,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int,
 func CountTokenInput(input any, model string) (int, error) {
 	switch v := input.(type) {
 	case string:
-		return CountTokenText(v, model)
+		return CountTextToken(v, model)
 	case []string:
 		text := ""
 		for _, s := range v {
 			text += s
 		}
-		return CountTokenText(text, model)
+		return CountTextToken(text, model)
 	}
 	return CountTokenInput(fmt.Sprintf("%v", input), model)
 }
@@ -315,16 +328,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
 	return tokens
 }
 
-func CountAudioToken(text string, model string) (int, error) {
+func CountTTSToken(text string, model string) (int, error) {
 	if strings.HasPrefix(model, "tts") {
 		return utf8.RuneCountInString(text), nil
 	} else {
-		return CountTokenText(text, model)
+		return CountTextToken(text, model)
+	}
+}
+
+func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) {
+	if audioBase64 == "" {
+		return 0, nil
+	}
+	duration, err := parseAudio(audioBase64, audioFormat)
+	if err != nil {
+		return 0, err
 	}
+	return int(duration / 60 * 100 / 0.06), nil
 }
 
-// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
-func CountTokenText(text string, model string) (int, error) {
+func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) {
+	if audioBase64 == "" {
+		return 0, nil
+	}
+	duration, err := parseAudio(audioBase64, audioFormat)
+	if err != nil {
+		return 0, err
+	}
+	return int(duration / 60 * 200 / 0.24), nil
+}
+
+//func CountAudioToken(sec float64, audioType string) {
+//	if audioType == "input" {
+//
+//	}
+//}
+
+// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
+func CountTextToken(text string, model string) (int, error) {
 	var err error
 	tokenEncoder := getTokenEncoder(model)
 	return getTokenNum(tokenEncoder, text), err

+ 1 - 1
service/usage_helpr.go

@@ -19,7 +19,7 @@ import (
 func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
 	usage := &dto.Usage{}
 	usage.PromptTokens = promptTokens
-	ctkm, err := CountTokenText(responseText, modeName)
+	ctkm, err := CountTextToken(responseText, modeName)
 	usage.CompletionTokens = ctkm
 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
 	return usage, err