2
0

relay-baidu.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. package baidu
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "one-api/common"
  9. "one-api/constant"
  10. "one-api/dto"
  11. relaycommon "one-api/relay/common"
  12. "one-api/relay/helper"
  13. "one-api/service"
  14. "one-api/types"
  15. "strings"
  16. "sync"
  17. "time"
  18. "github.com/gin-gonic/gin"
  19. )
  20. // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
  21. var baiduTokenStore sync.Map
  22. func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
  23. baiduRequest := BaiduChatRequest{
  24. Temperature: request.Temperature,
  25. TopP: request.TopP,
  26. PenaltyScore: request.FrequencyPenalty,
  27. Stream: request.Stream,
  28. DisableSearch: false,
  29. EnableCitation: false,
  30. UserId: request.User,
  31. }
  32. if request.MaxTokens != 0 {
  33. maxTokens := int(request.MaxTokens)
  34. if request.MaxTokens == 1 {
  35. maxTokens = 2
  36. }
  37. baiduRequest.MaxOutputTokens = &maxTokens
  38. }
  39. for _, message := range request.Messages {
  40. if message.Role == "system" {
  41. baiduRequest.System = message.StringContent()
  42. } else {
  43. baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
  44. Role: message.Role,
  45. Content: message.StringContent(),
  46. })
  47. }
  48. }
  49. return &baiduRequest
  50. }
  51. func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
  52. choice := dto.OpenAITextResponseChoice{
  53. Index: 0,
  54. Message: dto.Message{
  55. Role: "assistant",
  56. Content: response.Result,
  57. },
  58. FinishReason: "stop",
  59. }
  60. fullTextResponse := dto.OpenAITextResponse{
  61. Id: response.Id,
  62. Object: "chat.completion",
  63. Created: response.Created,
  64. Choices: []dto.OpenAITextResponseChoice{choice},
  65. Usage: response.Usage,
  66. }
  67. return &fullTextResponse
  68. }
  69. func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
  70. var choice dto.ChatCompletionsStreamResponseChoice
  71. choice.Delta.SetContentString(baiduResponse.Result)
  72. if baiduResponse.IsEnd {
  73. choice.FinishReason = &constant.FinishReasonStop
  74. }
  75. response := dto.ChatCompletionsStreamResponse{
  76. Id: baiduResponse.Id,
  77. Object: "chat.completion.chunk",
  78. Created: baiduResponse.Created,
  79. Model: "ernie-bot",
  80. Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
  81. }
  82. return &response
  83. }
  84. func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
  85. return &BaiduEmbeddingRequest{
  86. Input: request.ParseInput(),
  87. }
  88. }
  89. func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
  90. openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
  91. Object: "list",
  92. Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
  93. Model: "baidu-embedding",
  94. Usage: response.Usage,
  95. }
  96. for _, item := range response.Data {
  97. openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
  98. Object: item.Object,
  99. Index: item.Index,
  100. Embedding: item.Embedding,
  101. })
  102. }
  103. return &openAIEmbeddingResponse
  104. }
  105. func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  106. usage := &dto.Usage{}
  107. helper.StreamScannerHandler(c, resp, info, func(data string) bool {
  108. var baiduResponse BaiduChatStreamResponse
  109. err := common.Unmarshal([]byte(data), &baiduResponse)
  110. if err != nil {
  111. common.SysError("error unmarshalling stream response: " + err.Error())
  112. return true
  113. }
  114. if baiduResponse.Usage.TotalTokens != 0 {
  115. usage.TotalTokens = baiduResponse.Usage.TotalTokens
  116. usage.PromptTokens = baiduResponse.Usage.PromptTokens
  117. usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
  118. }
  119. response := streamResponseBaidu2OpenAI(&baiduResponse)
  120. err = helper.ObjectData(c, response)
  121. if err != nil {
  122. common.SysError("error sending stream response: " + err.Error())
  123. }
  124. return true
  125. })
  126. common.CloseResponseBodyGracefully(resp)
  127. return nil, usage
  128. }
  129. func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  130. var baiduResponse BaiduChatResponse
  131. responseBody, err := io.ReadAll(resp.Body)
  132. if err != nil {
  133. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  134. }
  135. common.CloseResponseBodyGracefully(resp)
  136. err = json.Unmarshal(responseBody, &baiduResponse)
  137. if err != nil {
  138. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  139. }
  140. if baiduResponse.ErrorMsg != "" {
  141. return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
  142. }
  143. fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
  144. jsonResponse, err := json.Marshal(fullTextResponse)
  145. if err != nil {
  146. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  147. }
  148. c.Writer.Header().Set("Content-Type", "application/json")
  149. c.Writer.WriteHeader(resp.StatusCode)
  150. _, err = c.Writer.Write(jsonResponse)
  151. return nil, &fullTextResponse.Usage
  152. }
  153. func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
  154. var baiduResponse BaiduEmbeddingResponse
  155. responseBody, err := io.ReadAll(resp.Body)
  156. if err != nil {
  157. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  158. }
  159. common.CloseResponseBodyGracefully(resp)
  160. err = json.Unmarshal(responseBody, &baiduResponse)
  161. if err != nil {
  162. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  163. }
  164. if baiduResponse.ErrorMsg != "" {
  165. return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
  166. }
  167. fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
  168. jsonResponse, err := json.Marshal(fullTextResponse)
  169. if err != nil {
  170. return types.NewError(err, types.ErrorCodeBadResponseBody), nil
  171. }
  172. c.Writer.Header().Set("Content-Type", "application/json")
  173. c.Writer.WriteHeader(resp.StatusCode)
  174. _, err = c.Writer.Write(jsonResponse)
  175. return nil, &fullTextResponse.Usage
  176. }
  177. func getBaiduAccessToken(apiKey string) (string, error) {
  178. if val, ok := baiduTokenStore.Load(apiKey); ok {
  179. var accessToken BaiduAccessToken
  180. if accessToken, ok = val.(BaiduAccessToken); ok {
  181. // soon this will expire
  182. if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
  183. go func() {
  184. _, _ = getBaiduAccessTokenHelper(apiKey)
  185. }()
  186. }
  187. return accessToken.AccessToken, nil
  188. }
  189. }
  190. accessToken, err := getBaiduAccessTokenHelper(apiKey)
  191. if err != nil {
  192. return "", err
  193. }
  194. if accessToken == nil {
  195. return "", errors.New("getBaiduAccessToken return a nil token")
  196. }
  197. return (*accessToken).AccessToken, nil
  198. }
  199. func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
  200. parts := strings.Split(apiKey, "|")
  201. if len(parts) != 2 {
  202. return nil, errors.New("invalid baidu apikey")
  203. }
  204. req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
  205. parts[0], parts[1]), nil)
  206. if err != nil {
  207. return nil, err
  208. }
  209. req.Header.Add("Content-Type", "application/json")
  210. req.Header.Add("Accept", "application/json")
  211. res, err := service.GetHttpClient().Do(req)
  212. if err != nil {
  213. return nil, err
  214. }
  215. defer res.Body.Close()
  216. var accessToken BaiduAccessToken
  217. err = json.NewDecoder(res.Body).Decode(&accessToken)
  218. if err != nil {
  219. return nil, err
  220. }
  221. if accessToken.Error != "" {
  222. return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
  223. }
  224. if accessToken.AccessToken == "" {
  225. return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
  226. }
  227. accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
  228. baiduTokenStore.Store(apiKey, accessToken)
  229. return &accessToken, nil
  230. }