Просмотр исходного кода

refactor: refactor claude related code

JustSong 2 лет назад
Родитель
Сommit
e628b643cd
2 измененных файлов с 123 добавлено и 99 удалено
  1. 117 0
      controller/relay-claude.go
  2. 6 99
      controller/relay-text.go

+ 117 - 0
controller/relay-claude.go

@@ -1,7 +1,12 @@
 package controller
 
 import (
+	"bufio"
+	"encoding/json"
 	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
 	"one-api/common"
 	"strings"
 )
@@ -102,3 +107,115 @@ func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
 	}
 	return &fullTextResponse
 }
+
+func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
+	responseText := ""
+	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
+	createdTime := common.GetTimestamp()
+	scanner := bufio.NewScanner(resp.Body)
+	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+		if atEOF && len(data) == 0 {
+			return 0, nil, nil
+		}
+		if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
+			return i + 4, data[0:i], nil
+		}
+		if atEOF {
+			return len(data), data, nil
+		}
+		return 0, nil, nil
+	})
+	dataChan := make(chan string)
+	stopChan := make(chan bool)
+	go func() {
+		for scanner.Scan() {
+			data := scanner.Text()
+			if !strings.HasPrefix(data, "event: completion") {
+				continue
+			}
+			data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
+			dataChan <- data
+		}
+		stopChan <- true
+	}()
+	c.Writer.Header().Set("Content-Type", "text/event-stream")
+	c.Writer.Header().Set("Cache-Control", "no-cache")
+	c.Writer.Header().Set("Connection", "keep-alive")
+	c.Writer.Header().Set("Transfer-Encoding", "chunked")
+	c.Writer.Header().Set("X-Accel-Buffering", "no")
+	c.Stream(func(w io.Writer) bool {
+		select {
+		case data := <-dataChan:
+			// some implementations may add \r at the end of data
+			data = strings.TrimSuffix(data, "\r")
+			var claudeResponse ClaudeResponse
+			err := json.Unmarshal([]byte(data), &claudeResponse)
+			if err != nil {
+				common.SysError("error unmarshalling stream response: " + err.Error())
+				return true
+			}
+			responseText += claudeResponse.Completion
+			response := streamResponseClaude2OpenAI(&claudeResponse)
+			response.Id = responseId
+			response.Created = createdTime
+			jsonStr, err := json.Marshal(response)
+			if err != nil {
+				common.SysError("error marshalling stream response: " + err.Error())
+				return true
+			}
+			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
+			return true
+		case <-stopChan:
+			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+			return false
+		}
+	})
+	err := resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+	}
+	return nil, responseText
+}
+
+func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+	var claudeResponse ClaudeResponse
+	err = json.Unmarshal(responseBody, &claudeResponse)
+	if err != nil {
+		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	if claudeResponse.Error.Type != "" {
+		return &OpenAIErrorWithStatusCode{
+			OpenAIError: OpenAIError{
+				Message: claudeResponse.Error.Message,
+				Type:    claudeResponse.Error.Type,
+				Param:   "",
+				Code:    claudeResponse.Error.Type,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+	fullTextResponse := responseClaude2OpenAI(&claudeResponse)
+	completionTokens := countTokenText(claudeResponse.Completion, model)
+	usage := Usage{
+		PromptTokens:     promptTokens,
+		CompletionTokens: completionTokens,
+		TotalTokens:      promptTokens + completionTokens,
+	}
+	fullTextResponse.Usage = usage
+	jsonResponse, err := json.Marshal(fullTextResponse)
+	if err != nil {
+		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	return nil, &usage
+}

+ 6 - 99
controller/relay-text.go

@@ -373,111 +373,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		}
 	case APITypeClaude:
 		if isStream {
-			responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
-			createdTime := common.GetTimestamp()
-			scanner := bufio.NewScanner(resp.Body)
-			scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
-				if atEOF && len(data) == 0 {
-					return 0, nil, nil
-				}
-				if i := strings.Index(string(data), "\r\n\r\n"); i >= 0 {
-					return i + 4, data[0:i], nil
-				}
-				if atEOF {
-					return len(data), data, nil
-				}
-				return 0, nil, nil
-			})
-			dataChan := make(chan string)
-			stopChan := make(chan bool)
-			go func() {
-				for scanner.Scan() {
-					data := scanner.Text()
-					if !strings.HasPrefix(data, "event: completion") {
-						continue
-					}
-					data = strings.TrimPrefix(data, "event: completion\r\ndata: ")
-					dataChan <- data
-				}
-				stopChan <- true
-			}()
-			c.Writer.Header().Set("Content-Type", "text/event-stream")
-			c.Writer.Header().Set("Cache-Control", "no-cache")
-			c.Writer.Header().Set("Connection", "keep-alive")
-			c.Writer.Header().Set("Transfer-Encoding", "chunked")
-			c.Writer.Header().Set("X-Accel-Buffering", "no")
-			c.Stream(func(w io.Writer) bool {
-				select {
-				case data := <-dataChan:
-					// some implementations may add \r at the end of data
-					data = strings.TrimSuffix(data, "\r")
-					var claudeResponse ClaudeResponse
-					err = json.Unmarshal([]byte(data), &claudeResponse)
-					if err != nil {
-						common.SysError("error unmarshalling stream response: " + err.Error())
-						return true
-					}
-					streamResponseText += claudeResponse.Completion
-					response := streamResponseClaude2OpenAI(&claudeResponse)
-					response.Id = responseId
-					response.Created = createdTime
-					jsonStr, err := json.Marshal(response)
-					if err != nil {
-						common.SysError("error marshalling stream response: " + err.Error())
-						return true
-					}
-					c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
-					return true
-				case <-stopChan:
-					c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
-					return false
-				}
-			})
-			err = resp.Body.Close()
+			err, responseText := claudeStreamHandler(c, resp)
 			if err != nil {
-				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+				return err
 			}
+			streamResponseText = responseText
 			return nil
 		} else {
-			responseBody, err := io.ReadAll(resp.Body)
-			if err != nil {
-				return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
-			}
-			err = resp.Body.Close()
-			if err != nil {
-				return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
-			}
-			var claudeResponse ClaudeResponse
-			err = json.Unmarshal(responseBody, &claudeResponse)
+			err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
 			if err != nil {
-				return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+				return err
 			}
-			if claudeResponse.Error.Type != "" {
-				return &OpenAIErrorWithStatusCode{
-					OpenAIError: OpenAIError{
-						Message: claudeResponse.Error.Message,
-						Type:    claudeResponse.Error.Type,
-						Param:   "",
-						Code:    claudeResponse.Error.Type,
-					},
-					StatusCode: resp.StatusCode,
-				}
-			}
-			fullTextResponse := responseClaude2OpenAI(&claudeResponse)
-			completionTokens := countTokenText(claudeResponse.Completion, textRequest.Model)
-			fullTextResponse.Usage = Usage{
-				PromptTokens:     promptTokens,
-				CompletionTokens: completionTokens,
-				TotalTokens:      promptTokens + completionTokens,
-			}
-			textResponse.Usage = fullTextResponse.Usage
-			jsonResponse, err := json.Marshal(fullTextResponse)
-			if err != nil {
-				return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
-			}
-			c.Writer.Header().Set("Content-Type", "application/json")
-			c.Writer.WriteHeader(resp.StatusCode)
-			_, err = c.Writer.Write(jsonResponse)
+			textResponse.Usage = *usage
 			return nil
 		}
 	default: