rerank.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. package common_handler
  2. import (
  3. "io"
  4. "net/http"
  5. "one-api/common"
  6. "one-api/constant"
  7. "one-api/dto"
  8. "one-api/relay/channel/xinference"
  9. relaycommon "one-api/relay/common"
  10. "one-api/types"
  11. "github.com/gin-gonic/gin"
  12. )
  13. func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  14. responseBody, err := io.ReadAll(resp.Body)
  15. if err != nil {
  16. return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
  17. }
  18. common.CloseResponseBodyGracefully(resp)
  19. if common.DebugEnabled {
  20. println("reranker response body: ", string(responseBody))
  21. }
  22. var jinaResp dto.RerankResponse
  23. if info.ChannelType == constant.ChannelTypeXinference {
  24. var xinRerankResponse xinference.XinRerankResponse
  25. err = common.Unmarshal(responseBody, &xinRerankResponse)
  26. if err != nil {
  27. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  28. }
  29. jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
  30. for i, result := range xinRerankResponse.Results {
  31. respResult := dto.RerankResponseResult{
  32. Index: result.Index,
  33. RelevanceScore: result.RelevanceScore,
  34. }
  35. if info.ReturnDocuments {
  36. var document any
  37. if result.Document != nil {
  38. if doc, ok := result.Document.(string); ok {
  39. if doc == "" {
  40. document = info.Documents[result.Index]
  41. } else {
  42. document = doc
  43. }
  44. } else {
  45. document = result.Document
  46. }
  47. }
  48. respResult.Document = document
  49. }
  50. jinaRespResults[i] = respResult
  51. }
  52. jinaResp = dto.RerankResponse{
  53. Results: jinaRespResults,
  54. Usage: dto.Usage{
  55. PromptTokens: info.PromptTokens,
  56. TotalTokens: info.PromptTokens,
  57. },
  58. }
  59. } else {
  60. err = common.Unmarshal(responseBody, &jinaResp)
  61. if err != nil {
  62. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  63. }
  64. jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
  65. }
  66. c.Writer.Header().Set("Content-Type", "application/json")
  67. c.JSON(http.StatusOK, jinaResp)
  68. return &jinaResp.Usage, nil
  69. }