relay-tencent.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. package controller
  2. import (
  3. "bufio"
  4. "crypto/hmac"
  5. "crypto/sha1"
  6. "encoding/base64"
  7. "encoding/json"
  8. "errors"
  9. "fmt"
  10. "github.com/gin-gonic/gin"
  11. "io"
  12. "net/http"
  13. "one-api/common"
  14. "sort"
  15. "strconv"
  16. "strings"
  17. )
  18. // https://cloud.tencent.com/document/product/1729/97732
  19. type TencentMessage struct {
  20. Role string `json:"role"`
  21. Content string `json:"content"`
  22. }
  23. type TencentChatRequest struct {
  24. AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
  25. SecretId string `json:"secret_id"` // 官网 SecretId
  26. // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
  27. // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
  28. Timestamp int64 `json:"timestamp"`
  29. // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
  30. // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
  31. Expired int64 `json:"expired"`
  32. QueryID string `json:"query_id"` //请求 Id,用于问题排查
  33. // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
  34. // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
  35. // 建议该参数和 top_p 只设置1个,不要同时更改 top_p
  36. Temperature float64 `json:"temperature"`
  37. // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
  38. // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
  39. // 建议该参数和 temperature 只设置1个,不要同时更改
  40. TopP float64 `json:"top_p"`
  41. // Stream 0:同步,1:流式 (默认,协议:SSE)
  42. // 同步请求超时:60s,如果内容较长建议使用流式
  43. Stream int `json:"stream"`
  44. // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
  45. // 输入 content 总数最大支持 3000 token。
  46. Messages []TencentMessage `json:"messages"`
  47. }
  48. type TencentError struct {
  49. Code int `json:"code"`
  50. Message string `json:"message"`
  51. }
  52. type TencentUsage struct {
  53. InputTokens int `json:"input_tokens"`
  54. OutputTokens int `json:"output_tokens"`
  55. TotalTokens int `json:"total_tokens"`
  56. }
  57. type TencentResponseChoices struct {
  58. FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
  59. Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
  60. Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
  61. }
  62. type TencentChatResponse struct {
  63. Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
  64. Created string `json:"created,omitempty"` // unix 时间戳的字符串
  65. Id string `json:"id,omitempty"` // 会话 id
  66. Usage Usage `json:"usage,omitempty"` // token 数量
  67. Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
  68. Note string `json:"note,omitempty"` // 注释
  69. ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
  70. }
  71. func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
  72. messages := make([]TencentMessage, 0, len(request.Messages))
  73. for i := 0; i < len(request.Messages); i++ {
  74. message := request.Messages[i]
  75. if message.Role == "system" {
  76. messages = append(messages, TencentMessage{
  77. Role: "user",
  78. Content: message.Content,
  79. })
  80. messages = append(messages, TencentMessage{
  81. Role: "assistant",
  82. Content: "Okay",
  83. })
  84. continue
  85. }
  86. messages = append(messages, TencentMessage{
  87. Content: message.Content,
  88. Role: message.Role,
  89. })
  90. }
  91. stream := 0
  92. if request.Stream {
  93. stream = 1
  94. }
  95. return &TencentChatRequest{
  96. Timestamp: common.GetTimestamp(),
  97. Expired: common.GetTimestamp() + 24*60*60,
  98. QueryID: common.GetUUID(),
  99. Temperature: request.Temperature,
  100. TopP: request.TopP,
  101. Stream: stream,
  102. Messages: messages,
  103. }
  104. }
  105. func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
  106. fullTextResponse := OpenAITextResponse{
  107. Object: "chat.completion",
  108. Created: common.GetTimestamp(),
  109. Usage: response.Usage,
  110. }
  111. if len(response.Choices) > 0 {
  112. choice := OpenAITextResponseChoice{
  113. Index: 0,
  114. Message: Message{
  115. Role: "assistant",
  116. Content: response.Choices[0].Messages.Content,
  117. },
  118. FinishReason: response.Choices[0].FinishReason,
  119. }
  120. fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
  121. }
  122. return &fullTextResponse
  123. }
  124. func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
  125. response := ChatCompletionsStreamResponse{
  126. Object: "chat.completion.chunk",
  127. Created: common.GetTimestamp(),
  128. Model: "tencent-hunyuan",
  129. }
  130. if len(TencentResponse.Choices) > 0 {
  131. var choice ChatCompletionsStreamResponseChoice
  132. choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
  133. if TencentResponse.Choices[0].FinishReason == "stop" {
  134. choice.FinishReason = &stopFinishReason
  135. }
  136. response.Choices = append(response.Choices, choice)
  137. }
  138. return &response
  139. }
  140. func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
  141. var responseText string
  142. scanner := bufio.NewScanner(resp.Body)
  143. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  144. if atEOF && len(data) == 0 {
  145. return 0, nil, nil
  146. }
  147. if i := strings.Index(string(data), "\n"); i >= 0 {
  148. return i + 1, data[0:i], nil
  149. }
  150. if atEOF {
  151. return len(data), data, nil
  152. }
  153. return 0, nil, nil
  154. })
  155. dataChan := make(chan string)
  156. stopChan := make(chan bool)
  157. go func() {
  158. for scanner.Scan() {
  159. data := scanner.Text()
  160. if len(data) < 5 { // ignore blank line or wrong format
  161. continue
  162. }
  163. if data[:5] != "data:" {
  164. continue
  165. }
  166. data = data[5:]
  167. dataChan <- data
  168. }
  169. stopChan <- true
  170. }()
  171. setEventStreamHeaders(c)
  172. c.Stream(func(w io.Writer) bool {
  173. select {
  174. case data := <-dataChan:
  175. var TencentResponse TencentChatResponse
  176. err := json.Unmarshal([]byte(data), &TencentResponse)
  177. if err != nil {
  178. common.SysError("error unmarshalling stream response: " + err.Error())
  179. return true
  180. }
  181. response := streamResponseTencent2OpenAI(&TencentResponse)
  182. if len(response.Choices) != 0 {
  183. responseText += response.Choices[0].Delta.Content
  184. }
  185. jsonResponse, err := json.Marshal(response)
  186. if err != nil {
  187. common.SysError("error marshalling stream response: " + err.Error())
  188. return true
  189. }
  190. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
  191. return true
  192. case <-stopChan:
  193. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  194. return false
  195. }
  196. })
  197. err := resp.Body.Close()
  198. if err != nil {
  199. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
  200. }
  201. return nil, responseText
  202. }
  203. func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
  204. var TencentResponse TencentChatResponse
  205. responseBody, err := io.ReadAll(resp.Body)
  206. if err != nil {
  207. return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
  208. }
  209. err = resp.Body.Close()
  210. if err != nil {
  211. return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
  212. }
  213. err = json.Unmarshal(responseBody, &TencentResponse)
  214. if err != nil {
  215. return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
  216. }
  217. if TencentResponse.Error.Code != 0 {
  218. return &OpenAIErrorWithStatusCode{
  219. OpenAIError: OpenAIError{
  220. Message: TencentResponse.Error.Message,
  221. Code: TencentResponse.Error.Code,
  222. },
  223. StatusCode: resp.StatusCode,
  224. }, nil
  225. }
  226. fullTextResponse := responseTencent2OpenAI(&TencentResponse)
  227. jsonResponse, err := json.Marshal(fullTextResponse)
  228. if err != nil {
  229. return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
  230. }
  231. c.Writer.Header().Set("Content-Type", "application/json")
  232. c.Writer.WriteHeader(resp.StatusCode)
  233. _, err = c.Writer.Write(jsonResponse)
  234. return nil, &fullTextResponse.Usage
  235. }
  236. func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
  237. parts := strings.Split(config, "|")
  238. if len(parts) != 3 {
  239. err = errors.New("invalid tencent config")
  240. return
  241. }
  242. appId, err = strconv.ParseInt(parts[0], 10, 64)
  243. secretId = parts[1]
  244. secretKey = parts[2]
  245. return
  246. }
  247. func getTencentSign(req TencentChatRequest, secretKey string) string {
  248. params := make([]string, 0)
  249. params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
  250. params = append(params, "secret_id="+req.SecretId)
  251. params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
  252. params = append(params, "query_id="+req.QueryID)
  253. params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
  254. params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
  255. params = append(params, "stream="+strconv.Itoa(req.Stream))
  256. params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
  257. var messageStr string
  258. for _, msg := range req.Messages {
  259. messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
  260. }
  261. messageStr = strings.TrimSuffix(messageStr, ",")
  262. params = append(params, "messages=["+messageStr+"]")
  263. sort.Sort(sort.StringSlice(params))
  264. url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
  265. mac := hmac.New(sha1.New, []byte(secretKey))
  266. signURL := url
  267. mac.Write([]byte(signURL))
  268. sign := mac.Sum([]byte(nil))
  269. return base64.StdEncoding.EncodeToString(sign)
  270. }