Explorar o código

feat: support /v1/edits now (close #196)

JustSong %!s(int64=2) %!d(string=hai) anos
pai
achega
9b178a28a3
Modificáronse 4 ficheiros con 41 adicións e 12 borrados
  1. 18 0
      controller/model.go
  2. 16 9
      controller/relay-text.go
  3. 6 2
      controller/relay.go
  4. 1 1
      router/relay-router.go

+ 18 - 0
controller/model.go

@@ -224,6 +224,24 @@ func init() {
 			Root:       "text-moderation-stable",
 			Root:       "text-moderation-stable",
 			Parent:     nil,
 			Parent:     nil,
 		},
 		},
+		{
+			Id:         "text-davinci-edit-001",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "text-davinci-edit-001",
+			Parent:     nil,
+		},
+		{
+			Id:         "code-davinci-edit-001",
+			Object:     "model",
+			Created:    1677649963,
+			OwnedBy:    "openai",
+			Permission: permission,
+			Root:       "code-davinci-edit-001",
+			Parent:     nil,
+		},
 	}
 	}
 	openAIModelsMap = make(map[string]OpenAIModels)
 	openAIModelsMap = make(map[string]OpenAIModels)
 	for _, model := range openAIModels {
 	for _, model := range openAIModels {

+ 16 - 9
controller/relay-text.go

@@ -27,7 +27,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 		}
 		}
 	}
 	}
-	if relayMode == RelayModeModeration && textRequest.Model == "" {
+	if relayMode == RelayModeModerations && textRequest.Model == "" {
 		textRequest.Model = "text-moderation-latest"
 		textRequest.Model = "text-moderation-latest"
 	}
 	}
 	// request validation
 	// request validation
@@ -37,16 +37,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	switch relayMode {
 	switch relayMode {
 	case RelayModeCompletions:
 	case RelayModeCompletions:
 		if textRequest.Prompt == "" {
 		if textRequest.Prompt == "" {
-			return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
+			return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
 		}
 		}
 	case RelayModeChatCompletions:
 	case RelayModeChatCompletions:
-		if len(textRequest.Messages) == 0 {
-			return errorWrapper(errors.New("messages is required"), "required_field_missing", http.StatusBadRequest)
+		if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
+			return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
 		}
 		}
 	case RelayModeEmbeddings:
 	case RelayModeEmbeddings:
-	case RelayModeModeration:
+	case RelayModeModerations:
 		if textRequest.Input == "" {
 		if textRequest.Input == "" {
-			return errorWrapper(errors.New("input is required"), "required_field_missing", http.StatusBadRequest)
+			return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
+		}
+	case RelayModeEdits:
+		if textRequest.Instruction == "" {
+			return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
 		}
 		}
 	}
 	}
 	baseURL := common.ChannelBaseURLs[channelType]
 	baseURL := common.ChannelBaseURLs[channelType]
@@ -84,7 +88,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
 		promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
 	case RelayModeCompletions:
 	case RelayModeCompletions:
 		promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
 		promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
-	case RelayModeModeration:
+	case RelayModeModerations:
 		promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
 		promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
 	}
 	}
 	preConsumedTokens := common.PreConsumedQuota
 	preConsumedTokens := common.PreConsumedQuota
@@ -144,7 +148,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	defer func() {
 	defer func() {
 		if consumeQuota {
 		if consumeQuota {
 			quota := 0
 			quota := 0
-			completionRatio := 1.333333 // default for gpt-3
+			completionRatio := 1.0
+			if strings.HasPrefix(textRequest.Model, "gpt-3.5") {
+				completionRatio = 1.333333
+			}
 			if strings.HasPrefix(textRequest.Model, "gpt-4") {
 			if strings.HasPrefix(textRequest.Model, "gpt-4") {
 				completionRatio = 2
 				completionRatio = 2
 			}
 			}
@@ -172,7 +179,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 			}
 			}
 			if quota != 0 {
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				tokenName := c.GetString("token_name")
-				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
 				model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 				model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				channelId := c.GetInt("channel_id")

+ 6 - 2
controller/relay.go

@@ -19,8 +19,9 @@ const (
 	RelayModeChatCompletions
 	RelayModeChatCompletions
 	RelayModeCompletions
 	RelayModeCompletions
 	RelayModeEmbeddings
 	RelayModeEmbeddings
-	RelayModeModeration
+	RelayModeModerations
 	RelayModeImagesGenerations
 	RelayModeImagesGenerations
+	RelayModeEdits
 )
 )
 
 
 // https://platform.openai.com/docs/api-reference/chat
 // https://platform.openai.com/docs/api-reference/chat
@@ -35,6 +36,7 @@ type GeneralOpenAIRequest struct {
 	TopP        float64   `json:"top_p"`
 	TopP        float64   `json:"top_p"`
 	N           int       `json:"n"`
 	N           int       `json:"n"`
 	Input       any       `json:"input"`
 	Input       any       `json:"input"`
+	Instruction string    `json:"instruction"`
 }
 }
 
 
 type ChatRequest struct {
 type ChatRequest struct {
@@ -99,9 +101,11 @@ func Relay(c *gin.Context) {
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
 		relayMode = RelayModeEmbeddings
 		relayMode = RelayModeEmbeddings
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
-		relayMode = RelayModeModeration
+		relayMode = RelayModeModerations
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
 		relayMode = RelayModeImagesGenerations
 		relayMode = RelayModeImagesGenerations
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
+		relayMode = RelayModeEdits
 	}
 	}
 	var err *OpenAIErrorWithStatusCode
 	var err *OpenAIErrorWithStatusCode
 	switch relayMode {
 	switch relayMode {

+ 1 - 1
router/relay-router.go

@@ -19,7 +19,7 @@ func SetRelayRouter(router *gin.Engine) {
 	{
 	{
 		relayV1Router.POST("/completions", controller.Relay)
 		relayV1Router.POST("/completions", controller.Relay)
 		relayV1Router.POST("/chat/completions", controller.Relay)
 		relayV1Router.POST("/chat/completions", controller.Relay)
-		relayV1Router.POST("/edits", controller.RelayNotImplemented)
+		relayV1Router.POST("/edits", controller.Relay)
 		relayV1Router.POST("/images/generations", controller.RelayNotImplemented)
 		relayV1Router.POST("/images/generations", controller.RelayNotImplemented)
 		relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
 		relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
 		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)