Parcourir la source

feat: support /v1/moderations now (close #117)

JustSong il y a 2 ans
Parent
commit
4339f45f74
5 fichiers modifiés avec 39 ajouts et 3 suppressions
  1. 2 2
      common/model-ratio.go
  2. 18 0
      controller/model.go
  3. 12 0
      controller/relay.go
  4. 6 0
      middleware/distributor.go
  5. 1 1
      router/relay-router.go

+ 2 - 2
common/model-ratio.go

@@ -26,8 +26,8 @@ var ModelRatio = map[string]float64{
 	"ada":                     10,
 	"text-embedding-ada-002":  0.2,
 	"text-search-ada-doc-001": 10,
-	"text-moderation-stable":  10,
-	"text-moderation-latest":  10,
+	"text-moderation-stable":  0.1,
+	"text-moderation-latest":  0.1,
 }
 
 func ModelRatio2JSONString() string {

+ 18 - 0
controller/model.go

@@ -161,6 +161,24 @@ func init() {
 			Root:       "text-ada-001",
 			Parent:     nil,
 		},
+		{
+			Id:         "text-moderation-latest",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "text-moderation-latest",
+			Parent:     nil,
+		},
+		{
+			Id:         "text-moderation-stable",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "text-moderation-stable",
+			Parent:     nil,
+		},
 	}
 	openAIModelsMap = make(map[string]OpenAIModels)
 	for _, model := range openAIModels {

+ 12 - 0
controller/relay.go

@@ -24,6 +24,7 @@ const (
 	RelayModeChatCompletions
 	RelayModeCompletions
 	RelayModeEmbeddings
+	RelayModeModeration
 )
 
 // https://platform.openai.com/docs/api-reference/chat
@@ -37,6 +38,7 @@ type GeneralOpenAIRequest struct {
 	Temperature float64   `json:"temperature"`
 	TopP        float64   `json:"top_p"`
 	N           int       `json:"n"`
+	Input       string    `json:"input"`
 }
 
 type ChatRequest struct {
@@ -100,6 +102,8 @@ func Relay(c *gin.Context) {
 		relayMode = RelayModeCompletions
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
 		relayMode = RelayModeEmbeddings
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
+		relayMode = RelayModeModeration
 	}
 	err := relayHelper(c, relayMode)
 	if err != nil {
@@ -143,6 +147,9 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 		}
 	}
+	if relayMode == RelayModeModeration && textRequest.Model == "" {
+		textRequest.Model = "text-moderation-latest"
+	}
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()
 	if channelType == common.ChannelTypeCustom {
@@ -180,6 +187,8 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
 	case RelayModeCompletions:
 		promptTokens = countTokenText(textRequest.Prompt, textRequest.Model)
+	case RelayModeModeration:
+		promptTokens = countTokenText(textRequest.Input, textRequest.Model)
 	}
 	preConsumedTokens := common.PreConsumedQuota
 	if textRequest.MaxTokens != 0 {
@@ -239,6 +248,9 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 				quota = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens*completionRatio
 			}
 			quota = int(float64(quota) * ratio)
+			if ratio != 0 && quota <= 0 {
+				quota = 1
+			}
 			quotaDelta := quota - preConsumedQuota
 			err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 			if err != nil {

+ 6 - 0
middleware/distributor.go

@@ -7,6 +7,7 @@ import (
 	"one-api/common"
 	"one-api/model"
 	"strconv"
+	"strings"
 )
 
 type ModelRequest struct {
@@ -64,6 +65,11 @@ func Distribute() func(c *gin.Context) {
 				c.Abort()
 				return
 			}
+			if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
+				if modelRequest.Model == "" {
+					modelRequest.Model = "text-moderation-stable"
+				}
+			}
 			userId := c.GetInt("id")
 			userGroup, _ := model.GetUserGroup(userId)
 			channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)

+ 1 - 1
router/relay-router.go

@@ -37,6 +37,6 @@ func SetRelayRouter(router *gin.Engine) {
 		relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
 		relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
 		relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
-		relayV1Router.POST("/moderations", controller.RelayNotImplemented)
+		relayV1Router.POST("/moderations", controller.Relay)
 	}
 }