|
|
@@ -0,0 +1,189 @@
|
|
|
+package cohere
|
|
|
+
|
|
|
+import (
|
|
|
+ "bufio"
|
|
|
+ "encoding/json"
|
|
|
+ "fmt"
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "one-api/common"
|
|
|
+ "one-api/dto"
|
|
|
+ "one-api/service"
|
|
|
+ "strings"
|
|
|
+)
|
|
|
+
|
|
|
+func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
|
|
+ cohereReq := CohereRequest{
|
|
|
+ Model: textRequest.Model,
|
|
|
+ ChatHistory: []ChatHistory{},
|
|
|
+ Message: "",
|
|
|
+ Stream: textRequest.Stream,
|
|
|
+ MaxTokens: textRequest.GetMaxTokens(),
|
|
|
+ }
|
|
|
+ if cohereReq.MaxTokens == 0 {
|
|
|
+ cohereReq.MaxTokens = 4000
|
|
|
+ }
|
|
|
+ for _, msg := range textRequest.Messages {
|
|
|
+ if msg.Role == "user" {
|
|
|
+ cohereReq.Message = msg.StringContent()
|
|
|
+ } else {
|
|
|
+ var role string
|
|
|
+ if msg.Role == "assistant" {
|
|
|
+ role = "CHATBOT"
|
|
|
+ } else if msg.Role == "system" {
|
|
|
+ role = "SYSTEM"
|
|
|
+ } else {
|
|
|
+ role = "USER"
|
|
|
+ }
|
|
|
+ cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
|
|
|
+ Role: role,
|
|
|
+ Message: msg.StringContent(),
|
|
|
+ })
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return &cohereReq
|
|
|
+}
|
|
|
+
|
|
|
+func stopReasonCohere2OpenAI(reason string) string {
|
|
|
+ switch reason {
|
|
|
+ case "COMPLETE":
|
|
|
+ return "stop"
|
|
|
+ case "MAX_TOKENS":
|
|
|
+ return "max_tokens"
|
|
|
+ default:
|
|
|
+ return reason
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
+ responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
|
|
+ createdTime := common.GetTimestamp()
|
|
|
+ usage := &dto.Usage{}
|
|
|
+ responseText := ""
|
|
|
+ scanner := bufio.NewScanner(resp.Body)
|
|
|
+ scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
|
+ if atEOF && len(data) == 0 {
|
|
|
+ return 0, nil, nil
|
|
|
+ }
|
|
|
+ if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
|
+ return i + 1, data[0:i], nil
|
|
|
+ }
|
|
|
+ if atEOF {
|
|
|
+ return len(data), data, nil
|
|
|
+ }
|
|
|
+ return 0, nil, nil
|
|
|
+ })
|
|
|
+ dataChan := make(chan string)
|
|
|
+ stopChan := make(chan bool)
|
|
|
+ go func() {
|
|
|
+ for scanner.Scan() {
|
|
|
+ data := scanner.Text()
|
|
|
+ dataChan <- data
|
|
|
+ }
|
|
|
+ stopChan <- true
|
|
|
+ }()
|
|
|
+ service.SetEventStreamHeaders(c)
|
|
|
+ c.Stream(func(w io.Writer) bool {
|
|
|
+ select {
|
|
|
+ case data := <-dataChan:
|
|
|
+ data = strings.TrimSuffix(data, "\r")
|
|
|
+ var cohereResp CohereResponse
|
|
|
+ err := json.Unmarshal([]byte(data), &cohereResp)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error unmarshalling stream response: " + err.Error())
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ var openaiResp dto.ChatCompletionsStreamResponse
|
|
|
+ openaiResp.Id = responseId
|
|
|
+ openaiResp.Created = createdTime
|
|
|
+ openaiResp.Object = "chat.completion.chunk"
|
|
|
+ openaiResp.Model = modelName
|
|
|
+ if cohereResp.IsFinished {
|
|
|
+ finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
|
|
+ openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
|
|
+ {
|
|
|
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
|
|
|
+ Index: 0,
|
|
|
+ FinishReason: &finishReason,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ if cohereResp.Response != nil {
|
|
|
+ usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
|
|
|
+ usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
|
|
+ {
|
|
|
+ Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
|
|
+ Role: "assistant",
|
|
|
+ Content: cohereResp.Text,
|
|
|
+ },
|
|
|
+ Index: 0,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ responseText += cohereResp.Text
|
|
|
+ }
|
|
|
+ jsonStr, err := json.Marshal(openaiResp)
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("error marshalling stream response: " + err.Error())
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
|
|
+ return true
|
|
|
+ case <-stopChan:
|
|
|
+ c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ })
|
|
|
+ if usage.PromptTokens == 0 {
|
|
|
+ usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
|
|
+ }
|
|
|
+ return nil, usage
|
|
|
+}
|
|
|
+
|
|
|
+func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
+ createdTime := common.GetTimestamp()
|
|
|
+ responseBody, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ err = resp.Body.Close()
|
|
|
+ if err != nil {
|
|
|
+ return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ var cohereResp CohereResponseResult
|
|
|
+ err = json.Unmarshal(responseBody, &cohereResp)
|
|
|
+ if err != nil {
|
|
|
+ return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ usage := dto.Usage{}
|
|
|
+ usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
|
|
+ usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
|
|
+ usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
|
|
+
|
|
|
+ var openaiResp dto.TextResponse
|
|
|
+ openaiResp.Id = cohereResp.ResponseId
|
|
|
+ openaiResp.Created = createdTime
|
|
|
+ openaiResp.Object = "chat.completion"
|
|
|
+ openaiResp.Model = modelName
|
|
|
+ openaiResp.Usage = usage
|
|
|
+
|
|
|
+ content, _ := json.Marshal(cohereResp.Text)
|
|
|
+ openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
|
|
+ {
|
|
|
+ Index: 0,
|
|
|
+ Message: dto.Message{Content: content, Role: "assistant"},
|
|
|
+ FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ jsonResponse, err := json.Marshal(openaiResp)
|
|
|
+ if err != nil {
|
|
|
+ return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ }
|
|
|
+ c.Writer.Header().Set("Content-Type", "application/json")
|
|
|
+ c.Writer.WriteHeader(resp.StatusCode)
|
|
|
+ _, err = c.Writer.Write(jsonResponse)
|
|
|
+ return nil, &usage
|
|
|
+}
|