Преглед изворни кода

fix: 修复流模式错误扣费的问题 (close #95)

[email protected] пре 1 година
родитељ
комит
626217fbd4

+ 5 - 2
controller/channel-test.go

@@ -24,6 +24,9 @@ import (
 )
 
 func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
+	if channel.Type == common.ChannelTypeMidjourney {
+		return errors.New("midjourney channel test is not supported"), nil
+	}
 	common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
 	w := httptest.NewRecorder()
 	c, _ := gin.CreateTestContext(w)
@@ -68,11 +71,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 	}
 	if resp.StatusCode != http.StatusOK {
 		err := relaycommon.RelayErrorHandler(resp)
-		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError
+		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
 	}
 	usage, respErr := adaptor.DoResponse(c, resp, meta)
 	if respErr != nil {
-		return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError
+		return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
 	}
 	if usage == nil {
 		return errors.New("usage is nil"), nil

+ 8 - 8
controller/relay.go

@@ -38,24 +38,24 @@ func Relay(c *gin.Context) {
 			retryTimes = common.RetryTimes
 		}
 		if retryTimes > 0 {
-			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Message))
+			c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d&error=%s", c.Request.URL.Path, retryTimes-1, err.Error.Message))
 		} else {
 			if err.StatusCode == http.StatusTooManyRequests {
-				//err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
+				//err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
 			}
-			err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
+			err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
 			c.JSON(err.StatusCode, gin.H{
-				"error": err.OpenAIError,
+				"error": err.Error,
 			})
 		}
 		channelId := c.GetInt("channel_id")
 		autoBan := c.GetBool("auto_ban")
-		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
+		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Error.Message))
 		// https://platform.openai.com/docs/guides/error-codes/api-errors
-		if service.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
+		if service.ShouldDisableChannel(&err.Error, err.StatusCode) && autoBan {
 			channelId := c.GetInt("channel_id")
 			channelName := c.GetString("channel_name")
-			service.DisableChannel(channelId, channelName, err.Message)
+			service.DisableChannel(channelId, channelName, err.Error.Message)
 		}
 	}
 }
@@ -110,7 +110,7 @@ func RelayMidjourney(c *gin.Context) {
 		}
 		channelId := c.GetInt("channel_id")
 		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, fmt.Sprintf("%s %s", err.Description, err.Result)))
-		//if shouldDisableChannel(&err.OpenAIError) {
+		//if shouldDisableChannel(&err.Error) {
 		//	channelId := c.GetInt("channel_id")
 		//	channelName := c.GetString("channel_name")
 		//	disableChannel(channelId, channelName, err.Result)

+ 43 - 2
dto/error.go

@@ -8,6 +8,47 @@ type OpenAIError struct {
 }
 
 type OpenAIErrorWithStatusCode struct {
-	OpenAIError
-	StatusCode int `json:"status_code"`
+	Error      OpenAIError `json:"error"`
+	StatusCode int         `json:"status_code"`
+}
+
+type GeneralErrorResponse struct {
+	Error    OpenAIError `json:"error"`
+	Message  string      `json:"message"`
+	Msg      string      `json:"msg"`
+	Err      string      `json:"err"`
+	ErrorMsg string      `json:"error_msg"`
+	Header   struct {
+		Message string `json:"message"`
+	} `json:"header"`
+	Response struct {
+		Error struct {
+			Message string `json:"message"`
+		} `json:"error"`
+	} `json:"response"`
+}
+
+func (e GeneralErrorResponse) ToMessage() string {
+	if e.Error.Message != "" {
+		return e.Error.Message
+	}
+	if e.Message != "" {
+		return e.Message
+	}
+	if e.Msg != "" {
+		return e.Msg
+	}
+	if e.Err != "" {
+		return e.Err
+	}
+	if e.ErrorMsg != "" {
+		return e.ErrorMsg
+	}
+	if e.Header.Message != "" {
+		return e.Header.Message
+	}
+	if e.Response.Error.Message != "" {
+		return e.Response.Error.Message
+	}
+	return ""
 }

+ 2 - 2
relay/channel/ali/relay-ali.go

