|
@@ -27,7 +27,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- if relayMode == RelayModeModeration && textRequest.Model == "" {
|
|
|
|
|
|
|
+ if relayMode == RelayModeModerations && textRequest.Model == "" {
|
|
|
textRequest.Model = "text-moderation-latest"
|
|
textRequest.Model = "text-moderation-latest"
|
|
|
}
|
|
}
|
|
|
// request validation
|
|
// request validation
|
|
@@ -37,16 +37,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
switch relayMode {
|
|
switch relayMode {
|
|
|
case RelayModeCompletions:
|
|
case RelayModeCompletions:
|
|
|
if textRequest.Prompt == "" {
|
|
if textRequest.Prompt == "" {
|
|
|
- return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
|
|
|
|
|
|
|
+ return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
|
|
|
}
|
|
}
|
|
|
case RelayModeChatCompletions:
|
|
case RelayModeChatCompletions:
|
|
|
- if len(textRequest.Messages) == 0 {
|
|
|
|
|
- return errorWrapper(errors.New("messages is required"), "required_field_missing", http.StatusBadRequest)
|
|
|
|
|
|
|
+ if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
|
|
|
|
+ return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
|
|
|
}
|
|
}
|
|
|
case RelayModeEmbeddings:
|
|
case RelayModeEmbeddings:
|
|
|
- case RelayModeModeration:
|
|
|
|
|
|
|
+ case RelayModeModerations:
|
|
|
if textRequest.Input == "" {
|
|
if textRequest.Input == "" {
|
|
|
- return errorWrapper(errors.New("input is required"), "required_field_missing", http.StatusBadRequest)
|
|
|
|
|
|
|
+ return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
|
|
|
|
+ }
|
|
|
|
|
+ case RelayModeEdits:
|
|
|
|
|
+ if textRequest.Instruction == "" {
|
|
|
|
|
+ return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
baseURL := common.ChannelBaseURLs[channelType]
|
|
baseURL := common.ChannelBaseURLs[channelType]
|
|
@@ -84,7 +88,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
|
|
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
|
|
|
case RelayModeCompletions:
|
|
case RelayModeCompletions:
|
|
|
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
|
|
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
|
|
|
- case RelayModeModeration:
|
|
|
|
|
|
|
+ case RelayModeModerations:
|
|
|
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
|
|
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
|
|
|
}
|
|
}
|
|
|
preConsumedTokens := common.PreConsumedQuota
|
|
preConsumedTokens := common.PreConsumedQuota
|
|
@@ -144,7 +148,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
defer func() {
|
|
defer func() {
|
|
|
if consumeQuota {
|
|
if consumeQuota {
|
|
|
quota := 0
|
|
quota := 0
|
|
|
- completionRatio := 1.333333 // default for gpt-3
|
|
|
|
|
|
|
+ completionRatio := 1.0
|
|
|
|
|
+ if strings.HasPrefix(textRequest.Model, "gpt-3.5") {
|
|
|
|
|
+ completionRatio = 1.333333
|
|
|
|
|
+ }
|
|
|
if strings.HasPrefix(textRequest.Model, "gpt-4") {
|
|
if strings.HasPrefix(textRequest.Model, "gpt-4") {
|
|
|
completionRatio = 2
|
|
completionRatio = 2
|
|
|
}
|
|
}
|
|
@@ -172,7 +179,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
}
|
|
}
|
|
|
if quota != 0 {
|
|
if quota != 0 {
|
|
|
tokenName := c.GetString("token_name")
|
|
tokenName := c.GetString("token_name")
|
|
|
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
|
|
|
|
|
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
|
|
|
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
|
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
|
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
|
channelId := c.GetInt("channel_id")
|
|
channelId := c.GetInt("channel_id")
|