embedding_handler.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. package relay
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. relaycommon "one-api/relay/common"
  10. relayconstant "one-api/relay/constant"
  11. "one-api/relay/helper"
  12. "one-api/service"
  13. "one-api/types"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
  17. token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
  18. return token
  19. }
  20. func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
  21. if embeddingRequest.Input == nil {
  22. return fmt.Errorf("input is empty")
  23. }
  24. if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
  25. embeddingRequest.Model = "omni-moderation-latest"
  26. }
  27. if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
  28. embeddingRequest.Model = c.Param("model")
  29. }
  30. return nil
  31. }
  32. func EmbeddingHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
  33. relayInfo := relaycommon.GenRelayInfoEmbedding(c)
  34. var embeddingRequest *dto.EmbeddingRequest
  35. err := common.UnmarshalBodyReusable(c, &embeddingRequest)
  36. if err != nil {
  37. common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
  38. return types.NewError(err, types.ErrorCodeInvalidRequest)
  39. }
  40. err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
  41. if err != nil {
  42. return types.NewError(err, types.ErrorCodeInvalidRequest)
  43. }
  44. err = helper.ModelMappedHelper(c, relayInfo, embeddingRequest)
  45. if err != nil {
  46. return types.NewError(err, types.ErrorCodeChannelModelMappedError)
  47. }
  48. promptToken := getEmbeddingPromptToken(*embeddingRequest)
  49. relayInfo.PromptTokens = promptToken
  50. priceData, err := helper.ModelPriceHelper(c, relayInfo, promptToken, 0)
  51. if err != nil {
  52. return types.NewError(err, types.ErrorCodeModelPriceError)
  53. }
  54. // pre-consume quota 预消耗配额
  55. preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
  56. if newAPIError != nil {
  57. return newAPIError
  58. }
  59. defer func() {
  60. if newAPIError != nil {
  61. returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
  62. }
  63. }()
  64. adaptor := GetAdaptor(relayInfo.ApiType)
  65. if adaptor == nil {
  66. return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
  67. }
  68. adaptor.Init(relayInfo)
  69. convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
  70. if err != nil {
  71. return types.NewError(err, types.ErrorCodeConvertRequestFailed)
  72. }
  73. jsonData, err := json.Marshal(convertedRequest)
  74. if err != nil {
  75. return types.NewError(err, types.ErrorCodeConvertRequestFailed)
  76. }
  77. requestBody := bytes.NewBuffer(jsonData)
  78. statusCodeMappingStr := c.GetString("status_code_mapping")
  79. resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
  80. if err != nil {
  81. return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
  82. }
  83. var httpResp *http.Response
  84. if resp != nil {
  85. httpResp = resp.(*http.Response)
  86. if httpResp.StatusCode != http.StatusOK {
  87. newAPIError = service.RelayErrorHandler(httpResp, false)
  88. // reset status code 重置状态码
  89. service.ResetStatusCode(newAPIError, statusCodeMappingStr)
  90. return newAPIError
  91. }
  92. }
  93. usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
  94. if newAPIError != nil {
  95. // reset status code 重置状态码
  96. service.ResetStatusCode(newAPIError, statusCodeMappingStr)
  97. return newAPIError
  98. }
  99. postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
  100. return nil
  101. }