relay-cohere.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. package cohere
  2. import (
  3. "bufio"
  4. "encoding/json"
  5. "io"
  6. "net/http"
  7. "one-api/common"
  8. "one-api/dto"
  9. relaycommon "one-api/relay/common"
  10. "one-api/relay/helper"
  11. "one-api/service"
  12. "one-api/types"
  13. "strings"
  14. "time"
  15. "github.com/gin-gonic/gin"
  16. )
  17. func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
  18. cohereReq := CohereRequest{
  19. Model: textRequest.Model,
  20. ChatHistory: []ChatHistory{},
  21. Message: "",
  22. Stream: textRequest.Stream,
  23. MaxTokens: textRequest.GetMaxTokens(),
  24. }
  25. if common.CohereSafetySetting != "NONE" {
  26. cohereReq.SafetyMode = common.CohereSafetySetting
  27. }
  28. if cohereReq.MaxTokens == 0 {
  29. cohereReq.MaxTokens = 4000
  30. }
  31. for _, msg := range textRequest.Messages {
  32. if msg.Role == "user" {
  33. cohereReq.Message = msg.StringContent()
  34. } else {
  35. var role string
  36. if msg.Role == "assistant" {
  37. role = "CHATBOT"
  38. } else if msg.Role == "system" {
  39. role = "SYSTEM"
  40. } else {
  41. role = "USER"
  42. }
  43. cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
  44. Role: role,
  45. Message: msg.StringContent(),
  46. })
  47. }
  48. }
  49. return &cohereReq
  50. }
  51. func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
  52. if rerankRequest.TopN == 0 {
  53. rerankRequest.TopN = 1
  54. }
  55. cohereReq := CohereRerankRequest{
  56. Query: rerankRequest.Query,
  57. Documents: rerankRequest.Documents,
  58. Model: rerankRequest.Model,
  59. TopN: rerankRequest.TopN,
  60. ReturnDocuments: true,
  61. }
  62. return &cohereReq
  63. }
  64. func stopReasonCohere2OpenAI(reason string) string {
  65. switch reason {
  66. case "COMPLETE":
  67. return "stop"
  68. case "MAX_TOKENS":
  69. return "max_tokens"
  70. default:
  71. return reason
  72. }
  73. }
  74. func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  75. responseId := helper.GetResponseID(c)
  76. createdTime := common.GetTimestamp()
  77. usage := &dto.Usage{}
  78. responseText := ""
  79. scanner := bufio.NewScanner(resp.Body)
  80. scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
  81. if atEOF && len(data) == 0 {
  82. return 0, nil, nil
  83. }
  84. if i := strings.Index(string(data), "\n"); i >= 0 {
  85. return i + 1, data[0:i], nil
  86. }
  87. if atEOF {
  88. return len(data), data, nil
  89. }
  90. return 0, nil, nil
  91. })
  92. dataChan := make(chan string)
  93. stopChan := make(chan bool)
  94. go func() {
  95. for scanner.Scan() {
  96. data := scanner.Text()
  97. dataChan <- data
  98. }
  99. stopChan <- true
  100. }()
  101. helper.SetEventStreamHeaders(c)
  102. isFirst := true
  103. c.Stream(func(w io.Writer) bool {
  104. select {
  105. case data := <-dataChan:
  106. if isFirst {
  107. isFirst = false
  108. info.FirstResponseTime = time.Now()
  109. }
  110. data = strings.TrimSuffix(data, "\r")
  111. var cohereResp CohereResponse
  112. err := json.Unmarshal([]byte(data), &cohereResp)
  113. if err != nil {
  114. common.SysError("error unmarshalling stream response: " + err.Error())
  115. return true
  116. }
  117. var openaiResp dto.ChatCompletionsStreamResponse
  118. openaiResp.Id = responseId
  119. openaiResp.Created = createdTime
  120. openaiResp.Object = "chat.completion.chunk"
  121. openaiResp.Model = info.UpstreamModelName
  122. if cohereResp.IsFinished {
  123. finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
  124. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  125. {
  126. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
  127. Index: 0,
  128. FinishReason: &finishReason,
  129. },
  130. }
  131. if cohereResp.Response != nil {
  132. usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
  133. usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
  134. }
  135. } else {
  136. openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
  137. {
  138. Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
  139. Role: "assistant",
  140. Content: &cohereResp.Text,
  141. },
  142. Index: 0,
  143. },
  144. }
  145. responseText += cohereResp.Text
  146. }
  147. jsonStr, err := json.Marshal(openaiResp)
  148. if err != nil {
  149. common.SysError("error marshalling stream response: " + err.Error())
  150. return true
  151. }
  152. c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
  153. return true
  154. case <-stopChan:
  155. c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
  156. return false
  157. }
  158. })
  159. if usage.PromptTokens == 0 {
  160. usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
  161. }
  162. return usage, nil
  163. }
  164. func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
  165. createdTime := common.GetTimestamp()
  166. responseBody, err := io.ReadAll(resp.Body)
  167. if err != nil {
  168. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  169. }
  170. common.CloseResponseBodyGracefully(resp)
  171. var cohereResp CohereResponseResult
  172. err = json.Unmarshal(responseBody, &cohereResp)
  173. if err != nil {
  174. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  175. }
  176. usage := dto.Usage{}
  177. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  178. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  179. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  180. var openaiResp dto.TextResponse
  181. openaiResp.Id = cohereResp.ResponseId
  182. openaiResp.Created = createdTime
  183. openaiResp.Object = "chat.completion"
  184. openaiResp.Model = info.UpstreamModelName
  185. openaiResp.Usage = usage
  186. openaiResp.Choices = []dto.OpenAITextResponseChoice{
  187. {
  188. Index: 0,
  189. Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
  190. FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
  191. },
  192. }
  193. jsonResponse, err := json.Marshal(openaiResp)
  194. if err != nil {
  195. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  196. }
  197. c.Writer.Header().Set("Content-Type", "application/json")
  198. c.Writer.WriteHeader(resp.StatusCode)
  199. _, _ = c.Writer.Write(jsonResponse)
  200. return &usage, nil
  201. }
  202. func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
  203. responseBody, err := io.ReadAll(resp.Body)
  204. if err != nil {
  205. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  206. }
  207. common.CloseResponseBodyGracefully(resp)
  208. var cohereResp CohereRerankResponseResult
  209. err = json.Unmarshal(responseBody, &cohereResp)
  210. if err != nil {
  211. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  212. }
  213. usage := dto.Usage{}
  214. if cohereResp.Meta.BilledUnits.InputTokens == 0 {
  215. usage.PromptTokens = info.PromptTokens
  216. usage.CompletionTokens = 0
  217. usage.TotalTokens = info.PromptTokens
  218. } else {
  219. usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
  220. usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
  221. usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
  222. }
  223. var rerankResp dto.RerankResponse
  224. rerankResp.Results = cohereResp.Results
  225. rerankResp.Usage = usage
  226. jsonResponse, err := json.Marshal(rerankResp)
  227. if err != nil {
  228. return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
  229. }
  230. c.Writer.Header().Set("Content-Type", "application/json")
  231. c.Writer.WriteHeader(resp.StatusCode)
  232. _, err = c.Writer.Write(jsonResponse)
  233. return &usage, nil
  234. }