relay-openai.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704
  1. package openai
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "strings"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/logger"
  12. "github.com/QuantumNous/new-api/relay/channel/openrouter"
  13. relaycommon "github.com/QuantumNous/new-api/relay/common"
  14. "github.com/QuantumNous/new-api/relay/helper"
  15. "github.com/QuantumNous/new-api/service"
  16. "github.com/QuantumNous/new-api/types"
  17. "github.com/bytedance/gopkg/util/gopool"
  18. "github.com/gin-gonic/gin"
  19. "github.com/gorilla/websocket"
  20. )
  21. func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  22. if data == "" {
  23. return nil
  24. }
  25. if !forceFormat && !thinkToContent {
  26. return helper.StringData(c, data)
  27. }
  28. var lastStreamResponse dto.ChatCompletionsStreamResponse
  29. if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
  30. return err
  31. }
  32. if !thinkToContent {
  33. return helper.ObjectData(c, lastStreamResponse)
  34. }
  35. hasThinkingContent := false
  36. hasContent := false
  37. var thinkingContent strings.Builder
  38. for _, choice := range lastStreamResponse.Choices {
  39. if len(choice.Delta.GetReasoningContent()) > 0 {
  40. hasThinkingContent = true
  41. thinkingContent.WriteString(choice.Delta.GetReasoningContent())
  42. }
  43. if len(choice.Delta.GetContentString()) > 0 {
  44. hasContent = true
  45. }
  46. }
  47. // Handle think to content conversion
  48. if info.ThinkingContentInfo.IsFirstThinkingContent {
  49. if hasThinkingContent {
  50. response := lastStreamResponse.Copy()
  51. for i := range response.Choices {
  52. // send `think` tag with thinking content
  53. response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
  54. response.Choices[i].Delta.ReasoningContent = nil
  55. response.Choices[i].Delta.Reasoning = nil
  56. }
  57. info.ThinkingContentInfo.IsFirstThinkingContent = false
  58. info.ThinkingContentInfo.HasSentThinkingContent = true
  59. return helper.ObjectData(c, response)
  60. }
  61. }
  62. if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
  63. return helper.ObjectData(c, lastStreamResponse)
  64. }
  65. // Process each choice
  66. for i, choice := range lastStreamResponse.Choices {
  67. // Handle transition from thinking to content
  68. // only send `</think>` tag when previous thinking content has been sent
  69. if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
  70. response := lastStreamResponse.Copy()
  71. for j := range response.Choices {
  72. response.Choices[j].Delta.SetContentString("\n</think>\n")
  73. response.Choices[j].Delta.ReasoningContent = nil
  74. response.Choices[j].Delta.Reasoning = nil
  75. }
  76. info.ThinkingContentInfo.SendLastThinkingContent = true
  77. helper.ObjectData(c, response)
  78. }
  79. // Convert reasoning content to regular content if any
  80. if len(choice.Delta.GetReasoningContent()) > 0 {
  81. lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
  82. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  83. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  84. } else if !hasThinkingContent && !hasContent {
  85. // flush thinking content
  86. lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
  87. lastStreamResponse.Choices[i].Delta.Reasoning = nil
  88. }
  89. }
  90. return helper.ObjectData(c, lastStreamResponse)
  91. }
  92. func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  93. if resp == nil || resp.Body == nil {
  94. logger.LogError(c, "invalid response or response body")
  95. return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
  96. }
  97. defer service.CloseResponseBodyGracefully(resp)
  98. model := info.UpstreamModelName
  99. var responseId string
  100. var createAt int64 = 0
  101. var systemFingerprint string
  102. var containStreamUsage bool
  103. var responseTextBuilder strings.Builder
  104. var toolCount int
  105. var usage = &dto.Usage{}
  106. var streamItems []string // store stream items
  107. var lastStreamData string
  108. var secondLastStreamData string // 存储倒数第二个stream data,用于音频模型
  109. // 检查是否为音频模型
  110. isAudioModel := strings.Contains(strings.ToLower(model), "audio")
  111. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  112. if lastStreamData != "" {
  113. err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  114. if err != nil {
  115. common.SysLog("error handling stream format: " + err.Error())
  116. }
  117. }
  118. if len(data) > 0 {
  119. // 对音频模型,保存倒数第二个stream data
  120. if isAudioModel && lastStreamData != "" {
  121. secondLastStreamData = lastStreamData
  122. }
  123. lastStreamData = data
  124. streamItems = append(streamItems, data)
  125. }
  126. return true
  127. })
  128. // 对音频模型,从倒数第二个stream data中提取usage信息
  129. if isAudioModel && secondLastStreamData != "" {
  130. var streamResp struct {
  131. Usage *dto.Usage `json:"usage"`
  132. }
  133. err := json.Unmarshal([]byte(secondLastStreamData), &streamResp)
  134. if err == nil && streamResp.Usage != nil && service.ValidUsage(streamResp.Usage) {
  135. usage = streamResp.Usage
  136. containStreamUsage = true
  137. if common.DebugEnabled {
  138. logger.LogDebug(c, fmt.Sprintf("Audio model usage extracted from second last SSE: PromptTokens=%d, CompletionTokens=%d, TotalTokens=%d, InputTokens=%d, OutputTokens=%d",
  139. usage.PromptTokens, usage.CompletionTokens, usage.TotalTokens,
  140. usage.InputTokens, usage.OutputTokens))
  141. }
  142. }
  143. }
  144. // 处理最后的响应
  145. shouldSendLastResp := true
  146. if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
  147. &containStreamUsage, info, &shouldSendLastResp); err != nil {
  148. logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
  149. }
  150. if info.RelayFormat == types.RelayFormatOpenAI {
  151. if shouldSendLastResp {
  152. _ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
  153. }
  154. }
  155. // 处理token计算
  156. if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
  157. logger.LogError(c, "error processing tokens: "+err.Error())
  158. }
  159. if !containStreamUsage {
  160. usage = service.ResponseText2Usage(c, responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
  161. usage.CompletionTokens += toolCount * 7
  162. }
  163. applyUsagePostProcessing(info, usage, nil)
  164. HandleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
  165. return usage, nil
  166. }
  167. func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  168. defer service.CloseResponseBodyGracefully(resp)
  169. var simpleResponse dto.OpenAITextResponse
  170. responseBody, err := io.ReadAll(resp.Body)
  171. if err != nil {
  172. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  173. }
  174. if common.DebugEnabled {
  175. println("upstream response body:", string(responseBody))
  176. }
  177. // Unmarshal to simpleResponse
  178. if info.ChannelType == constant.ChannelTypeOpenRouter && info.ChannelOtherSettings.IsOpenRouterEnterprise() {
  179. // 尝试解析为 openrouter enterprise
  180. var enterpriseResponse openrouter.OpenRouterEnterpriseResponse
  181. err = common.Unmarshal(responseBody, &enterpriseResponse)
  182. if err != nil {
  183. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  184. }
  185. if enterpriseResponse.Success {
  186. responseBody = enterpriseResponse.Data
  187. } else {
  188. logger.LogError(c, fmt.Sprintf("openrouter enterprise response success=false, data: %s", enterpriseResponse.Data))
  189. return nil, types.NewOpenAIError(fmt.Errorf("openrouter response success=false"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  190. }
  191. }
  192. err = common.Unmarshal(responseBody, &simpleResponse)
  193. if err != nil {
  194. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  195. }
  196. if oaiError := simpleResponse.GetOpenAIError(); oaiError != nil && oaiError.Type != "" {
  197. return nil, types.WithOpenAIError(*oaiError, resp.StatusCode)
  198. }
  199. forceFormat := false
  200. if info.ChannelSetting.ForceFormat {
  201. forceFormat = true
  202. }
  203. usageModified := false
  204. if simpleResponse.Usage.PromptTokens == 0 {
  205. completionTokens := simpleResponse.Usage.CompletionTokens
  206. if completionTokens == 0 {
  207. for _, choice := range simpleResponse.Choices {
  208. ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
  209. completionTokens += ctkm
  210. }
  211. }
  212. simpleResponse.Usage = dto.Usage{
  213. PromptTokens: info.PromptTokens,
  214. CompletionTokens: completionTokens,
  215. TotalTokens: info.PromptTokens + completionTokens,
  216. }
  217. usageModified = true
  218. }
  219. applyUsagePostProcessing(info, &simpleResponse.Usage, responseBody)
  220. switch info.RelayFormat {
  221. case types.RelayFormatOpenAI:
  222. if usageModified {
  223. var bodyMap map[string]interface{}
  224. err = common.Unmarshal(responseBody, &bodyMap)
  225. if err != nil {
  226. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  227. }
  228. bodyMap["usage"] = simpleResponse.Usage
  229. responseBody, _ = common.Marshal(bodyMap)
  230. }
  231. if forceFormat {
  232. responseBody, err = common.Marshal(simpleResponse)
  233. if err != nil {
  234. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  235. }
  236. } else {
  237. break
  238. }
  239. case types.RelayFormatClaude:
  240. claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
  241. claudeRespStr, err := common.Marshal(claudeResp)
  242. if err != nil {
  243. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  244. }
  245. responseBody = claudeRespStr
  246. case types.RelayFormatGemini:
  247. geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
  248. geminiRespStr, err := common.Marshal(geminiResp)
  249. if err != nil {
  250. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  251. }
  252. responseBody = geminiRespStr
  253. }
  254. service.IOCopyBytesGracefully(c, resp, responseBody)
  255. return &simpleResponse.Usage, nil
  256. }
  257. func streamTTSResponse(c *gin.Context, resp *http.Response) {
  258. c.Writer.WriteHeaderNow()
  259. flusher, ok := c.Writer.(http.Flusher)
  260. if !ok {
  261. logger.LogWarn(c, "streaming not supported")
  262. _, err := io.Copy(c.Writer, resp.Body)
  263. if err != nil {
  264. logger.LogWarn(c, err.Error())
  265. }
  266. return
  267. }
  268. buffer := make([]byte, 4096)
  269. for {
  270. n, err := resp.Body.Read(buffer)
  271. //logger.LogInfo(c, fmt.Sprintf("streamTTSResponse read %d bytes", n))
  272. if n > 0 {
  273. if _, writeErr := c.Writer.Write(buffer[:n]); writeErr != nil {
  274. logger.LogError(c, writeErr.Error())
  275. break
  276. }
  277. flusher.Flush()
  278. }
  279. if err != nil {
  280. if err != io.EOF {
  281. logger.LogError(c, err.Error())
  282. }
  283. break
  284. }
  285. }
  286. }
  287. func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
  288. // the status code has been judged before, if there is a body reading failure,
  289. // it should be regarded as a non-recoverable error, so it should not return err for external retry.
  290. // Analogous to nginx's load balancing, it will only retry if it can't be requested or
  291. // if the upstream returns a specific status code, once the upstream has already written the header,
  292. // the subsequent failure of the response body should be regarded as a non-recoverable error,
  293. // and can be terminated directly.
  294. defer service.CloseResponseBodyGracefully(resp)
  295. usage := &dto.Usage{}
  296. usage.PromptTokens = info.PromptTokens
  297. usage.TotalTokens = info.PromptTokens
  298. for k, v := range resp.Header {
  299. c.Writer.Header().Set(k, v[0])
  300. }
  301. c.Writer.WriteHeader(resp.StatusCode)
  302. isStreaming := resp.ContentLength == -1 || resp.Header.Get("Content-Length") == ""
  303. if isStreaming {
  304. streamTTSResponse(c, resp)
  305. } else {
  306. c.Writer.WriteHeaderNow()
  307. _, err := io.Copy(c.Writer, resp.Body)
  308. if err != nil {
  309. logger.LogError(c, err.Error())
  310. }
  311. }
  312. return usage
  313. }
  314. func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
  315. defer service.CloseResponseBodyGracefully(resp)
  316. responseBody, err := io.ReadAll(resp.Body)
  317. if err != nil {
  318. return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
  319. }
  320. // 写入新的 response body
  321. service.IOCopyBytesGracefully(c, resp, responseBody)
  322. var responseData struct {
  323. Usage *dto.Usage `json:"usage"`
  324. }
  325. if err := json.Unmarshal(responseBody, &responseData); err == nil && responseData.Usage != nil {
  326. if responseData.Usage.TotalTokens > 0 {
  327. usage := responseData.Usage
  328. if usage.PromptTokens == 0 {
  329. usage.PromptTokens = usage.InputTokens
  330. }
  331. if usage.CompletionTokens == 0 {
  332. usage.CompletionTokens = usage.OutputTokens
  333. }
  334. return nil, usage
  335. }
  336. }
  337. usage := &dto.Usage{}
  338. usage.PromptTokens = info.PromptTokens
  339. usage.CompletionTokens = 0
  340. usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
  341. return nil, usage
  342. }
  343. func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
  344. if info == nil || info.ClientWs == nil || info.TargetWs == nil {
  345. return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
  346. }
  347. info.IsStream = true
  348. clientConn := info.ClientWs
  349. targetConn := info.TargetWs
  350. clientClosed := make(chan struct{})
  351. targetClosed := make(chan struct{})
  352. sendChan := make(chan []byte, 100)
  353. receiveChan := make(chan []byte, 100)
  354. errChan := make(chan error, 2)
  355. usage := &dto.RealtimeUsage{}
  356. localUsage := &dto.RealtimeUsage{}
  357. sumUsage := &dto.RealtimeUsage{}
  358. gopool.Go(func() {
  359. defer func() {
  360. if r := recover(); r != nil {
  361. errChan <- fmt.Errorf("panic in client reader: %v", r)
  362. }
  363. }()
  364. for {
  365. select {
  366. case <-c.Done():
  367. return
  368. default:
  369. _, message, err := clientConn.ReadMessage()
  370. if err != nil {
  371. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  372. errChan <- fmt.Errorf("error reading from client: %v", err)
  373. }
  374. close(clientClosed)
  375. return
  376. }
  377. realtimeEvent := &dto.RealtimeEvent{}
  378. err = common.Unmarshal(message, realtimeEvent)
  379. if err != nil {
  380. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  381. return
  382. }
  383. if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
  384. if realtimeEvent.Session != nil {
  385. if realtimeEvent.Session.Tools != nil {
  386. info.RealtimeTools = realtimeEvent.Session.Tools
  387. }
  388. }
  389. }
  390. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  391. if err != nil {
  392. errChan <- fmt.Errorf("error counting text token: %v", err)
  393. return
  394. }
  395. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  396. localUsage.TotalTokens += textToken + audioToken
  397. localUsage.InputTokens += textToken + audioToken
  398. localUsage.InputTokenDetails.TextTokens += textToken
  399. localUsage.InputTokenDetails.AudioTokens += audioToken
  400. err = helper.WssString(c, targetConn, string(message))
  401. if err != nil {
  402. errChan <- fmt.Errorf("error writing to target: %v", err)
  403. return
  404. }
  405. select {
  406. case sendChan <- message:
  407. default:
  408. }
  409. }
  410. }
  411. })
  412. gopool.Go(func() {
  413. defer func() {
  414. if r := recover(); r != nil {
  415. errChan <- fmt.Errorf("panic in target reader: %v", r)
  416. }
  417. }()
  418. for {
  419. select {
  420. case <-c.Done():
  421. return
  422. default:
  423. _, message, err := targetConn.ReadMessage()
  424. if err != nil {
  425. if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
  426. errChan <- fmt.Errorf("error reading from target: %v", err)
  427. }
  428. close(targetClosed)
  429. return
  430. }
  431. info.SetFirstResponseTime()
  432. realtimeEvent := &dto.RealtimeEvent{}
  433. err = common.Unmarshal(message, realtimeEvent)
  434. if err != nil {
  435. errChan <- fmt.Errorf("error unmarshalling message: %v", err)
  436. return
  437. }
  438. if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
  439. realtimeUsage := realtimeEvent.Response.Usage
  440. if realtimeUsage != nil {
  441. usage.TotalTokens += realtimeUsage.TotalTokens
  442. usage.InputTokens += realtimeUsage.InputTokens
  443. usage.OutputTokens += realtimeUsage.OutputTokens
  444. usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
  445. usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
  446. usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
  447. usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
  448. usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
  449. err := preConsumeUsage(c, info, usage, sumUsage)
  450. if err != nil {
  451. errChan <- fmt.Errorf("error consume usage: %v", err)
  452. return
  453. }
  454. // 本次计费完成,清除
  455. usage = &dto.RealtimeUsage{}
  456. localUsage = &dto.RealtimeUsage{}
  457. } else {
  458. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  459. if err != nil {
  460. errChan <- fmt.Errorf("error counting text token: %v", err)
  461. return
  462. }
  463. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  464. localUsage.TotalTokens += textToken + audioToken
  465. info.IsFirstRequest = false
  466. localUsage.InputTokens += textToken + audioToken
  467. localUsage.InputTokenDetails.TextTokens += textToken
  468. localUsage.InputTokenDetails.AudioTokens += audioToken
  469. err = preConsumeUsage(c, info, localUsage, sumUsage)
  470. if err != nil {
  471. errChan <- fmt.Errorf("error consume usage: %v", err)
  472. return
  473. }
  474. // 本次计费完成,清除
  475. localUsage = &dto.RealtimeUsage{}
  476. // print now usage
  477. }
  478. logger.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
  479. logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  480. logger.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
  481. } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
  482. realtimeSession := realtimeEvent.Session
  483. if realtimeSession != nil {
  484. // update audio format
  485. info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
  486. info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
  487. }
  488. } else {
  489. textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
  490. if err != nil {
  491. errChan <- fmt.Errorf("error counting text token: %v", err)
  492. return
  493. }
  494. logger.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
  495. localUsage.TotalTokens += textToken + audioToken
  496. localUsage.OutputTokens += textToken + audioToken
  497. localUsage.OutputTokenDetails.TextTokens += textToken
  498. localUsage.OutputTokenDetails.AudioTokens += audioToken
  499. }
  500. err = helper.WssString(c, clientConn, string(message))
  501. if err != nil {
  502. errChan <- fmt.Errorf("error writing to client: %v", err)
  503. return
  504. }
  505. select {
  506. case receiveChan <- message:
  507. default:
  508. }
  509. }
  510. }
  511. })
  512. select {
  513. case <-clientClosed:
  514. case <-targetClosed:
  515. case err := <-errChan:
  516. //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
  517. logger.LogError(c, "realtime error: "+err.Error())
  518. case <-c.Done():
  519. }
  520. if usage.TotalTokens != 0 {
  521. _ = preConsumeUsage(c, info, usage, sumUsage)
  522. }
  523. if localUsage.TotalTokens != 0 {
  524. _ = preConsumeUsage(c, info, localUsage, sumUsage)
  525. }
  526. // check usage total tokens, if 0, use local usage
  527. return nil, sumUsage
  528. }
  529. func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
  530. if usage == nil || totalUsage == nil {
  531. return fmt.Errorf("invalid usage pointer")
  532. }
  533. totalUsage.TotalTokens += usage.TotalTokens
  534. totalUsage.InputTokens += usage.InputTokens
  535. totalUsage.OutputTokens += usage.OutputTokens
  536. totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
  537. totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
  538. totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
  539. totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
  540. totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
  541. // clear usage
  542. err := service.PreWssConsumeQuota(ctx, info, usage)
  543. return err
  544. }
  545. func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  546. defer service.CloseResponseBodyGracefully(resp)
  547. responseBody, err := io.ReadAll(resp.Body)
  548. if err != nil {
  549. return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
  550. }
  551. var usageResp dto.SimpleResponse
  552. err = common.Unmarshal(responseBody, &usageResp)
  553. if err != nil {
  554. return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
  555. }
  556. // 写入新的 response body
  557. service.IOCopyBytesGracefully(c, resp, responseBody)
  558. // Once we've written to the client, we should not return errors anymore
  559. // because the upstream has already consumed resources and returned content
  560. // We should still perform billing even if parsing fails
  561. // format
  562. if usageResp.InputTokens > 0 {
  563. usageResp.PromptTokens += usageResp.InputTokens
  564. }
  565. if usageResp.OutputTokens > 0 {
  566. usageResp.CompletionTokens += usageResp.OutputTokens
  567. }
  568. if usageResp.InputTokensDetails != nil {
  569. usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
  570. usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
  571. }
  572. applyUsagePostProcessing(info, &usageResp.Usage, responseBody)
  573. return &usageResp.Usage, nil
  574. }
  575. func applyUsagePostProcessing(info *relaycommon.RelayInfo, usage *dto.Usage, responseBody []byte) {
  576. if info == nil || usage == nil {
  577. return
  578. }
  579. switch info.ChannelType {
  580. case constant.ChannelTypeDeepSeek:
  581. if usage.PromptTokensDetails.CachedTokens == 0 && usage.PromptCacheHitTokens != 0 {
  582. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  583. }
  584. case constant.ChannelTypeZhipu_v4:
  585. if usage.PromptTokensDetails.CachedTokens == 0 {
  586. if usage.InputTokensDetails != nil && usage.InputTokensDetails.CachedTokens > 0 {
  587. usage.PromptTokensDetails.CachedTokens = usage.InputTokensDetails.CachedTokens
  588. } else if cachedTokens, ok := extractCachedTokensFromBody(responseBody); ok {
  589. usage.PromptTokensDetails.CachedTokens = cachedTokens
  590. } else if usage.PromptCacheHitTokens > 0 {
  591. usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
  592. }
  593. }
  594. }
  595. }
  596. func extractCachedTokensFromBody(body []byte) (int, bool) {
  597. if len(body) == 0 {
  598. return 0, false
  599. }
  600. var payload struct {
  601. Usage struct {
  602. PromptTokensDetails struct {
  603. CachedTokens *int `json:"cached_tokens"`
  604. } `json:"prompt_tokens_details"`
  605. CachedTokens *int `json:"cached_tokens"`
  606. PromptCacheHitTokens *int `json:"prompt_cache_hit_tokens"`
  607. } `json:"usage"`
  608. }
  609. if err := json.Unmarshal(body, &payload); err != nil {
  610. return 0, false
  611. }
  612. if payload.Usage.PromptTokensDetails.CachedTokens != nil {
  613. return *payload.Usage.PromptTokensDetails.CachedTokens, true
  614. }
  615. if payload.Usage.CachedTokens != nil {
  616. return *payload.Usage.CachedTokens, true
  617. }
  618. if payload.Usage.PromptCacheHitTokens != nil {
  619. return *payload.Usage.PromptCacheHitTokens, true
  620. }
  621. return 0, false
  622. }