|
|
@@ -0,0 +1,192 @@
|
|
|
+package sora
|
|
|
+
|
|
|
+import (
|
|
|
+ "bytes"
|
|
|
+ "encoding/json"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "one-api/common"
|
|
|
+ "one-api/dto"
|
|
|
+ "one-api/model"
|
|
|
+ "one-api/relay/channel"
|
|
|
+ relaycommon "one-api/relay/common"
|
|
|
+ "one-api/service"
|
|
|
+
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
+ "github.com/pkg/errors"
|
|
|
+)
|
|
|
+
|
|
|
+// ============================
|
|
|
+// Request / Response structures
|
|
|
+// ============================
|
|
|
+
|
|
|
+type ContentItem struct {
|
|
|
+ Type string `json:"type"` // "text" or "image_url"
|
|
|
+ Text string `json:"text,omitempty"` // for text type
|
|
|
+ ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
|
|
|
+}
|
|
|
+
|
|
|
+type ImageURL struct {
|
|
|
+ URL string `json:"url"`
|
|
|
+}
|
|
|
+
|
|
|
+type responsePayload struct {
|
|
|
+ ID string `json:"id"` // task_id
|
|
|
+}
|
|
|
+
|
|
|
+type responseTask struct {
|
|
|
+ ID string `json:"id"`
|
|
|
+ Model string `json:"model"`
|
|
|
+ Status string `json:"status"`
|
|
|
+ Content struct {
|
|
|
+ VideoURL string `json:"video_url"`
|
|
|
+ } `json:"content"`
|
|
|
+ Seed int `json:"seed"`
|
|
|
+ Resolution string `json:"resolution"`
|
|
|
+ Duration int `json:"duration"`
|
|
|
+ AspectRatio string `json:"aspect_ratio"`
|
|
|
+ Usage struct {
|
|
|
+ PromptTokens int `json:"prompt_tokens"`
|
|
|
+ CompletionTokens int `json:"completion_tokens"`
|
|
|
+ TotalTokens int `json:"total_tokens"`
|
|
|
+ } `json:"usage"`
|
|
|
+ CreatedAt int64 `json:"created_at"`
|
|
|
+ UpdatedAt int64 `json:"updated_at"`
|
|
|
+}
|
|
|
+
|
|
|
+// ============================
|
|
|
+// Adaptor implementation
|
|
|
+// ============================
|
|
|
+
|
|
|
+type TaskAdaptor struct {
|
|
|
+ ChannelType int
|
|
|
+ apiKey string
|
|
|
+ baseURL string
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
|
|
+ a.ChannelType = info.ChannelType
|
|
|
+ a.baseURL = info.ChannelBaseUrl
|
|
|
+ a.apiKey = info.ApiKey
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
|
|
+ return relaycommon.ValidateMultipartDirect(c, info)
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
|
+ return fmt.Sprintf("%s/v1/videos", a.baseURL), nil
|
|
|
+}
|
|
|
+
|
|
|
+// BuildRequestHeader sets required headers.
|
|
|
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
|
|
+ req.Header.Set("Authorization", "Bearer "+a.apiKey)
|
|
|
+ req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
|
|
+ cachedBody, err := common.GetRequestBody(c)
|
|
|
+ if err != nil {
|
|
|
+ return nil, errors.Wrap(err, "get_request_body_failed")
|
|
|
+ }
|
|
|
+ return bytes.NewReader(cachedBody), nil
|
|
|
+}
|
|
|
+
|
|
|
+// DoRequest delegates to common helper.
|
|
|
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
|
|
+ return channel.DoTaskApiRequest(a, c, info, requestBody)
|
|
|
+}
|
|
|
+
|
|
|
+// DoResponse handles upstream response, returns taskID etc.
|
|
|
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
|
|
+ responseBody, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ _ = resp.Body.Close()
|
|
|
+
|
|
|
+ // Parse Sora response
|
|
|
+ var dResp responsePayload
|
|
|
+ if err := json.Unmarshal(responseBody, &dResp); err != nil {
|
|
|
+ taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if dResp.ID == "" {
|
|
|
+ taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ c.JSON(http.StatusOK, gin.H{"task_id": dResp.ID})
|
|
|
+ return dResp.ID, responseBody, nil
|
|
|
+}
|
|
|
+
|
|
|
+// FetchTask fetch task status
|
|
|
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
|
|
+ taskID, ok := body["task_id"].(string)
|
|
|
+ if !ok {
|
|
|
+ return nil, fmt.Errorf("invalid task_id")
|
|
|
+ }
|
|
|
+
|
|
|
+ uri := fmt.Sprintf("%s/v1/videos/generations/%s", baseUrl, taskID)
|
|
|
+
|
|
|
+ req, err := http.NewRequest(http.MethodGet, uri, nil)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ req.Header.Set("Accept", "application/json")
|
|
|
+ req.Header.Set("Content-Type", "application/json")
|
|
|
+ req.Header.Set("Authorization", "Bearer "+key)
|
|
|
+
|
|
|
+ return service.GetHttpClient().Do(req)
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) GetModelList() []string {
|
|
|
+ return ModelList
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) GetChannelName() string {
|
|
|
+ return ChannelName
|
|
|
+}
|
|
|
+
|
|
|
+func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
|
|
+ resTask := responseTask{}
|
|
|
+ if err := json.Unmarshal(respBody, &resTask); err != nil {
|
|
|
+ return nil, errors.Wrap(err, "unmarshal task result failed")
|
|
|
+ }
|
|
|
+
|
|
|
+ taskResult := relaycommon.TaskInfo{
|
|
|
+ Code: 0,
|
|
|
+ }
|
|
|
+
|
|
|
+ // Map Sora status to internal status
|
|
|
+ switch resTask.Status {
|
|
|
+ case "pending", "queued":
|
|
|
+ taskResult.Status = model.TaskStatusQueued
|
|
|
+ taskResult.Progress = "10%"
|
|
|
+ case "processing", "running":
|
|
|
+ taskResult.Status = model.TaskStatusInProgress
|
|
|
+ taskResult.Progress = "50%"
|
|
|
+ case "succeeded", "completed":
|
|
|
+ taskResult.Status = model.TaskStatusSuccess
|
|
|
+ taskResult.Progress = "100%"
|
|
|
+ taskResult.Url = resTask.Content.VideoURL
|
|
|
+ // Parse usage information for billing
|
|
|
+ taskResult.CompletionTokens = resTask.Usage.CompletionTokens
|
|
|
+ taskResult.TotalTokens = resTask.Usage.TotalTokens
|
|
|
+ case "failed", "cancelled":
|
|
|
+ taskResult.Status = model.TaskStatusFailure
|
|
|
+ taskResult.Progress = "100%"
|
|
|
+ taskResult.Reason = "task failed"
|
|
|
+ default:
|
|
|
+ // Unknown status, treat as processing
|
|
|
+ taskResult.Status = model.TaskStatusInProgress
|
|
|
+ taskResult.Progress = "30%"
|
|
|
+ }
|
|
|
+
|
|
|
+ return &taskResult, nil
|
|
|
+}
|