rerank_handler.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package relay
  2. import (
  3. "bytes"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. relaycommon "one-api/relay/common"
  10. "one-api/relay/helper"
  11. "one-api/service"
  12. "one-api/setting/model_setting"
  13. "one-api/types"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
  17. info.InitChannelMeta(c)
  18. rerankReq, ok := info.Request.(*dto.RerankRequest)
  19. if !ok {
  20. return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.RerankRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
  21. }
  22. request, err := common.DeepCopy(rerankReq)
  23. if err != nil {
  24. return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
  25. }
  26. err = helper.ModelMappedHelper(c, info, request)
  27. if err != nil {
  28. return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
  29. }
  30. adaptor := GetAdaptor(info.ApiType)
  31. if adaptor == nil {
  32. return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
  33. }
  34. adaptor.Init(info)
  35. var requestBody io.Reader
  36. if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
  37. body, err := common.GetRequestBody(c)
  38. if err != nil {
  39. return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
  40. }
  41. requestBody = bytes.NewBuffer(body)
  42. } else {
  43. convertedRequest, err := adaptor.ConvertRerankRequest(c, info.RelayMode, *request)
  44. if err != nil {
  45. return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
  46. }
  47. jsonData, err := common.Marshal(convertedRequest)
  48. if err != nil {
  49. return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
  50. }
  51. // apply param override
  52. if len(info.ParamOverride) > 0 {
  53. jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
  54. if err != nil {
  55. return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
  56. }
  57. }
  58. if common.DebugEnabled {
  59. println(fmt.Sprintf("Rerank request body: %s", string(jsonData)))
  60. }
  61. requestBody = bytes.NewBuffer(jsonData)
  62. }
  63. resp, err := adaptor.DoRequest(c, info, requestBody)
  64. if err != nil {
  65. return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
  66. }
  67. statusCodeMappingStr := c.GetString("status_code_mapping")
  68. var httpResp *http.Response
  69. if resp != nil {
  70. httpResp = resp.(*http.Response)
  71. if httpResp.StatusCode != http.StatusOK {
  72. newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
  73. // reset status code 重置状态码
  74. service.ResetStatusCode(newAPIError, statusCodeMappingStr)
  75. return newAPIError
  76. }
  77. }
  78. usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
  79. if newAPIError != nil {
  80. // reset status code 重置状态码
  81. service.ResetStatusCode(newAPIError, statusCodeMappingStr)
  82. return newAPIError
  83. }
  84. postConsumeQuota(c, info, usage.(*dto.Usage), "")
  85. return nil
  86. }