Forráskód Böngészése

feat: add sora video submit task

feitianbubu 2 hónapja
szülő
commit
9f4a2d64a3

+ 24 - 0
common/gin.go

@@ -3,6 +3,7 @@ package common
 import (
 	"bytes"
 	"io"
+	"mime/multipart"
 	"net/http"
 	"one-api/constant"
 	"strings"
@@ -113,3 +114,26 @@ func ApiSuccess(c *gin.Context, data any) {
 		"data":    data,
 	})
 }
+
+func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) {
+	requestBody, err := GetRequestBody(c)
+	if err != nil {
+		return nil, err
+	}
+
+	contentType := c.Request.Header.Get("Content-Type")
+	boundary := ""
+	if idx := strings.Index(contentType, "boundary="); idx != -1 {
+		boundary = contentType[idx+9:]
+	}
+
+	reader := multipart.NewReader(bytes.NewReader(requestBody), boundary)
+	form, err := reader.ReadForm(32 << 20) // 32 MB max memory
+	if err != nil {
+		return nil, err
+	}
+
+	// Reset request body
+	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+	return form, nil
+}

+ 3 - 0
constant/channel.go

@@ -52,6 +52,7 @@ const (
 	ChannelTypeVidu           = 52
 	ChannelTypeSubmodel       = 53
 	ChannelTypeDoubaoVideo    = 54
+	ChannelTypeSora           = 55
 	ChannelTypeDummy          // this one is only for count, do not add any channel after this
 
 )
