rerank.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package textembeddingsinference
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "strconv"
  8. "github.com/bytedance/sonic"
  9. "github.com/bytedance/sonic/ast"
  10. "github.com/gin-gonic/gin"
  11. "github.com/labring/aiproxy/core/common"
  12. "github.com/labring/aiproxy/core/model"
  13. "github.com/labring/aiproxy/core/relay/adaptor"
  14. "github.com/labring/aiproxy/core/relay/meta"
  15. relaymodel "github.com/labring/aiproxy/core/relay/model"
  16. )
  17. func ConvertRerankRequest(
  18. meta *meta.Meta,
  19. req *http.Request,
  20. ) (adaptor.ConvertResult, error) {
  21. node, err := common.UnmarshalRequest2NodeReusable(req)
  22. if err != nil {
  23. return adaptor.ConvertResult{}, fmt.Errorf("failed to parse request body: %w", err)
  24. }
  25. // Set the actual model in the request
  26. _, err = node.Set("model", ast.NewString(meta.ActualModel))
  27. if err != nil {
  28. return adaptor.ConvertResult{}, err
  29. }
  30. // Get the documents array and rename it to texts
  31. documentsNode := node.Get("documents")
  32. if !documentsNode.Exists() {
  33. return adaptor.ConvertResult{}, errors.New("documents field not found")
  34. }
  35. // Set the texts field with the documents value
  36. _, err = node.Set("texts", *documentsNode)
  37. if err != nil {
  38. return adaptor.ConvertResult{}, fmt.Errorf("failed to set texts field: %w", err)
  39. }
  40. // Remove the documents field
  41. _, err = node.Unset("documents")
  42. if err != nil {
  43. return adaptor.ConvertResult{}, fmt.Errorf(
  44. "failed to remove documents field: %w",
  45. err,
  46. )
  47. }
  48. returnDocumentsNode := node.Get("return_documents")
  49. if returnDocumentsNode.Exists() {
  50. returnDocuments, err := returnDocumentsNode.Bool()
  51. if err != nil {
  52. return adaptor.ConvertResult{}, fmt.Errorf(
  53. "failed to unmarshal return_documents field: %w",
  54. err,
  55. )
  56. }
  57. _, err = node.Unset("return_documents")
  58. if err != nil {
  59. return adaptor.ConvertResult{}, fmt.Errorf(
  60. "failed to remove return_documents field: %w",
  61. err,
  62. )
  63. }
  64. _, err = node.Set("return_text", ast.NewBool(returnDocuments))
  65. if err != nil {
  66. return adaptor.ConvertResult{}, fmt.Errorf(
  67. "failed to set return_text field: %w",
  68. err,
  69. )
  70. }
  71. }
  72. // Convert back to JSON
  73. jsonData, err := node.MarshalJSON()
  74. if err != nil {
  75. return adaptor.ConvertResult{}, fmt.Errorf("failed to marshal request: %w", err)
  76. }
  77. return adaptor.ConvertResult{
  78. Header: http.Header{
  79. "Content-Type": {"application/json"},
  80. "Content-Length": {strconv.Itoa(len(jsonData))},
  81. },
  82. Body: bytes.NewReader(jsonData),
  83. }, nil
  84. }
  85. type RerankResponse []RerankResponseItem
  86. type RerankResponseItem struct {
  87. Index int `json:"index"`
  88. Score float64 `json:"score"`
  89. Text string `json:"text,omitempty"`
  90. }
  91. func (rri *RerankResponseItem) ToRerankModel() *relaymodel.RerankResult {
  92. var document *relaymodel.Document
  93. if rri.Text != "" {
  94. document = &relaymodel.Document{
  95. Text: rri.Text,
  96. }
  97. }
  98. return &relaymodel.RerankResult{
  99. Index: rri.Index,
  100. RelevanceScore: rri.Score,
  101. Document: document,
  102. }
  103. }
  104. func RerankHandler(
  105. meta *meta.Meta,
  106. c *gin.Context,
  107. resp *http.Response,
  108. ) (model.Usage, adaptor.Error) {
  109. if resp.StatusCode != http.StatusOK {
  110. return model.Usage{}, RerankErrorHanlder(resp)
  111. }
  112. defer resp.Body.Close()
  113. log := common.GetLogger(c)
  114. respSlice := RerankResponse{}
  115. err := sonic.ConfigDefault.NewDecoder(resp.Body).Decode(&respSlice)
  116. if err != nil {
  117. return model.Usage{}, relaymodel.WrapperOpenAIError(
  118. err,
  119. "read_response_body_failed",
  120. http.StatusInternalServerError,
  121. )
  122. }
  123. usage := model.Usage{
  124. InputTokens: meta.RequestUsage.InputTokens,
  125. TotalTokens: meta.RequestUsage.InputTokens,
  126. }
  127. results := make([]*relaymodel.RerankResult, len(respSlice))
  128. for i, v := range respSlice {
  129. results[i] = v.ToRerankModel()
  130. }
  131. rerankResp := relaymodel.RerankResponse{
  132. Meta: relaymodel.RerankMeta{
  133. Tokens: &relaymodel.RerankMetaTokens{
  134. InputTokens: int64(usage.InputTokens),
  135. },
  136. },
  137. Results: results,
  138. ID: meta.RequestID,
  139. }
  140. jsonResponse, err := sonic.Marshal(rerankResp)
  141. if err != nil {
  142. return usage, relaymodel.WrapperOpenAIError(
  143. err,
  144. "marshal_response_body_failed",
  145. http.StatusInternalServerError,
  146. )
  147. }
  148. c.Writer.Header().Set("Content-Type", "application/json")
  149. c.Writer.Header().Set("Content-Length", strconv.Itoa(len(jsonResponse)))
  150. _, err = c.Writer.Write(jsonResponse)
  151. if err != nil {
  152. log.Warnf("write response body failed: %v", err)
  153. }
  154. return usage, nil
  155. }