|
|
@@ -2,6 +2,7 @@ package controller
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
+ "errors"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"log"
|
|
|
@@ -104,7 +105,12 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|
|
|
|
|
request, err := helper.GetAndValidateRequest(c, relayFormat)
|
|
|
if err != nil {
|
|
|
- newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
|
|
+ // Map "request body too large" to 413 so clients can handle it correctly
|
|
|
+ if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
|
|
|
+ newAPIError = types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry())
|
|
|
+ } else {
|
|
|
+ newAPIError = types.NewError(err, types.ErrorCodeInvalidRequest)
|
|
|
+ }
|
|
|
return
|
|
|
}
|
|
|
|
|
|
@@ -114,9 +120,17 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- meta := request.GetTokenCountMeta()
|
|
|
+ needSensitiveCheck := setting.ShouldCheckPromptSensitive()
|
|
|
+ needCountToken := constant.CountToken
|
|
|
+ // Avoid building huge CombineText (strings.Join) when token counting and sensitive check are both disabled.
|
|
|
+ var meta *types.TokenCountMeta
|
|
|
+ if needSensitiveCheck || needCountToken {
|
|
|
+ meta = request.GetTokenCountMeta()
|
|
|
+ } else {
|
|
|
+ meta = fastTokenCountMetaForPricing(request)
|
|
|
+ }
|
|
|
|
|
|
- if setting.ShouldCheckPromptSensitive() {
|
|
|
+ if needSensitiveCheck && meta != nil {
|
|
|
contains, words := service.CheckSensitiveText(meta.CombineText)
|
|
|
if contains {
|
|
|
logger.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ", ")))
|
|
|
@@ -165,15 +179,24 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
|
|
|
}
|
|
|
|
|
|
for ; retryParam.GetRetry() <= common.RetryTimes; retryParam.IncreaseRetry() {
|
|
|
- channel, err := getChannel(c, relayInfo, retryParam)
|
|
|
- if err != nil {
|
|
|
- logger.LogError(c, err.Error())
|
|
|
- newAPIError = err
|
|
|
+ channel, channelErr := getChannel(c, relayInfo, retryParam)
|
|
|
+ if channelErr != nil {
|
|
|
+ logger.LogError(c, channelErr.Error())
|
|
|
+ newAPIError = channelErr
|
|
|
break
|
|
|
}
|
|
|
|
|
|
addUsedChannel(c, channel.Id)
|
|
|
- requestBody, _ := common.GetRequestBody(c)
|
|
|
+ requestBody, bodyErr := common.GetRequestBody(c)
|
|
|
+ if bodyErr != nil {
|
|
|
+ // Ensure consistent 413 for oversized bodies even when error occurs later (e.g., retry path)
|
|
|
+ if common.IsRequestBodyTooLargeError(bodyErr) || errors.Is(bodyErr, common.ErrRequestBodyTooLarge) {
|
|
|
+ newAPIError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusRequestEntityTooLarge, types.ErrOptionWithSkipRetry())
|
|
|
+ } else {
|
|
|
+ newAPIError = types.NewErrorWithStatusCode(bodyErr, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
|
|
|
+ }
|
|
|
+ break
|
|
|
+ }
|
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
|
|
|
|
switch relayFormat {
|
|
|
@@ -218,6 +241,33 @@ func addUsedChannel(c *gin.Context, channelId int) {
|
|
|
c.Set("use_channel", useChannel)
|
|
|
}
|
|
|
|
|
|
+func fastTokenCountMetaForPricing(request dto.Request) *types.TokenCountMeta {
|
|
|
+ if request == nil {
|
|
|
+ return &types.TokenCountMeta{}
|
|
|
+ }
|
|
|
+ meta := &types.TokenCountMeta{
|
|
|
+ TokenType: types.TokenTypeTokenizer,
|
|
|
+ }
|
|
|
+ switch r := request.(type) {
|
|
|
+ case *dto.GeneralOpenAIRequest:
|
|
|
+ if r.MaxCompletionTokens > r.MaxTokens {
|
|
|
+ meta.MaxTokens = int(r.MaxCompletionTokens)
|
|
|
+ } else {
|
|
|
+ meta.MaxTokens = int(r.MaxTokens)
|
|
|
+ }
|
|
|
+ case *dto.OpenAIResponsesRequest:
|
|
|
+ meta.MaxTokens = int(r.MaxOutputTokens)
|
|
|
+ case *dto.ClaudeRequest:
|
|
|
+ meta.MaxTokens = int(r.MaxTokens)
|
|
|
+ case *dto.ImageRequest:
|
|
|
+ // Pricing for image requests depends on ImagePriceRatio; safe to compute even when CountToken is disabled.
|
|
|
+ return r.GetTokenCountMeta()
|
|
|
+ default:
|
|
|
+ // Best-effort: leave CombineText empty to avoid large allocations.
|
|
|
+ }
|
|
|
+ return meta
|
|
|
+}
|
|
|
+
|
|
|
func getChannel(c *gin.Context, info *relaycommon.RelayInfo, retryParam *service.RetryParam) (*model.Channel, *types.NewAPIError) {
|
|
|
if info.ChannelMeta == nil {
|
|
|
autoBan := c.GetBool("auto_ban")
|
|
|
@@ -432,7 +482,15 @@ func RelayTask(c *gin.Context) {
|
|
|
logger.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, retryParam.GetRetry()))
|
|
|
//middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
|
|
|
|
|
- requestBody, _ := common.GetRequestBody(c)
|
|
|
+ requestBody, err := common.GetRequestBody(c)
|
|
|
+ if err != nil {
|
|
|
+ if common.IsRequestBodyTooLargeError(err) || errors.Is(err, common.ErrRequestBodyTooLarge) {
|
|
|
+ taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusRequestEntityTooLarge)
|
|
|
+ } else {
|
|
|
+ taskErr = service.TaskErrorWrapperLocal(err, "read_request_body_failed", http.StatusBadRequest)
|
|
|
+ }
|
|
|
+ break
|
|
|
+ }
|
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
|
taskErr = taskRelayHandler(c, relayInfo)
|
|
|
}
|