@@ -112,6 +113,7 @@ var ChannelBaseURLs = []string{
 	"https://api.vidu.cn",                       //52
 	"https://llm.submodel.ai",                   //53
 	"https://ark.cn-beijing.volces.com",         //54
+	"https://api.openai.com",                    //55
 }
 
 var ChannelTypeNames = map[int]string{
@@ -166,6 +168,7 @@ var ChannelTypeNames = map[int]string{
 	ChannelTypeVidu:           "Vidu",
 	ChannelTypeSubmodel:       "Submodel",
 	ChannelTypeDoubaoVideo:    "DoubaoVideo",
+	ChannelTypeSora:           "Sora",
 }
 
 func GetChannelTypeName(channelType int) string {

+ 10 - 0
dto/openai_response.go

@@ -233,6 +233,16 @@ type Usage struct {
 	Cost any `json:"cost,omitempty"`
 }
 
+type OpenAIVideoResponse struct {
+	Id        string `json:"id" example:"file-abc123"`
+	Object    string `json:"object" example:"file"`
+	Bytes     int64  `json:"bytes" example:"120000"`
+	CreatedAt int64  `json:"created_at" example:"1677610602"`
+	ExpiresAt int64  `json:"expires_at" example:"1677614202"`
+	Filename  string `json:"filename" example:"mydata.jsonl"`
+	Purpose   string `json:"purpose" example:"fine-tune"`
+}
+
 type InputTokenDetails struct {
 	CachedTokens         int `json:"cached_tokens"`
 	CachedCreationTokens int `json:"-"`

+ 12 - 0
middleware/distributor.go

@@ -165,6 +165,18 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		}
 		c.Set("platform", string(constant.TaskPlatformSuno))
 		c.Set("relay_mode", relayMode)
+	} else if strings.Contains(c.Request.URL.Path, "/v1/videos") {
+		//curl https://api.openai.com/v1/videos \
+		//  -H "Authorization: Bearer $OPENAI_API_KEY" \
+		//  -F "model=sora-2" \
+		//  -F "prompt=A calico cat playing a piano on stage"
+		//	-F input_reference="@image.jpg"
+		relayMode := relayconstant.RelayModeUnknown
+		if c.Request.Method == http.MethodPost {
+			relayMode = relayconstant.RelayModeVideoSubmit
+			modelRequest.Model = c.PostForm("model")
+		}
+		c.Set("relay_mode", relayMode)
 	} else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") {
 		relayMode := relayconstant.RelayModeUnknown
 		if c.Request.Method == http.MethodPost {

+ 192 - 0
relay/channel/task/sora/adaptor.go

@@ -0,0 +1,192 @@
+package sora
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	"one-api/model"
+	"one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+
+	"github.com/gin-gonic/gin"
+	"github.com/pkg/errors"
+)
+
+// ============================
+// Request / Response structures
+// ============================
+
+type ContentItem struct {
+	Type     string    `json:"type"`                // "text" or "image_url"
+	Text     string    `json:"text,omitempty"`      // for text type
+	ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
+}
+
+type ImageURL struct {
+	URL string `json:"url"`
+}
+
+type responsePayload struct {
+	ID string `json:"id"` // task_id
+}
+
+type responseTask struct {
+	ID      string `json:"id"`
+	Model   string `json:"model"`
+	Status  string `json:"status"`
+	Content struct {
+		VideoURL string `json:"video_url"`
+	} `json:"content"`
+	Seed        int    `json:"seed"`
+	Resolution  string `json:"resolution"`
+	Duration    int    `json:"duration"`
+	AspectRatio string `json:"aspect_ratio"`
+	Usage       struct {
+		PromptTokens     int `json:"prompt_tokens"`
+		CompletionTokens int `json:"completion_tokens"`
+		TotalTokens      int `json:"total_tokens"`
+	} `json:"usage"`
+	CreatedAt int64 `json:"created_at"`
+	UpdatedAt int64 `json:"updated_at"`
+}
+
+// ============================
+// Adaptor implementation
+// ============================
+
+type TaskAdaptor struct {
+	ChannelType int
+	apiKey      string
+	baseURL     string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
+	a.ChannelType = info.ChannelType
+	a.baseURL = info.ChannelBaseUrl
+	a.apiKey = info.ApiKey
+}
+
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
+	return relaycommon.ValidateMultipartDirect(c, info)
+}
+
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
+	return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
+}
+
+// BuildRequestHeader sets required headers.
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
+	req.Header.Set("Authorization", "Bearer "+a.apiKey)
+	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+	return nil
+}
+
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
+	cachedBody, err := common.GetRequestBody(c)
+	if err != nil {
+		return nil, errors.Wrap(err, "get_request_body_failed")
+	}
+	return bytes.NewReader(cachedBody), nil
+}
+
+// DoRequest delegates to common helper.
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+// DoResponse handles upstream response, returns taskID etc.
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+	_ = resp.Body.Close()
+
+	// Parse Sora response
+	var dResp responsePayload
+	if err := json.Unmarshal(responseBody, &dResp); err != nil {
+		taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+
+	if dResp.ID == "" {
+		taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
+		return
+	}
+
+	c.JSON(http.StatusOK, gin.H{"task_id": dResp.ID})
+	return dResp.ID, responseBody, nil
+}
+
+// FetchTask fetch task status
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+	taskID, ok := body["task_id"].(string)
+	if !ok {
+		return nil, fmt.Errorf("invalid task_id")
+	}
+
+	uri := fmt.Sprintf("%s/v1/videos/generations/%s", baseUrl, taskID)
+
+	req, err := http.NewRequest(http.MethodGet, uri, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	req.Header.Set("Accept", "application/json")
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Authorization", "Bearer "+key)
+
+	return service.GetHttpClient().Do(req)
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+	return ChannelName
+}
+
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
+	resTask := responseTask{}
+	if err := json.Unmarshal(respBody, &resTask); err != nil {
+		return nil, errors.Wrap(err, "unmarshal task result failed")
+	}
+
+	taskResult := relaycommon.TaskInfo{
+		Code: 0,
+	}
+
+	// Map Sora status to internal status
+	switch resTask.Status {
+	case "pending", "queued":
+		taskResult.Status = model.TaskStatusQueued
+		taskResult.Progress = "10%"
+	case "processing", "running":
+		taskResult.Status = model.TaskStatusInProgress
+		taskResult.Progress = "50%"
+	case "succeeded", "completed":
+		taskResult.Status = model.TaskStatusSuccess
+		taskResult.Progress = "100%"
+		taskResult.Url = resTask.Content.VideoURL
+		// Parse usage information for billing
+		taskResult.CompletionTokens = resTask.Usage.CompletionTokens
+		taskResult.TotalTokens = resTask.Usage.TotalTokens
+	case "failed", "cancelled":
+		taskResult.Status = model.TaskStatusFailure
+		taskResult.Progress = "100%"
+		taskResult.Reason = "task failed"
+	default:
+		// Unknown status, treat as processing
+		taskResult.Status = model.TaskStatusInProgress
+		taskResult.Progress = "30%"
+	}
+
+	return &taskResult, nil
+}

+ 8 - 0
relay/channel/task/sora/constants.go

@@ -0,0 +1,8 @@
+package sora
+
+var ModelList = []string{
+	"sora-2",
+	"sora-2-pro",
+}
+
+var ChannelName = "sora"

+ 91 - 2
relay/common/relay_utils.go

@@ -6,6 +6,7 @@ import (
 	"one-api/common"
 	"one-api/constant"
 	"one-api/dto"
+	"strconv"
 	"strings"
 
 	"github.com/gin-gonic/gin"
@@ -52,7 +53,7 @@ func createTaskError(err error, code string, statusCode int, localError bool) *d
 	}
 }
 
