Просмотр исходного кода

Merge pull request #1814 from ShibaInu64/fix/volcengine-image-model

fix: VolcEngine渠道-图片生成 API-渠道测试报错
Calcium-Ion 3 месяцев назад
Родитель
Сommit
45a6a779e5

+ 39 - 0
controller/channel-test.go

@@ -90,6 +90,11 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 		requestPath = "/v1/embeddings" // 修改请求路径
 	}
 
+	// VolcEngine 图像生成模型
+	if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
+		requestPath = "/v1/images/generations"
+	}
+
 	c.Request = &http.Request{
 		Method: "POST",
 		URL:    &url.URL{Path: requestPath}, // 使用动态路径
@@ -109,6 +114,21 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 		}
 	}
 
+	// 重新检查模型类型并更新请求路径
+	if strings.Contains(strings.ToLower(testModel), "embedding") ||
+		strings.HasPrefix(testModel, "m3e") ||
+		strings.Contains(testModel, "bge-") ||
+		strings.Contains(testModel, "embed") ||
+		channel.Type == constant.ChannelTypeMokaAI {
+		requestPath = "/v1/embeddings"
+		c.Request.URL.Path = requestPath
+	}
+
+	if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
+		requestPath = "/v1/images/generations"
+		c.Request.URL.Path = requestPath
+	}
+
 	cache, err := model.GetUserCache(1)
 	if err != nil {
 		return testResult{
@@ -140,6 +160,9 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 	if c.Request.URL.Path == "/v1/embeddings" {
 		relayFormat = types.RelayFormatEmbedding
 	}
+	if c.Request.URL.Path == "/v1/images/generations" {
+		relayFormat = types.RelayFormatOpenAIImage
+	}
 
 	info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
 
@@ -201,6 +224,22 @@ func testChannel(channel *model.Channel, testModel string) testResult {
 		}
 		// 调用专门用于 Embedding 的转换函数
 		convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, embeddingRequest)
+	} else if info.RelayMode == relayconstant.RelayModeImagesGenerations {
+		// 创建一个 ImageRequest
+		prompt := "cat"
+		if request.Prompt != nil {
+			if promptStr, ok := request.Prompt.(string); ok && promptStr != "" {
+				prompt = promptStr
+			}
+		}
+		imageRequest := dto.ImageRequest{
+			Prompt: prompt,
+			Model:  request.Model,
+			N:      uint(request.N),
+			Size:   request.Size,
+		}
+		// 调用专门用于图像生成的转换函数
+		convertedRequest, err = adaptor.ConvertImageRequest(c, info, imageRequest)
 	} else {
 		// 对其他所有请求类型(如 Chat),保持原有逻辑
 		convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, request)

+ 2 - 0
relay/channel/volcengine/adaptor.go

@@ -41,6 +41,8 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf
 
 func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
 	switch info.RelayMode {
+	case constant.RelayModeImagesGenerations:
+		return request, nil
 	case constant.RelayModeImagesEdits:
 
 		var requestBody bytes.Buffer

+ 1 - 0
relay/channel/volcengine/constants.go

@@ -8,6 +8,7 @@ var ModelList = []string{
 	"Doubao-lite-32k",
 	"Doubao-lite-4k",
 	"Doubao-embedding",
+	"doubao-seedream-4-0-250828",
 }
 
 var ChannelName = "volcengine"