|
|
@@ -24,6 +24,7 @@ const (
|
|
|
RelayModeChatCompletions
|
|
|
RelayModeCompletions
|
|
|
RelayModeEmbeddings
|
|
|
+ RelayModeModeration
|
|
|
)
|
|
|
|
|
|
// https://platform.openai.com/docs/api-reference/chat
|
|
|
@@ -37,6 +38,7 @@ type GeneralOpenAIRequest struct {
|
|
|
Temperature float64 `json:"temperature"`
|
|
|
TopP float64 `json:"top_p"`
|
|
|
N int `json:"n"`
|
|
|
+ Input string `json:"input"`
|
|
|
}
|
|
|
|
|
|
type ChatRequest struct {
|
|
|
@@ -100,6 +102,8 @@ func Relay(c *gin.Context) {
|
|
|
relayMode = RelayModeCompletions
|
|
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
|
|
relayMode = RelayModeEmbeddings
|
|
|
+ } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
|
|
+ relayMode = RelayModeModeration
|
|
|
}
|
|
|
err := relayHelper(c, relayMode)
|
|
|
if err != nil {
|
|
|
@@ -143,6 +147,9 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
|
|
}
|
|
|
}
|
|
|
+ if relayMode == RelayModeModeration && textRequest.Model == "" {
|
|
|
+ textRequest.Model = "text-moderation-latest"
|
|
|
+ }
|
|
|
baseURL := common.ChannelBaseURLs[channelType]
|
|
|
requestURL := c.Request.URL.String()
|
|
|
if channelType == common.ChannelTypeCustom {
|
|
|
@@ -180,6 +187,8 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
|
|
|
case RelayModeCompletions:
|
|
|
promptTokens = countTokenText(textRequest.Prompt, textRequest.Model)
|
|
|
+ case RelayModeModeration:
|
|
|
+ promptTokens = countTokenText(textRequest.Input, textRequest.Model)
|
|
|
}
|
|
|
preConsumedTokens := common.PreConsumedQuota
|
|
|
if textRequest.MaxTokens != 0 {
|
|
|
@@ -239,6 +248,9 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|
|
quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
|
|
|
}
|
|
|
quota = int(float64(quota) * ratio)
|
|
|
+ if ratio != 0 && quota <= 0 {
|
|
|
+ quota = 1
|
|
|
+ }
|
|
|
quotaDelta := quota - preConsumedQuota
|
|
|
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
|
|
if err != nil {
|