|
|
@@ -2,9 +2,12 @@ package sora
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
+ "encoding/json"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
+ "mime/multipart"
|
|
|
"net/http"
|
|
|
+ "strings"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
"github.com/QuantumNous/new-api/dto"
|
|
|
@@ -87,9 +90,96 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
|
|
if err != nil {
|
|
|
return nil, errors.Wrap(err, "get_request_body_failed")
|
|
|
}
|
|
|
+
|
|
|
+ // 检查是否需要模型重定向
|
|
|
+ if !info.IsModelMapped {
|
|
|
+ // 如果不需要重定向,直接返回原始请求体
|
|
|
+ return bytes.NewReader(cachedBody), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ contentType := c.Request.Header.Get("Content-Type")
|
|
|
+
|
|
|
+ // 处理multipart/form-data请求
|
|
|
+ if strings.Contains(contentType, "multipart/form-data") {
|
|
|
+ return buildRequestBodyWithMappedModel(cachedBody, contentType, info.UpstreamModelName)
|
|
|
+ }
|
|
|
+ // 处理JSON请求
|
|
|
+ if strings.Contains(contentType, "application/json") {
|
|
|
+ var jsonData map[string]interface{}
|
|
|
+ if err := json.Unmarshal(cachedBody, &jsonData); err != nil {
|
|
|
+ return nil, errors.Wrap(err, "unmarshal_json_failed")
|
|
|
+ }
|
|
|
+
|
|
|
+ // 替换model字段为映射后的模型名
|
|
|
+ jsonData["model"] = info.UpstreamModelName
|
|
|
+
|
|
|
+ // 重新编码为JSON
|
|
|
+ newBody, err := json.Marshal(jsonData)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Wrap(err, "marshal_json_failed")
|
|
|
+ }
|
|
|
+
|
|
|
+ return bytes.NewReader(newBody), nil
|
|
|
+ }
|
|
|
+
|
|
|
return bytes.NewReader(cachedBody), nil
|
|
|
}
|
|
|
|
|
|
+func buildRequestBodyWithMappedModel(originalBody []byte, contentType, redirectedModel string) (io.Reader, error) {
|
|
|
+ newBuffer := &bytes.Buffer{}
|
|
|
+ writer := multipart.NewWriter(newBuffer)
|
|
|
+
|
|
|
+ r := multipart.NewReader(bytes.NewReader(originalBody), strings.TrimPrefix(contentType, "multipart/form-data; boundary="))
|
|
|
+
|
|
|
+ for {
|
|
|
+ part, err := r.NextPart()
|
|
|
+ if err == io.EOF {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Wrap(err, "read_multipart_part_failed")
|
|
|
+ }
|
|
|
+
|
|
|
+ fieldName := part.FormName()
|
|
|
+
|
|
|
+ if fieldName == "model" {
|
|
|
+ // 修改 model 字段为映射后的模型名
|
|
|
+ if err := writer.WriteField("model", redirectedModel); err != nil {
|
|
|
+ return nil, errors.Wrap(err, "write_model_field_failed")
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // 对于其他字段,保留原始内容
|
|
|
+ if part.FileName() != "" {
|
|
|
+ newPart, err := writer.CreateFormFile(fieldName, part.FileName())
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Wrap(err, "create_form_file_failed")
|
|
|
+ }
|
|
|
+ if _, err := io.Copy(newPart, part); err != nil {
|
|
|
+ return nil, errors.Wrap(err, "copy_file_content_failed")
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ content, err := io.ReadAll(part)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Wrap(err, "read_field_content_failed")
|
|
|
+ }
|
|
|
+ if err := writer.WriteField(fieldName, string(content)); err != nil {
|
|
|
+ return nil, errors.Wrap(err, "write_field_failed")
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := part.Close(); err != nil {
|
|
|
+ return nil, errors.Wrap(err, "close_part_failed")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := writer.Close(); err != nil {
|
|
|
+ return nil, errors.Wrap(err, "close_multipart_writer_failed")
|
|
|
+ }
|
|
|
+
|
|
|
+ return newBuffer, nil
|
|
|
+}
|
|
|
+
|
|
|
// DoRequest delegates to common helper.
|
|
|
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
|
return channel.DoTaskApiRequest(a, c, info, requestBody)
|