Browse Source

feat(relay): 添加视频模型映射功能支持

creamlike1024 2 months ago
parent
commit
7fc25a57cf
3 changed files with 101 additions and 0 deletions
  1. 90 0
      relay/channel/task/sora/adaptor.go
  2. 5 0
      relay/common/relay_utils.go
  3. 6 0
      relay/relay_task.go

+ 90 - 0
relay/channel/task/sora/adaptor.go

@@ -2,9 +2,12 @@ package sora
 
 import (
 	"bytes"
+	"encoding/json"
 	"fmt"
 	"io"
+	"mime/multipart"
 	"net/http"
+	"strings"
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/dto"
@@ -87,9 +90,96 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
 	if err != nil {
 		return nil, errors.Wrap(err, "get_request_body_failed")
 	}
+
+	// 检查是否需要模型重定向
+	if !info.IsModelMapped {
+		// 如果不需要重定向,直接返回原始请求体
+		return bytes.NewReader(cachedBody), nil
+	}
+
+	contentType := c.Request.Header.Get("Content-Type")
+
+	// 处理multipart/form-data请求
+	if strings.Contains(contentType, "multipart/form-data") {
+		return buildRequestBodyWithMappedModel(cachedBody, contentType, info.UpstreamModelName)
+	}
+	// 处理JSON请求
+	if strings.Contains(contentType, "application/json") {
+		var jsonData map[string]interface{}
+		if err := json.Unmarshal(cachedBody, &jsonData); err != nil {
+			return nil, errors.Wrap(err, "unmarshal_json_failed")
+		}
+
+		// 替换model字段为映射后的模型名
+		jsonData["model"] = info.UpstreamModelName
+
+		// 重新编码为JSON
+		newBody, err := json.Marshal(jsonData)
+		if err != nil {
+			return nil, errors.Wrap(err, "marshal_json_failed")
+		}
+
+		return bytes.NewReader(newBody), nil
+	}
+
 	return bytes.NewReader(cachedBody), nil
 }
 
+func buildRequestBodyWithMappedModel(originalBody []byte, contentType, redirectedModel string) (io.Reader, error) {
+	newBuffer := &bytes.Buffer{}
+	writer := multipart.NewWriter(newBuffer)
+
+	r := multipart.NewReader(bytes.NewReader(originalBody), strings.TrimPrefix(contentType, "multipart/form-data; boundary="))
+
+	for {
+		part, err := r.NextPart()
+		if err == io.EOF {
+			break
+		}
+		if err != nil {
+			return nil, errors.Wrap(err, "read_multipart_part_failed")
+		}
+
+		fieldName := part.FormName()
+
+		if fieldName == "model" {
+			// 修改 model 字段为映射后的模型名
+			if err := writer.WriteField("model", redirectedModel); err != nil {
+				return nil, errors.Wrap(err, "write_model_field_failed")
+			}
+		} else {
+			// 对于其他字段,保留原始内容
+			if part.FileName() != "" {
+				newPart, err := writer.CreateFormFile(fieldName, part.FileName())
+				if err != nil {
+					return nil, errors.Wrap(err, "create_form_file_failed")
+				}
+				if _, err := io.Copy(newPart, part); err != nil {
+					return nil, errors.Wrap(err, "copy_file_content_failed")
+				}
+			} else {
+				content, err := io.ReadAll(part)
+				if err != nil {
+					return nil, errors.Wrap(err, "read_field_content_failed")
+				}
+				if err := writer.WriteField(fieldName, string(content)); err != nil {
+					return nil, errors.Wrap(err, "write_field_failed")
+				}
+			}
+		}
+
+		if err := part.Close(); err != nil {
+			return nil, errors.Wrap(err, "close_part_failed")
+		}
+	}
+
+	if err := writer.Close(); err != nil {
+		return nil, errors.Wrap(err, "close_multipart_writer_failed")
+	}
+
+	return newBuffer, nil
+}
+
 // DoRequest delegates to common helper.
 func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
 	return channel.DoTaskApiRequest(a, c, info, requestBody)

+ 5 - 0
relay/common/relay_utils.go

@@ -252,6 +252,11 @@ func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *d
 		}
 	}
 
+	// 模型映射
+	if info.IsModelMapped {
+		req.Model = info.UpstreamModelName
+	}
+
 	storeTaskRequest(c, info, action, req)
 	return nil
 }

+ 6 - 0
relay/relay_task.go

@@ -17,6 +17,7 @@ import (
 	"github.com/QuantumNous/new-api/relay/channel"
 	relaycommon "github.com/QuantumNous/new-api/relay/common"
 	relayconstant "github.com/QuantumNous/new-api/relay/constant"
+	"github.com/QuantumNous/new-api/relay/helper"
 	"github.com/QuantumNous/new-api/service"
 	"github.com/QuantumNous/new-api/setting/ratio_setting"
 
@@ -38,6 +39,11 @@ func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.
 	}
 
 	info.InitChannelMeta(c)
+
+	// 模型映射
+	if err := helper.ModelMappedHelper(c, info, nil); err != nil {
+		return service.TaskErrorWrapper(err, "model_mapped_failed", http.StatusBadRequest)
+	}
 	adaptor := GetTaskAdaptor(platform)
 	if adaptor == nil {
 		return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)