@@ -71,7 +71,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorW
 
 	if aliResponse.Code != "" {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: dto.OpenAIError{
+			Error: dto.OpenAIError{
 				Message: aliResponse.Message,
 				Type:    aliResponse.Code,
 				Param:   aliResponse.RequestId,
@@ -236,7 +236,7 @@ func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatus
 	}
 	if aliResponse.Code != "" {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: dto.OpenAIError{
+			Error: dto.OpenAIError{
 				Message: aliResponse.Message,
 				Type:    aliResponse.Code,
 				Param:   aliResponse.RequestId,

+ 2 - 2
relay/channel/baidu/relay-baidu.go

@@ -173,7 +173,7 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
 	}
 	if baiduResponse.ErrorMsg != "" {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: dto.OpenAIError{
+			Error: dto.OpenAIError{
 				Message: baiduResponse.ErrorMsg,
 				Type:    "baidu_error",
 				Param:   "",
@@ -209,7 +209,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErro
 	}
 	if baiduResponse.ErrorMsg != "" {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: dto.OpenAIError{
+			Error: dto.OpenAIError{
 				Message: baiduResponse.ErrorMsg,
 				Type:    "baidu_error",
 				Param:   "",

+ 22 - 2
relay/channel/claude/adaptor.go

@@ -10,17 +10,32 @@ import (
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
+	"strings"
+)
+
+const (
+	RequestModeCompletion = 1
+	RequestModeMessage    = 2
 )
 
 type Adaptor struct {
+	RequestMode int
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
-
+	if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
+		a.RequestMode = RequestModeMessage
+	} else {
+		a.RequestMode = RequestModeCompletion
+	}
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
+	if a.RequestMode == RequestModeMessage {
+		return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
+	} else {
+		return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
+	}
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
@@ -38,6 +53,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
+	//if a.RequestMode == RequestModeCompletion {
+	//	return requestOpenAI2ClaudeComplete(*request), nil
+	//} else {
+	//	return requestOpenAI2ClaudeMessage(*request), nil
+	//}
 	return request, nil
 }
 

+ 9 - 3
relay/channel/claude/relay-claude.go

@@ -24,7 +24,7 @@ func stopReasonClaude2OpenAI(reason string) string {
 	}
 }
 
-func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
+func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
 	claudeRequest := ClaudeRequest{
 		Model:             textRequest.Model,
 		Prompt:            "",
@@ -44,7 +44,9 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
 		} else if message.Role == "assistant" {
 			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
 		} else if message.Role == "system" {
-			prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
+			if prompt == "" {
+				prompt = message.StringContent()
+			}
 		}
 	}
 	prompt += "\n\nAssistant:"
@@ -52,6 +54,10 @@ func requestOpenAI2Claude(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
 	return &claudeRequest
 }
 
+//func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
+//
+//}
+
 func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *dto.ChatCompletionsStreamResponse {
 	var choice dto.ChatCompletionsStreamResponseChoice
 	choice.Delta.Content = claudeResponse.Completion
@@ -167,7 +173,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	}
 	if claudeResponse.Error.Type != "" {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: dto.OpenAIError{
+			Error: dto.OpenAIError{
 				Message: claudeResponse.Error.Message,
 				Type:    claudeResponse.Error.Type,
 				Param:   "",

+ 1 - 1
relay/channel/gemini/relay-gemini.go

@@ -246,7 +246,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
 	}
 	if len(geminiResponse.Candidates) == 0 {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: dto.OpenAIError{
+			Error: dto.OpenAIError{
 				Message: "No candidates returned",
 				Type:    "server_error",
 				Param:   "",

+ 2 - 2
relay/channel/openai/relay-openai.go

@@ -127,8 +127,8 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 	}
 	if textResponse.Error.Type != "" {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: textResponse.Error,
-			StatusCode:  resp.StatusCode,
+			Error:      textResponse.Error,
+			StatusCode: resp.StatusCode,
 		}, nil
 	}
 	// Reset response body

+ 1 - 1
relay/channel/palm/relay-palm.go

@@ -146,7 +146,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
 	}
 	if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: dto.OpenAIError{
+			Error: dto.OpenAIError{
 				Message: palmResponse.Error.Message,
 				Type:    palmResponse.Error.Status,
 				Param:   "",

+ 1 - 1
relay/channel/tencent/relay-tencent.go

@@ -175,7 +175,7 @@ func tencentHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithSt
 	}
 	if TencentResponse.Error.Code != 0 {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: dto.OpenAIError{
+			Error: dto.OpenAIError{
 				Message: TencentResponse.Error.Message,
 				Code:    TencentResponse.Error.Code,
 			},

+ 1 - 1
relay/channel/zhipu/relay-zhipu.go

@@ -244,7 +244,7 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
 	}
 	if !zhipuResponse.Success {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: dto.OpenAIError{
+			Error: dto.OpenAIError{
 				Message: zhipuResponse.Msg,
 				Type:    "zhipu_error",
 				Param:   "",

+ 2 - 2
relay/channel/zhipu_4v/relay-zhipu_v4.go

@@ -234,8 +234,8 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStat
 	}
 	if textResponse.Error.Type != "" {
 		return &dto.OpenAIErrorWithStatusCode{
-			OpenAIError: textResponse.Error,
-			StatusCode:  resp.StatusCode,
+			Error:      textResponse.Error,
+			StatusCode: resp.StatusCode,
 		}, nil
 	}
 	// Reset response body

+ 4 - 4
relay/common/relay_utils.go

@@ -17,10 +17,10 @@ import (
 
 var StopFinishReason = "stop"
 
-func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
-	openAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
+func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) {
+	OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{
 		StatusCode: resp.StatusCode,
-		OpenAIError: dto.OpenAIError{
+		Error: dto.OpenAIError{
 			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
 			Type:    "upstream_error",
 			Code:    "bad_response_status_code",
@@ -40,7 +40,7 @@ func RelayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *dto.Open
 	if err != nil {
 		return
 	}
-	openAIErrorWithStatusCode.OpenAIError = textResponse.Error
+	OpenAIErrorWithStatusCode.Error = textResponse.Error
 	return
 }
 

+ 22 - 0
relay/relay-text.go

@@ -2,6 +2,7 @@ package relay
 
 import (
 	"bytes"
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -148,10 +149,19 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 	}
 
 	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+	}
 	relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 
+	if resp.StatusCode != http.StatusOK {
+		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
+		return service.RelayErrorHandler(resp)
+	}
+
 	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
 	if openaiErr != nil {
+		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
 		return openaiErr
 	}
 	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
@@ -218,6 +228,18 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
 	return preConsumedQuota, userQuota, nil
 }
 
+func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsumedQuota int) {
+	if preConsumedQuota != 0 {
+		go func(ctx context.Context) {
+			// return pre-consumed quota
+			err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
+			if err != nil {
+				common.SysError("error return pre-consumed quota: " + err.Error())
+			}
+		}(c)
+	}
+}
+
 func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
 	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 	promptTokens := usage.PromptTokens

+ 41 - 2
service/error.go

@@ -1,9 +1,13 @@
 package service
 
 import (
+	"encoding/json"
 	"fmt"
+	"io"
+	"net/http"
 	"one-api/common"
 	"one-api/dto"
+	"strconv"
 	"strings"
 )
 
@@ -23,7 +27,42 @@ func OpenAIErrorWrapper(err error, code string, statusCode int) *dto.OpenAIError
 		Code:    code,
 	}
 	return &dto.OpenAIErrorWithStatusCode{
-		OpenAIError: openAIError,
-		StatusCode:  statusCode,
+		Error:      openAIError,
+		StatusCode: statusCode,
 	}
 }
+
+func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorWithStatusCode) {
+	errWithStatusCode = &dto.OpenAIErrorWithStatusCode{
+		StatusCode: resp.StatusCode,
+		Error: dto.OpenAIError{
+			Message: "",
+			Type:    "upstream_error",
+			Code:    "bad_response_status_code",
+			Param:   strconv.Itoa(resp.StatusCode),
+		},
+	}
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return
+	}
+	var errResponse dto.GeneralErrorResponse
+	err = json.Unmarshal(responseBody, &errResponse)
+	if err != nil {
+		return
+	}
+	if errResponse.Error.Message != "" {
+		// OpenAI format error, so we override the default one
+		errWithStatusCode.Error = errResponse.Error
+	} else {
+		errWithStatusCode.Error.Message = errResponse.ToMessage()
+	}
+	if errWithStatusCode.Error.Message == "" {
+		errWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
+	}
+	return
+}