relay-ali.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "strings"
  10. )
  11. // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
  12. type AliMessage struct {
  13. User string `json:"user"`
  14. Bot string `json:"bot"`
  15. }
  16. type AliInput struct {
  17. Prompt string `json:"prompt"`
  18. History []AliMessage `json:"history"`
  19. }
  20. type AliParameters struct {
  21. TopP float64 `json:"top_p,omitempty"`
  22. TopK int `json:"top_k,omitempty"`
  23. Seed uint64 `json:"seed,omitempty"`
  24. EnableSearch bool `json:"enable_search,omitempty"`
  25. }
  26. type AliChatRequest struct {
  27. Model string `json:"model"`
  28. Input AliInput `json:"input"`
  29. Parameters AliParameters `json:"parameters,omitempty"`
  30. }
  31. type AliError struct {
  32. Code string `json:"code"`
  33. Message string `json:"message"`
  34. RequestId string `json:"request_id"`
  35. }
  36. type AliUsage struct {
  37. InputTokens int `json:"input_tokens"`
  38. OutputTokens int `json:"output_tokens"`
  39. }
  40. type AliOutput struct {
  41. Text string `json:"text"`
  42. FinishReason string `json:"finish_reason"`
  43. }
  44. type AliChatResponse struct {
  45. Output AliOutput `json:"output"`
  46. Usage AliUsage `json:"usage"`
  47. AliError
  48. }
  49. func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
  50. messages := make([]AliMessage, 0, len(request.Messages))
  51. prompt := ""
  52. for i := 0; i < len(request.Messages); i++ {
  53. message := request.Messages[i]
  54. if message.Role == "system" {
  55. messages = append(messages, AliMessage{
  56. User: message.Content,
  57. Bot: "Okay",
  58. })
  59. continue
  60. } else {
  61. if i == len(request.Messages)-1 {
  62. prompt = message.Content
  63. break
  64. }
  65. messages = append(messages, AliMessage{
  66. User: message.Content,
  67. Bot: request.Messages[i+1].Content,
  68. })
  69. i++
  70. }
  71. }
  72. return &AliChatRequest{
  73. Model: request.Model,
  74. Input: AliInput{
  75. Prompt: prompt,
  76. History: messages,
  77. },
  78. //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
  79. // TopP: request.TopP,
  80. // TopK: 50,
  81. // //Seed: 0,
  82. // //EnableSearch: false,
  83. //},
  84. }
  85. }
  86. func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
  87. choice := OpenAITextResponseChoice{
  88. Index: 0,
  89. Message: Message{
  90. Role: "assistant",
  91. Content: response.Output.Text,
  92. },
  93. FinishReason: response.Output.FinishReason,
  94. }
  95. fullTextResponse := OpenAITextResponse{
  96. Id: response.RequestId,
  97. Object: "chat.completion",
  98. Created: common.GetTimestamp(),
  99. Choices: []OpenAITextResponseChoice{choice},
  100. Usage: Usage{
  101. PromptTokens: response.Usage.InputTokens,
  102. CompletionTokens: response.Usage.OutputTokens,
  103. TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
  104. },
  105. }
  106. return &fullTextResponse
  107. }
  108. func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
  109. var choice ChatCompletionsStreamResponseChoice
  110. choice.Delta.Content = aliResponse.Output.Text
  111. if aliResponse.Output.FinishReason != "null" {
  112. finishReason := aliResponse.Output.FinishReason
  113. choice.FinishReason = &finishReason
  114. }
  115. response := ChatCompletionsStreamResponse{
  116. Id: aliResponse.RequestId,
  117. Object: "chat.completion.chunk",
  118. Created: common.GetTimestamp(),
  119. Model: "ernie-bot",
  120. Choices: []ChatCompletionsStreamResponseChoice{choice},
  121. }
  122. return &response
  123. }
  124. func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  125. var usage Usage
  126. scanner := bufio.NewScanner(resp.Body)
  127. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  128. if atEOF && len(data) == 0 {
  129. return 0, nil, nil
  130. }
  131. if i := strings.Index(string(data), "\n"); i >= 0 {
  132. return i + 1, data[0:i], nil
  133. }
  134. if atEOF {
  135. return len(data), data, nil
  136. }
  137. return 0, nil, nil
  138. })
  139. dataChan := make(chan string)
  140. stopChan := make(chan bool)
  141. go func() {
  142. for scanner.Scan() {
  143. data := scanner.Text()
  144. if len(data) < 5 { // ignore blank line or wrong format
  145. continue
  146. }
  147. if data[:5] != "data:" {
  148. continue
  149. }
  150. data = data[5:]
  151. dataChan <- data
  152. }
  153. stopChan <- true
  154. }()
  155. c.Writer.Header().Set("Content-Type", "text/event-stream")
  156. c.Writer.Header().Set("Cache-Control", "no-cache")
  157. c.Writer.Header().Set("Connection", "keep-alive")
  158. c.Writer.Header().Set("Transfer-Encoding", "chunked")
  159. c.Writer.Header().Set("X-Accel-Buffering", "no")
  160. lastResponseText := ""
  161. c.Stream(func(w io.Writer) bool {
  162. select {
  163. case data := <-dataChan:
  164. var aliResponse AliChatResponse
  165. err := json.Unmarshal([]byte(data), &aliResponse)
  166. if err != nil {
  167. common.SysError("error unmarshalling stream response: " + err.Error())
  168. return true
  169. }
  170. usage.PromptTokens += aliResponse.Usage.InputTokens
  171. usage.CompletionTokens += aliResponse.Usage.OutputTokens
  172. usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
  173. response := streamResponseAli2OpenAI(&aliResponse)
  174. response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
  175. lastResponseText = aliResponse.Output.Text
  176. jsonResponse, err := json.Marshal(response)
  177. if err != nil {
  178. common.SysError("error marshalling stream response: " + err.Error())
  179. return true
  180. }
  181. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  182. return true
  183. case <-stopChan:
  184. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  185. return false
  186. }
  187. })
  188. err := resp.Body.Close()
  189. if err != nil {
  190. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  191. }
  192. return nil, &usage
  193. }
  194. func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  195. var aliResponse AliChatResponse
  196. responseBody, err := io.ReadAll(resp.Body)
  197. if err != nil {
  198. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  199. }
  200. err = resp.Body.Close()
  201. if err != nil {
  202. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  203. }
  204. err = json.Unmarshal(responseBody, &aliResponse)
  205. if err != nil {
  206. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  207. }
  208. if aliResponse.Code != "" {
  209. return &OpenAIErrorWithStatusCode{
  210. OpenAIError: OpenAIError{
  211. Message: aliResponse.Message,
  212. Type: aliResponse.Code,
  213. Param: aliResponse.RequestId,
  214. Code: aliResponse.Code,
  215. },
  216. StatusCode: resp.StatusCode,
  217. }, nil
  218. }
  219. fullTextResponse := responseAli2OpenAI(&aliResponse)
  220. jsonResponse, err := json.Marshal(fullTextResponse)
  221. if err != nil {
  222. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  223. }
  224. c.Writer.Header().Set("Content-Type", "application/json")
  225. c.Writer.WriteHeader(resp.StatusCode)
  226. _, err = c.Writer.Write(jsonResponse)
  227. return nil, &fullTextResponse.Usage
  228. }