text.go 6.1 KB


  1. package ali
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. "one-api/relay/helper"
  10. "strings"
  11. "one-api/types"
  12. "github.com/gin-gonic/gin"
  13. )
  14. // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
  15. const EnableSearchModelSuffix = "-internet"
  16. func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
  17. if request.TopP >= 1 {
  18. request.TopP = 0.999
  19. } else if request.TopP <= 0 {
  20. request.TopP = 0.001
  21. }
  22. return &request
  23. }
  24. func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
  25. return &AliEmbeddingRequest{
  26. Model: request.Model,
  27. Input: struct {
  28. Texts []string `json:"texts"`
  29. }{
  30. Texts: request.ParseInput(),
  31. },
  32. }
  33. }
  34. func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  35. var fullTextResponse dto.FlexibleEmbeddingResponse
  36. err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
  37. if err != nil {
  38. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  39. }
  40. common.CloseResponseBodyGracefully(resp)
  41. model := c.GetString("model")
  42. if model == "" {
  43. model = "text-embedding-v4"
  44. }
  45. jsonResponse, err := json.Marshal(fullTextResponse)
  46. if err != nil {
  47. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  48. }
  49. c.Writer.Header().Set("Content-Type", "application/json")
  50. c.Writer.WriteHeader(resp.StatusCode)
  51. c.Writer.Write(jsonResponse)
  52. return nil, &fullTextResponse.Usage
  53. }
  54. func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
  55. openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
  56. Object: "list",
  57. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
  58. Model: model,
  59. Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
  60. }
  61. for _, item := range response.Output.Embeddings {
  62. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
  63. Object: `embedding`,
  64. Index: item.TextIndex,
  65. Embedding: item.Embedding,
  66. })
  67. }
  68. return &openAIEmbeddingResponse
  69. }
  70. func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
  71. choice := dto.OpenAITextResponseChoice{
  72. Index: 0,
  73. Message: dto.Message{
  74. Role: "assistant",
  75. Content: response.Output.Text,
  76. },
  77. FinishReason: response.Output.FinishReason,
  78. }
  79. fullTextResponse := dto.OpenAITextResponse{
  80. Id: response.RequestId,
  81. Object: "chat.completion",
  82. Created: common.GetTimestamp(),
  83. Choices: []dto.OpenAITextResponseChoice{choice},
  84. Usage: dto.Usage{
  85. PromptTokens: response.Usage.InputTokens,
  86. CompletionTokens: response.Usage.OutputTokens,
  87. TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
  88. },
  89. }
  90. return &fullTextResponse
  91. }
  92. func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
  93. var choice dto.ChatCompletionsStreamResponseChoice
  94. choice.Delta.SetContentString(aliResponse.Output.Text)
  95. if aliResponse.Output.FinishReason != "null" {
  96. finishReason := aliResponse.Output.FinishReason
  97. choice.FinishReason = &finishReason
  98. }
  99. response := dto.ChatCompletionsStreamResponse{
  100. Id: aliResponse.RequestId,
  101. Object: "chat.completion.chunk",
  102. Created: common.GetTimestamp(),
  103. Model: "ernie-bot",
  104. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  105. }
  106. return &response
  107. }
  108. func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  109. var usage dto.Usage
  110. scanner := bufio.NewScanner(resp.Body)
  111. scanner.Split(bufio.ScanLines)
  112. dataChan := make(chan string)
  113. stopChan := make(chan bool)
  114. go func() {
  115. for scanner.Scan() {
  116. data := scanner.Text()
  117. if len(data) < 5 { // ignore blank line or wrong format
  118. continue
  119. }
  120. if data[:5] != "data:" {
  121. continue
  122. }
  123. data = data[5:]
  124. dataChan <- data
  125. }
  126. stopChan <- true
  127. }()
  128. helper.SetEventStreamHeaders(c)
  129. lastResponseText := ""
  130. c.Stream(func(w io.Writer) bool {
  131. select {
  132. case data := <-dataChan:
  133. var aliResponse AliResponse
  134. err := json.Unmarshal([]byte(data), &aliResponse)
  135. if err != nil {
  136. common.SysError("error unmarshalling stream response: " + err.Error())
  137. return true
  138. }
  139. if aliResponse.Usage.OutputTokens != 0 {
  140. usage.PromptTokens = aliResponse.Usage.InputTokens
  141. usage.CompletionTokens = aliResponse.Usage.OutputTokens
  142. usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
  143. }
  144. response := streamResponseAli2OpenAI(&aliResponse)
  145. response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
  146. lastResponseText = aliResponse.Output.Text
  147. jsonResponse, err := json.Marshal(response)
  148. if err != nil {
  149. common.SysError("error marshalling stream response: " + err.Error())
  150. return true
  151. }
  152. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  153. return true
  154. case <-stopChan:
  155. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  156. return false
  157. }
  158. })
  159. common.CloseResponseBodyGracefully(resp)
  160. return nil, &usage
  161. }
  162. func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  163. var aliResponse AliResponse
  164. responseBody, err := io.ReadAll(resp.Body)
  165. if err != nil {
  166. return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
  167. }
  168. common.CloseResponseBodyGracefully(resp)
  169. err = json.Unmarshal(responseBody, &aliResponse)
  170. if err != nil {
  171. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  172. }
  173. if aliResponse.Code != "" {
  174. return types.WithOpenAIError(types.OpenAIError{
  175. Message: aliResponse.Message,
  176. Type: "ali_error",
  177. Param: aliResponse.RequestId,
  178. Code: aliResponse.Code,
  179. }, resp.StatusCode), nil
  180. }
  181. fullTextResponse := responseAli2OpenAI(&aliResponse)
  182. jsonResponse, err := common.Marshal(fullTextResponse)
  183. if err != nil {
  184. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  185. }
  186. c.Writer.Header().Set("Content-Type", "application/json")
  187. c.Writer.WriteHeader(resp.StatusCode)
  188. _, err = c.Writer.Write(jsonResponse)
  189. return nil, &fullTextResponse.Usage
  190. }