compatible_handler.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. package relay
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "strings"
  8. "github.com/QuantumNous/new-api/common"
  9. "github.com/QuantumNous/new-api/constant"
  10. "github.com/QuantumNous/new-api/dto"
  11. "github.com/QuantumNous/new-api/logger"
  12. relaycommon "github.com/QuantumNous/new-api/relay/common"
  13. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  14. "github.com/QuantumNous/new-api/relay/helper"
  15. "github.com/QuantumNous/new-api/service"
  16. "github.com/QuantumNous/new-api/setting/model_setting"
  17. "github.com/QuantumNous/new-api/setting/ratio_setting"
  18. "github.com/QuantumNous/new-api/types"
  19. "github.com/samber/lo"
  20. "github.com/gin-gonic/gin"
  21. )
  22. func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
  23. info.InitChannelMeta(c)
  24. textReq, ok := info.Request.(*dto.GeneralOpenAIRequest)
  25. if !ok {
  26. return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
  27. }
  28. request, err := common.DeepCopy(textReq)
  29. if err != nil {
  30. return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
  31. }
  32. if request.WebSearchOptions != nil {
  33. c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize)
  34. }
  35. err = helper.ModelMappedHelper(c, info, request)
  36. if err != nil {
  37. return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
  38. }
  39. includeUsage := true
  40. // 判断用户是否需要返回使用情况
  41. if request.StreamOptions != nil {
  42. includeUsage = request.StreamOptions.IncludeUsage
  43. }
  44. // 如果不支持StreamOptions,将StreamOptions设置为nil
  45. if !info.SupportStreamOptions || !lo.FromPtrOr(request.Stream, false) {
  46. request.StreamOptions = nil
  47. } else {
  48. // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions
  49. if constant.ForceStreamOption {
  50. request.StreamOptions = &dto.StreamOptions{
  51. IncludeUsage: true,
  52. }
  53. }
  54. }
  55. info.ShouldIncludeUsage = includeUsage
  56. adaptor := GetAdaptor(info.ApiType)
  57. if adaptor == nil {
  58. return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
  59. }
  60. adaptor.Init(info)
  61. passThroughGlobal := model_setting.GetGlobalSettings().PassThroughRequestEnabled
  62. if info.RelayMode == relayconstant.RelayModeChatCompletions &&
  63. !passThroughGlobal &&
  64. !info.ChannelSetting.PassThroughBodyEnabled &&
  65. service.ShouldChatCompletionsUseResponsesGlobal(info.ChannelId, info.ChannelType, info.OriginModelName) {
  66. applySystemPromptIfNeeded(c, info, request)
  67. usage, newApiErr := chatCompletionsViaResponses(c, info, adaptor, request)
  68. if newApiErr != nil {
  69. return newApiErr
  70. }
  71. var containAudioTokens = usage.CompletionTokenDetails.AudioTokens > 0 || usage.PromptTokensDetails.AudioTokens > 0
  72. var containsAudioRatios = ratio_setting.ContainsAudioRatio(info.OriginModelName) || ratio_setting.ContainsAudioCompletionRatio(info.OriginModelName)
  73. if containAudioTokens && containsAudioRatios {
  74. service.PostAudioConsumeQuota(c, info, usage, "")
  75. } else {
  76. service.PostTextConsumeQuota(c, info, usage, nil)
  77. }
  78. return nil
  79. }
  80. var requestBody io.Reader
  81. if passThroughGlobal || info.ChannelSetting.PassThroughBodyEnabled {
  82. storage, err := common.GetBodyStorage(c)
  83. if err != nil {
  84. return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
  85. }
  86. if common.DebugEnabled {
  87. if debugBytes, bErr := storage.Bytes(); bErr == nil {
  88. println("requestBody: ", string(debugBytes))
  89. }
  90. }
  91. requestBody = common.ReaderOnly(storage)
  92. } else {
  93. convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request)
  94. if err != nil {
  95. return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
  96. }
  97. relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
  98. if info.ChannelSetting.SystemPrompt != "" {
  99. // 如果有系统提示,则将其添加到请求中
  100. request, ok := convertedRequest.(*dto.GeneralOpenAIRequest)
  101. if ok {
  102. containSystemPrompt := false
  103. for _, message := range request.Messages {
  104. if message.Role == request.GetSystemRoleName() {
  105. containSystemPrompt = true
  106. break
  107. }
  108. }
  109. if !containSystemPrompt {
  110. // 如果没有系统提示,则添加系统提示
  111. systemMessage := dto.Message{
  112. Role: request.GetSystemRoleName(),
  113. Content: info.ChannelSetting.SystemPrompt,
  114. }
  115. request.Messages = append([]dto.Message{systemMessage}, request.Messages...)
  116. } else if info.ChannelSetting.SystemPromptOverride {
  117. common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true)
  118. // 如果有系统提示,且允许覆盖,则拼接到前面
  119. for i, message := range request.Messages {
  120. if message.Role == request.GetSystemRoleName() {
  121. if message.IsStringContent() {
  122. request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent())
  123. } else {
  124. contents := message.ParseContent()
  125. contents = append([]dto.MediaContent{
  126. {
  127. Type: dto.ContentTypeText,
  128. Text: info.ChannelSetting.SystemPrompt,
  129. },
  130. }, contents...)
  131. request.Messages[i].Content = contents
  132. }
  133. break
  134. }
  135. }
  136. }
  137. }
  138. }
  139. jsonData, err := common.Marshal(convertedRequest)
  140. if err != nil {
  141. return types.NewError(err, types.ErrorCodeJsonMarshalFailed, types.ErrOptionWithSkipRetry())
  142. }
  143. // remove disabled fields for OpenAI API
  144. jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled)
  145. if err != nil {
  146. return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
  147. }
  148. // apply param override
  149. if len(info.ParamOverride) > 0 {
  150. jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
  151. if err != nil {
  152. return newAPIErrorFromParamOverride(err)
  153. }
  154. }
  155. logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData)))
  156. requestBody = bytes.NewBuffer(jsonData)
  157. }
  158. var httpResp *http.Response
  159. resp, err := adaptor.DoRequest(c, info, requestBody)
  160. if err != nil {
  161. return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
  162. }
  163. statusCodeMappingStr := c.GetString("status_code_mapping")
  164. if resp != nil {
  165. httpResp = resp.(*http.Response)
  166. info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
  167. if httpResp.StatusCode != http.StatusOK {
  168. newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false)
  169. // reset status code 重置状态码
  170. service.ResetStatusCode(newApiErr, statusCodeMappingStr)
  171. return newApiErr
  172. }
  173. }
  174. usage, newApiErr := adaptor.DoResponse(c, httpResp, info)
  175. if newApiErr != nil {
  176. // reset status code 重置状态码
  177. service.ResetStatusCode(newApiErr, statusCodeMappingStr)
  178. return newApiErr
  179. }
  180. var containAudioTokens = usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0
  181. var containsAudioRatios = ratio_setting.ContainsAudioRatio(info.OriginModelName) || ratio_setting.ContainsAudioCompletionRatio(info.OriginModelName)
  182. if containAudioTokens && containsAudioRatios {
  183. service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "")
  184. } else {
  185. service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil)
  186. }
  187. return nil
  188. }