-func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj interface{}) {
+func storeTaskRequest(c *gin.Context, info *RelayInfo, action string, requestObj TaskSubmitReq) {
 	info.Action = action
 	c.Set("task_request", requestObj)
 }
@@ -64,9 +65,97 @@ func validatePrompt(prompt string) *dto.TaskError {
 	return nil
 }
 
+func validateMultipartTaskRequest(c *gin.Context, info *RelayInfo, action string) (TaskSubmitReq, error) {
+	var req TaskSubmitReq
+	if _, err := c.MultipartForm(); err != nil {
+		return req, err
+	}
+
+	formData := c.Request.PostForm
+	req = TaskSubmitReq{
+		Prompt:   formData.Get("prompt"),
+		Model:    formData.Get("model"),
+		Mode:     formData.Get("mode"),
+		Image:    formData.Get("image"),
+		Size:     formData.Get("size"),
+		Metadata: make(map[string]interface{}),
+	}
+
+	if durationStr := formData.Get("seconds"); durationStr != "" {
+		if duration, err := strconv.Atoi(durationStr); err == nil {
+			req.Duration = duration
+		}
+	}
+
+	if images := formData["images"]; len(images) > 0 {
+		req.Images = images
+	}
+
+	for key, values := range formData {
+		if len(values) > 0 && !isKnownTaskField(key) {
+			if intVal, err := strconv.Atoi(values[0]); err == nil {
+				req.Metadata[key] = intVal
+			} else if floatVal, err := strconv.ParseFloat(values[0], 64); err == nil {
+				req.Metadata[key] = floatVal
+			} else {
+				req.Metadata[key] = values[0]
+			}
+		}
+	}
+	return req, nil
+}
+
+func ValidateMultipartDirect(c *gin.Context, info *RelayInfo) *dto.TaskError {
+	form, err := common.ParseMultipartFormReusable(c)
+	if err != nil {
+		return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
+	}
+	defer form.RemoveAll()
+
+	prompts, ok := form.Value["prompt"]
+	if !ok || len(prompts) == 0 {
+		return createTaskError(fmt.Errorf("prompt field is required"), "missing_prompt", http.StatusBadRequest, true)
+	}
+	if taskErr := validatePrompt(prompts[0]); taskErr != nil {
+		return taskErr
+	}
+
+	if _, ok := form.Value["model"]; !ok {
+		return createTaskError(fmt.Errorf("model field is required"), "missing_model", http.StatusBadRequest, true)
+	}
+	action := constant.TaskActionTextGenerate
+	if _, ok := form.File["input_reference"]; ok {
+		action = constant.TaskActionGenerate
+	}
+	info.Action = action
+
+	return nil
+}
+
+func isKnownTaskField(field string) bool {
+	knownFields := map[string]bool{
+		"prompt":          true,
+		"model":           true,
+		"mode":            true,
+		"image":           true,
+		"images":          true,
+		"size":            true,
+		"duration":        true,
+		"input_reference": true, // Sora 特有字段
+	}
+	return knownFields[field]
+}
+
 func ValidateBasicTaskRequest(c *gin.Context, info *RelayInfo, action string) *dto.TaskError {
+	var err error
+	contentType := c.GetHeader("Content-Type")
 	var req TaskSubmitReq
-	if err := common.UnmarshalBodyReusable(c, &req); err != nil {
+	if strings.HasPrefix(contentType, "multipart/form-data") {
+		req, err = validateMultipartTaskRequest(c, info, action)
+		if err != nil {
+			return createTaskError(err, "invalid_multipart_form", http.StatusBadRequest, true)
+		}
+	} else if err := common.UnmarshalBodyReusable(c, &req); err != nil {
 		return createTaskError(err, "invalid_request", http.StatusBadRequest, true)
 	}
 

+ 3 - 0
relay/relay_adaptor.go

@@ -29,6 +29,7 @@ import (
 	taskdoubao "one-api/relay/channel/task/doubao"
 	taskjimeng "one-api/relay/channel/task/jimeng"
 	"one-api/relay/channel/task/kling"
+	tasksora "one-api/relay/channel/task/sora"
 	"one-api/relay/channel/task/suno"
 	taskvertex "one-api/relay/channel/task/vertex"
 	taskVidu "one-api/relay/channel/task/vidu"
@@ -137,6 +138,8 @@ func GetTaskAdaptor(platform constant.TaskPlatform) channel.TaskAdaptor {
 			return &taskVidu.TaskAdaptor{}
 		case constant.ChannelTypeDoubaoVideo:
 			return &taskdoubao.TaskAdaptor{}
+		case constant.ChannelTypeSora:
+			return &tasksora.TaskAdaptor{}
 		}
 	}
 	return nil

+ 5 - 0
router/video-router.go

@@ -14,6 +14,11 @@ func SetVideoRouter(router *gin.Engine) {
 		videoV1Router.POST("/video/generations", controller.RelayTask)
 		videoV1Router.GET("/video/generations/:task_id", controller.RelayTask)
 	}
+	// openai compatible API video routes
+	// docs: https://platform.openai.com/docs/api-reference/videos/create
+	{
+		videoV1Router.POST("/videos", controller.RelayTask)
+	}
 
 	klingV1Router := router.Group("/kling/v1")
 	klingV1Router.Use(middleware.KlingRequestConvert(), middleware.TokenAuth(), middleware.Distribute())

+ 5 - 0
web/src/constants/channel.constants.js

@@ -169,6 +169,11 @@ export const CHANNEL_OPTIONS = [
     color: 'blue',
     label: '豆包视频',
   },
+  {
+    value: 55,
+    color: 'green',
+    label: 'Sora',
+  },
 ];
 
 export const MODEL_TABLE_PAGE_SIZE = 10;