relay-ali.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. package controller
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "github.com/gin-gonic/gin"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "strings"
  10. )
  11. // https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
  12. type AliMessage struct {
  13. User string `json:"user"`
  14. Bot string `json:"bot"`
  15. }
  16. type AliInput struct {
  17. Prompt string `json:"prompt"`
  18. History []AliMessage `json:"history"`
  19. }
  20. type AliParameters struct {
  21. TopP float64 `json:"top_p,omitempty"`
  22. TopK int `json:"top_k,omitempty"`
  23. Seed uint64 `json:"seed,omitempty"`
  24. EnableSearch bool `json:"enable_search,omitempty"`
  25. }
  26. type AliChatRequest struct {
  27. Model string `json:"model"`
  28. Input AliInput `json:"input"`
  29. Parameters AliParameters `json:"parameters,omitempty"`
  30. }
  31. type AliEmbeddingRequest struct {
  32. Model string `json:"model"`
  33. Input struct {
  34. Texts []string `json:"texts"`
  35. } `json:"input"`
  36. Parameters *struct {
  37. TextType string `json:"text_type,omitempty"`
  38. } `json:"parameters,omitempty"`
  39. }
  40. type AliEmbedding struct {
  41. Embedding []float64 `json:"embedding"`
  42. TextIndex int `json:"text_index"`
  43. }
  44. type AliEmbeddingResponse struct {
  45. Output struct {
  46. Embeddings []AliEmbedding `json:"embeddings"`
  47. } `json:"output"`
  48. Usage AliUsage `json:"usage"`
  49. AliError
  50. }
  51. type AliError struct {
  52. Code string `json:"code"`
  53. Message string `json:"message"`
  54. RequestId string `json:"request_id"`
  55. }
  56. type AliUsage struct {
  57. InputTokens int `json:"input_tokens"`
  58. OutputTokens int `json:"output_tokens"`
  59. TotalTokens int `json:"total_tokens"`
  60. }
  61. type AliOutput struct {
  62. Text string `json:"text"`
  63. FinishReason string `json:"finish_reason"`
  64. }
  65. type AliChatResponse struct {
  66. Output AliOutput `json:"output"`
  67. Usage AliUsage `json:"usage"`
  68. AliError
  69. }
  70. func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
  71. messages := make([]AliMessage, 0, len(request.Messages))
  72. prompt := ""
  73. for i := 0; i < len(request.Messages); i++ {
  74. message := request.Messages[i]
  75. if message.Role == "system" {
  76. messages = append(messages, AliMessage{
  77. User: message.Content,
  78. Bot: "Okay",
  79. })
  80. continue
  81. } else {
  82. if i == len(request.Messages)-1 {
  83. prompt = message.Content
  84. break
  85. }
  86. messages = append(messages, AliMessage{
  87. User: message.Content,
  88. Bot: request.Messages[i+1].Content,
  89. })
  90. i++
  91. }
  92. }
  93. return &AliChatRequest{
  94. Model: request.Model,
  95. Input: AliInput{
  96. Prompt: prompt,
  97. History: messages,
  98. },
  99. //Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
  100. // TopP: request.TopP,
  101. // TopK: 50,
  102. // //Seed: 0,
  103. // //EnableSearch: false,
  104. //},
  105. }
  106. }
  107. func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
  108. return &AliEmbeddingRequest{
  109. Model: "text-embedding-v1",
  110. Input: struct {
  111. Texts []string `json:"texts"`
  112. }{
  113. Texts: request.ParseInput(),
  114. },
  115. }
  116. }
  117. func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  118. var aliResponse AliEmbeddingResponse
  119. err := json.NewDecoder(resp.Body).Decode(&aliResponse)
  120. if err != nil {
  121. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  122. }
  123. err = resp.Body.Close()
  124. if err != nil {
  125. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  126. }
  127. if aliResponse.Code != "" {
  128. return &OpenAIErrorWithStatusCode{
  129. OpenAIError: OpenAIError{
  130. Message: aliResponse.Message,
  131. Type: aliResponse.Code,
  132. Param: aliResponse.RequestId,
  133. Code: aliResponse.Code,
  134. },
  135. StatusCode: resp.StatusCode,
  136. }, nil
  137. }
  138. fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
  139. jsonResponse, err := json.Marshal(fullTextResponse)
  140. if err != nil {
  141. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  142. }
  143. c.Writer.Header().Set("Content-Type", "application/json")
  144. c.Writer.WriteHeader(resp.StatusCode)
  145. _, err = c.Writer.Write(jsonResponse)
  146. return nil, &fullTextResponse.Usage
  147. }
  148. func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
  149. openAIEmbeddingResponse := OpenAIEmbeddingResponse{
  150. Object: "list",
  151. Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
  152. Model: "text-embedding-v1",
  153. Usage: Usage{TotalTokens: response.Usage.TotalTokens},
  154. }
  155. for _, item := range response.Output.Embeddings {
  156. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
  157. Object: `embedding`,
  158. Index: item.TextIndex,
  159. Embedding: item.Embedding,
  160. })
  161. }
  162. return &openAIEmbeddingResponse
  163. }
  164. func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
  165. choice := OpenAITextResponseChoice{
  166. Index: 0,
  167. Message: Message{
  168. Role: "assistant",
  169. Content: response.Output.Text,
  170. },
  171. FinishReason: response.Output.FinishReason,
  172. }
  173. fullTextResponse := OpenAITextResponse{
  174. Id: response.RequestId,
  175. Object: "chat.completion",
  176. Created: common.GetTimestamp(),
  177. Choices: []OpenAITextResponseChoice{choice},
  178. Usage: Usage{
  179. PromptTokens: response.Usage.InputTokens,
  180. CompletionTokens: response.Usage.OutputTokens,
  181. TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
  182. },
  183. }
  184. return &fullTextResponse
  185. }
  186. func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
  187. var choice ChatCompletionsStreamResponseChoice
  188. choice.Delta.Content = aliResponse.Output.Text
  189. if aliResponse.Output.FinishReason != "null" {
  190. finishReason := aliResponse.Output.FinishReason
  191. choice.FinishReason = &finishReason
  192. }
  193. response := ChatCompletionsStreamResponse{
  194. Id: aliResponse.RequestId,
  195. Object: "chat.completion.chunk",
  196. Created: common.GetTimestamp(),
  197. Model: "ernie-bot",
  198. Choices: []ChatCompletionsStreamResponseChoice{choice},
  199. }
  200. return &response
  201. }
  202. func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  203. var usage Usage
  204. scanner := bufio.NewScanner(resp.Body)
  205. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  206. if atEOF && len(data) == 0 {
  207. return 0, nil, nil
  208. }
  209. if i := strings.Index(string(data), "\n"); i >= 0 {
  210. return i + 1, data[0:i], nil
  211. }
  212. if atEOF {
  213. return len(data), data, nil
  214. }
  215. return 0, nil, nil
  216. })
  217. dataChan := make(chan string)
  218. stopChan := make(chan bool)
  219. go func() {
  220. for scanner.Scan() {
  221. data := scanner.Text()
  222. if len(data) < 5 { // ignore blank line or wrong format
  223. continue
  224. }
  225. if data[:5] != "data:" {
  226. continue
  227. }
  228. data = data[5:]
  229. dataChan <- data
  230. }
  231. stopChan <- true
  232. }()
  233. setEventStreamHeaders(c)
  234. lastResponseText := ""
  235. c.Stream(func(w io.Writer) bool {
  236. select {
  237. case data := <-dataChan:
  238. var aliResponse AliChatResponse
  239. err := json.Unmarshal([]byte(data), &aliResponse)
  240. if err != nil {
  241. common.SysError("error unmarshalling stream response: " + err.Error())
  242. return true
  243. }
  244. if aliResponse.Usage.OutputTokens != 0 {
  245. usage.PromptTokens = aliResponse.Usage.InputTokens
  246. usage.CompletionTokens = aliResponse.Usage.OutputTokens
  247. usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
  248. }
  249. response := streamResponseAli2OpenAI(&aliResponse)
  250. response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
  251. lastResponseText = aliResponse.Output.Text
  252. jsonResponse, err := json.Marshal(response)
  253. if err != nil {
  254. common.SysError("error marshalling stream response: " + err.Error())
  255. return true
  256. }
  257. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  258. return true
  259. case <-stopChan:
  260. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  261. return false
  262. }
  263. })
  264. err := resp.Body.Close()
  265. if err != nil {
  266. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  267. }
  268. return nil, &usage
  269. }
  270. func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  271. var aliResponse AliChatResponse
  272. responseBody, err := io.ReadAll(resp.Body)
  273. if err != nil {
  274. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  275. }
  276. err = resp.Body.Close()
  277. if err != nil {
  278. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  279. }
  280. err = json.Unmarshal(responseBody, &aliResponse)
  281. if err != nil {
  282. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  283. }
  284. if aliResponse.Code != "" {
  285. return &OpenAIErrorWithStatusCode{
  286. OpenAIError: OpenAIError{
  287. Message: aliResponse.Message,
  288. Type: aliResponse.Code,
  289. Param: aliResponse.RequestId,
  290. Code: aliResponse.Code,
  291. },
  292. StatusCode: resp.StatusCode,
  293. }, nil
  294. }
  295. fullTextResponse := responseAli2OpenAI(&aliResponse)
  296. jsonResponse, err := json.Marshal(fullTextResponse)
  297. if err != nil {
  298. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  299. }
  300. c.Writer.Header().Set("Content-Type", "application/json")
  301. c.Writer.WriteHeader(resp.StatusCode)
  302. _, err = c.Writer.Write(jsonResponse)
  303. return nil, &fullTextResponse.Usage
  304. }