helper.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package openai
  2. import (
  3. "encoding/json"
  4. "strings"
  5. "github.com/QuantumNous/new-api/common"
  6. "github.com/QuantumNous/new-api/dto"
  7. "github.com/QuantumNous/new-api/logger"
  8. relaycommon "github.com/QuantumNous/new-api/relay/common"
  9. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  10. "github.com/QuantumNous/new-api/relay/helper"
  11. "github.com/QuantumNous/new-api/service"
  12. "github.com/QuantumNous/new-api/types"
  13. "github.com/samber/lo"
  14. "github.com/gin-gonic/gin"
  15. )
  16. // 辅助函数
  17. func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
  18. info.SendResponseCount++
  19. switch info.RelayFormat {
  20. case types.RelayFormatOpenAI:
  21. return sendStreamData(c, info, data, forceFormat, thinkToContent)
  22. case types.RelayFormatClaude:
  23. return handleClaudeFormat(c, data, info)
  24. case types.RelayFormatGemini:
  25. return handleGeminiFormat(c, data, info)
  26. }
  27. return nil
  28. }
  29. func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
  30. var streamResponse dto.ChatCompletionsStreamResponse
  31. if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
  32. return err
  33. }
  34. if streamResponse.Usage != nil {
  35. info.ClaudeConvertInfo.Usage = streamResponse.Usage
  36. }
  37. claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
  38. for _, resp := range claudeResponses {
  39. helper.ClaudeData(c, *resp)
  40. }
  41. return nil
  42. }
  43. func handleGeminiFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
  44. var streamResponse dto.ChatCompletionsStreamResponse
  45. if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
  46. logger.LogError(c, "failed to unmarshal stream response: "+err.Error())
  47. return err
  48. }
  49. geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
  50. // 如果返回 nil,表示没有实际内容,跳过发送
  51. if geminiResponse == nil {
  52. return nil
  53. }
  54. geminiResponseStr, err := common.Marshal(geminiResponse)
  55. if err != nil {
  56. logger.LogError(c, "failed to marshal gemini response: "+err.Error())
  57. return err
  58. }
  59. // send gemini format response
  60. c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
  61. _ = helper.FlushWriter(c)
  62. return nil
  63. }
  64. func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
  65. for _, choice := range streamResponse.Choices {
  66. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  67. responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
  68. if choice.Delta.ToolCalls != nil {
  69. if len(choice.Delta.ToolCalls) > *toolCount {
  70. *toolCount = len(choice.Delta.ToolCalls)
  71. }
  72. for _, tool := range choice.Delta.ToolCalls {
  73. responseTextBuilder.WriteString(tool.Function.Name)
  74. responseTextBuilder.WriteString(tool.Function.Arguments)
  75. }
  76. }
  77. }
  78. return nil
  79. }
  80. func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
  81. streamResp := "[" + strings.Join(streamItems, ",") + "]"
  82. switch relayMode {
  83. case relayconstant.RelayModeChatCompletions:
  84. return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
  85. case relayconstant.RelayModeCompletions:
  86. return processCompletions(streamResp, streamItems, responseTextBuilder)
  87. }
  88. return nil
  89. }
  90. func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
  91. var streamResponses []dto.ChatCompletionsStreamResponse
  92. if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
  93. // 一次性解析失败,逐个解析
  94. common.SysLog("error unmarshalling stream response: " + err.Error())
  95. for _, item := range streamItems {
  96. var streamResponse dto.ChatCompletionsStreamResponse
  97. if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
  98. return err
  99. }
  100. if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
  101. common.SysLog("error processing stream response: " + err.Error())
  102. }
  103. }
  104. return nil
  105. }
  106. // 批量处理所有响应
  107. for _, streamResponse := range streamResponses {
  108. for _, choice := range streamResponse.Choices {
  109. responseTextBuilder.WriteString(choice.Delta.GetContentString())
  110. responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
  111. if choice.Delta.ToolCalls != nil {
  112. if len(choice.Delta.ToolCalls) > *toolCount {
  113. *toolCount = len(choice.Delta.ToolCalls)
  114. }
  115. for _, tool := range choice.Delta.ToolCalls {
  116. responseTextBuilder.WriteString(tool.Function.Name)
  117. responseTextBuilder.WriteString(tool.Function.Arguments)
  118. }
  119. }
  120. }
  121. }
  122. return nil
  123. }
  124. func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
  125. var streamResponses []dto.CompletionsStreamResponse
  126. if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
  127. // 一次性解析失败,逐个解析
  128. common.SysLog("error unmarshalling stream response: " + err.Error())
  129. for _, item := range streamItems {
  130. var streamResponse dto.CompletionsStreamResponse
  131. if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
  132. continue
  133. }
  134. for _, choice := range streamResponse.Choices {
  135. responseTextBuilder.WriteString(choice.Text)
  136. }
  137. }
  138. return nil
  139. }
  140. // 批量处理所有响应
  141. for _, streamResponse := range streamResponses {
  142. for _, choice := range streamResponse.Choices {
  143. responseTextBuilder.WriteString(choice.Text)
  144. }
  145. }
  146. return nil
  147. }
  148. func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
  149. systemFingerprint *string, model *string, usage **dto.Usage,
  150. containStreamUsage *bool, info *relaycommon.RelayInfo,
  151. shouldSendLastResp *bool) error {
  152. var lastStreamResponse dto.ChatCompletionsStreamResponse
  153. if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
  154. return err
  155. }
  156. *responseId = lastStreamResponse.Id
  157. *createAt = lastStreamResponse.Created
  158. *systemFingerprint = lastStreamResponse.GetSystemFingerprint()
  159. *model = lastStreamResponse.Model
  160. if service.ValidUsage(lastStreamResponse.Usage) {
  161. *containStreamUsage = true
  162. *usage = lastStreamResponse.Usage
  163. if !info.ShouldIncludeUsage {
  164. *shouldSendLastResp = lo.SomeBy(lastStreamResponse.Choices, func(choice dto.ChatCompletionsStreamResponseChoice) bool {
  165. return choice.Delta.GetContentString() != "" || choice.Delta.GetReasoningContent() != ""
  166. })
  167. }
  168. }
  169. return nil
  170. }
  171. func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
  172. responseId string, createAt int64, model string, systemFingerprint string,
  173. usage *dto.Usage, containStreamUsage bool) {
  174. switch info.RelayFormat {
  175. case types.RelayFormatOpenAI:
  176. if info.ShouldIncludeUsage && !containStreamUsage {
  177. response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
  178. response.SetSystemFingerprint(systemFingerprint)
  179. helper.ObjectData(c, response)
  180. }
  181. helper.Done(c)
  182. case types.RelayFormatClaude:
  183. info.ClaudeConvertInfo.Done = true
  184. var streamResponse dto.ChatCompletionsStreamResponse
  185. if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
  186. common.SysLog("error unmarshalling stream response: " + err.Error())
  187. return
  188. }
  189. info.ClaudeConvertInfo.Usage = usage
  190. claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
  191. for _, resp := range claudeResponses {
  192. _ = helper.ClaudeData(c, *resp)
  193. }
  194. case types.RelayFormatGemini:
  195. var streamResponse dto.ChatCompletionsStreamResponse
  196. if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
  197. common.SysLog("error unmarshalling stream response: " + err.Error())
  198. return
  199. }
  200. // 这里处理的是 openai 最后一个流响应,其 delta 为空,有 finish_reason 字段
  201. // 因此相比较于 google 官方的流响应,由 openai 转换而来会多一个 parts 为空,finishReason 为 STOP 的响应
  202. // 而包含最后一段文本输出的响应(倒数第二个)的 finishReason 为 null
  203. // 暂不知是否有程序会不兼容。
  204. geminiResponse := service.StreamResponseOpenAI2Gemini(&streamResponse, info)
  205. // openai 流响应开头的空数据
  206. if geminiResponse == nil {
  207. return
  208. }
  209. geminiResponseStr, err := common.Marshal(geminiResponse)
  210. if err != nil {
  211. common.SysLog("error marshalling gemini response: " + err.Error())
  212. return
  213. }
  214. // 发送最终的 Gemini 响应
  215. c.Render(-1, common.CustomEvent{Data: "data: " + string(geminiResponseStr)})
  216. _ = helper.FlushWriter(c)
  217. }
  218. }
  219. func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
  220. if data == "" {
  221. return
  222. }
  223. helper.ResponseChunkData(c, streamResponse, data)
  224. }