rerank.go 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. package controller
  2. import (
  3. "errors"
  4. "github.com/gin-gonic/gin"
  5. "github.com/labring/aiproxy/core/model"
  6. "github.com/labring/aiproxy/core/relay/adaptor/openai"
  7. relaymodel "github.com/labring/aiproxy/core/relay/model"
  8. "github.com/labring/aiproxy/core/relay/utils"
  9. )
  10. func getRerankRequest(c *gin.Context) (*relaymodel.RerankRequest, error) {
  11. rerankRequest, err := utils.UnmarshalRerankRequest(c.Request)
  12. if err != nil {
  13. return nil, err
  14. }
  15. if rerankRequest.Model == "" {
  16. return nil, errors.New("model parameter must be provided")
  17. }
  18. if rerankRequest.Query == "" {
  19. return nil, errors.New("query must not be empty")
  20. }
  21. if len(rerankRequest.Documents) == 0 {
  22. return nil, errors.New("document list must not be empty")
  23. }
  24. return rerankRequest, nil
  25. }
  26. func rerankPromptTokens(rerankRequest *relaymodel.RerankRequest) int64 {
  27. tokens := openai.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
  28. for _, d := range rerankRequest.Documents {
  29. tokens += openai.CountTokenInput(d, rerankRequest.Model)
  30. }
  31. return tokens
  32. }
  33. func GetRerankRequestUsage(c *gin.Context, _ model.ModelConfig) (model.Usage, error) {
  34. rerankRequest, err := getRerankRequest(c)
  35. if err != nil {
  36. return model.Usage{}, err
  37. }
  38. return model.Usage{
  39. InputTokens: model.ZeroNullInt64(rerankPromptTokens(rerankRequest)),
  40. }, nil
  41. }