Browse Source

feat: suno api 支持

feat: 调试 suno

feat: 补充suno 文档
Xiangyuan Liu 1 year ago
parent
commit
c993ab2746

+ 9 - 0
README.md

@@ -47,6 +47,11 @@
     2. 对[@Botfather](https://t.me/botfather)输入指令/setdomain
     3. 选择你的bot,然后输入http(s)://你的网站地址/login
     4. Telegram Bot 名称是bot username 去掉@后的字符串
+13. 添加 [Suno API](https://github.com/Suno-API/Suno-API)接口的支持,[对接文档](Suno.md),支持的接口如下:
+    + [x] /suno/submit/music
+    + [x] /suno/submit/lyrics
+    + [x] /suno/fetch
+    + [x] /suno/fetch/:id
 
 ## 模型支持
 此版本额外支持以下模型:
@@ -57,6 +62,7 @@
 5. [Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口,[对接文档](Midjourney.md)
 6. [零一万物](https://platform.lingyiwanwu.com/)
 7. 自定义渠道,支持填入完整调用地址
+8. [Suno API](https://github.com/Suno-API/Suno-API) 接口,[对接文档](Suno.md)
 
 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。
 
@@ -105,6 +111,9 @@ docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:1234
 ## Midjourney接口设置文档
 [对接文档](Midjourney.md)
 
+## Suno接口设置文档
+[对接文档](Suno.md)
+
 ## 交流群
 <img src="https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266" width="300">
 

+ 37 - 0
Suno.md

@@ -0,0 +1,37 @@
+# Suno API文档
+
+**简介**:Suno API文档
+
+## 模型列表
+
+### Suno API支持
+
+- suno_music (自定义模式、灵感模式、续写)
+- suno_lyrics (生成歌词)
+
+
+## 模型价格设置(在设置-运营设置-模型固定价格设置中设置)
+```json
+{
+  "suno_music": 0.3,
+  "suno_lyrics": 0.01
+}
+```
+
+## 渠道设置
+
+### 对接 Suno API
+
+1.
+部署 Suno API,并配置好suno账号等(强烈建议设置密钥),[项目地址](https://github.com/Suno-API/Suno-API)
+
+2. 在渠道管理中添加渠道,渠道类型选择**Suno API**
+   ,模型请参考上方模型列表
+3. **代理**填写 Suno API 部署的地址,例如:http://localhost:8080
+4. 密钥填写 Suno API 的密钥,如果没有设置密钥,可以随便填
+
+### 对接上游new api
+
+1. 在渠道管理中添加渠道,渠道类型选择**Suno API**,或任意类型,只需模型包含上方模型列表的模型
+2. **代理**填写上游new api的地址,例如:http://localhost:3000
+3. 密钥填写上游new api的密钥

+ 4 - 0
common/constants.go

@@ -21,6 +21,7 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
 var DisplayInCurrencyEnabled = true
 var DisplayTokenStatEnabled = true
 var DrawingEnabled = true
+var TaskEnabled = true
 var DataExportEnabled = true
 var DataExportInterval = 5         // unit: minute
 var DataExportDefaultTime = "hour" // unit: minute
@@ -208,8 +209,10 @@ const (
 	ChannelTypeAws            = 33
 	ChannelTypeCohere         = 34
 	ChannelTypeMiniMax        = 35
+	ChannelTypeSunoAPI        = 36
 
 	ChannelTypeDummy // this one is only for count, do not add any channel after this
+
 )
 
 var ChannelBaseURLs = []string{
@@ -249,4 +252,5 @@ var ChannelBaseURLs = []string{
 	"",                                          //33
 	"https://api.cohere.ai",                     //34
 	"https://api.minimax.chat",                  //35
+	"",                                          //36
 }

+ 18 - 0
constant/task.go

@@ -0,0 +1,18 @@
+package constant
+
+type TaskPlatform string
+
+const (
+	TaskPlatformSuno       TaskPlatform = "suno"
+	TaskPlatformMidjourney              = "mj"
+)
+
+const (
+	SunoActionMusic  = "MUSIC"
+	SunoActionLyrics = "LYRICS"
+)
+
+var SunoModel2Action = map[string]string{
+	"suno_music":  SunoActionMusic,
+	"suno_lyrics": SunoActionLyrics,
+}

+ 3 - 0
controller/channel-test.go

@@ -27,6 +27,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 	if channel.Type == common.ChannelTypeMidjourney {
 		return errors.New("midjourney channel test is not supported"), nil
 	}
+	if channel.Type == common.ChannelTypeSunoAPI {
+		return errors.New("suno channel test is not supported"), nil
+	}
 	w := httptest.NewRecorder()
 	c, _ := gin.CreateTestContext(w)
 	c.Request = &http.Request{

+ 1 - 0
controller/misc.go

@@ -57,6 +57,7 @@ func GetStatus(c *gin.Context) {
 			"display_in_currency":      common.DisplayInCurrencyEnabled,
 			"enable_batch_update":      common.BatchUpdateEnabled,
 			"enable_drawing":           common.DrawingEnabled,
+			"enable_task":              common.TaskEnabled,
 			"enable_data_export":       common.DataExportEnabled,
 			"data_export_default_time": common.DataExportDefaultTime,
 			"default_collapse_sidebar": common.DefaultCollapseSidebar,

+ 91 - 0
controller/relay.go

@@ -190,3 +190,94 @@ func RelayNotFound(c *gin.Context) {
 		"error": err,
 	})
 }
+
+func RelayTask(c *gin.Context) {
+	retryTimes := common.RetryTimes
+	channelId := c.GetInt("channel_id")
+	relayMode := c.GetInt("relay_mode")
+	group := c.GetString("group")
+	originalModel := c.GetString("original_model")
+	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
+	taskErr := taskRelayHandler(c, relayMode)
+	if taskErr == nil {
+		retryTimes = 0
+	}
+	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
+		channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
+		if err != nil {
+			common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
+			break
+		}
+		channelId = channel.Id
+		useChannel := c.GetStringSlice("use_channel")
+		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
+		c.Set("use_channel", useChannel)
+		common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
+		middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+
+		requestBody, err := common.GetRequestBody(c)
+		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+		taskErr = taskRelayHandler(c, relayMode)
+	}
+	useChannel := c.GetStringSlice("use_channel")
+	if len(useChannel) > 1 {
+		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+		common.LogInfo(c.Request.Context(), retryLogStr)
+	}
+	if taskErr != nil {
+		if taskErr.StatusCode == http.StatusTooManyRequests {
+			taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
+		}
+		c.JSON(taskErr.StatusCode, taskErr)
+	}
+}
+
+func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
+	var err *dto.TaskError
+	switch relayMode {
+	case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
+		err = relay.RelayTaskFetch(c, relayMode)
+	default:
+		err = relay.RelayTaskSubmit(c, relayMode)
+	}
+	return err
+}
+
+func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
+	if taskErr == nil {
+		return false
+	}
+	if retryTimes <= 0 {
+		return false
+	}
+	if _, ok := c.Get("specific_channel_id"); ok {
+		return false
+	}
+	if taskErr.StatusCode == http.StatusTooManyRequests {
+		return true
+	}
+	if taskErr.StatusCode == 307 {
+		return true
+	}
+	if taskErr.StatusCode/100 == 5 {
+		// 超时不重试
+		if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
+			return false
+		}
+		return true
+	}
+	if taskErr.StatusCode == http.StatusBadRequest {
+		return false
+	}
+	if taskErr.StatusCode == 408 {
+		// azure处理超时不重试
+		return false
+	}
+	if taskErr.LocalError {
+		return false
+	}
+	if taskErr.StatusCode/100 == 2 {
+		return false
+	}
+	return true
+}

+ 284 - 0
controller/task.go

@@ -0,0 +1,284 @@
+package controller
+
+import (
+	"context"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"github.com/samber/lo"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
+	"one-api/model"
+	"one-api/relay"
+	"sort"
+	"strconv"
+	"time"
+)
+
+func UpdateTaskBulk() {
+	//revocer
+	//imageModel := "midjourney"
+	for {
+		time.Sleep(time.Duration(15) * time.Second)
+		common.SysLog("任务进度轮询开始")
+		ctx := context.TODO()
+		allTasks := model.GetAllUnFinishSyncTasks(500)
+		platformTask := make(map[constant.TaskPlatform][]*model.Task)
+		for _, t := range allTasks {
+			platformTask[t.Platform] = append(platformTask[t.Platform], t)
+		}
+		for platform, tasks := range platformTask {
+			if len(tasks) == 0 {
+				continue
+			}
+			taskChannelM := make(map[int][]string)
+			taskM := make(map[string]*model.Task)
+			nullTaskIds := make([]int64, 0)
+			for _, task := range tasks {
+				if task.TaskID == "" {
+					// 统计失败的未完成任务
+					nullTaskIds = append(nullTaskIds, task.ID)
+					continue
+				}
+				taskM[task.TaskID] = task
+				taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
+			}
+			if len(nullTaskIds) > 0 {
+				err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
+					"status":   "FAILURE",
+					"progress": "100%",
+				})
+				if err != nil {
+					common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+				} else {
+					common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+				}
+			}
+			if len(taskChannelM) == 0 {
+				continue
+			}
+
+			UpdateTaskByPlatform(platform, taskChannelM, taskM)
+		}
+		common.SysLog("任务进度轮询完成")
+	}
+}
+
+func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
+	switch platform {
+	case constant.TaskPlatformMidjourney:
+		//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
+	case constant.TaskPlatformSuno:
+		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
+	default:
+		common.SysLog("未知平台")
+	}
+}
+
+func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
+		if err != nil {
+			common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	channel, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
+		err = model.TaskBulkUpdate(taskIds, map[string]any{
+			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if err != nil {
+			common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
+		}
+		return err
+	}
+	adaptor := relay.GetTaskAdaptor(constant.TaskPlatformSuno)
+	if adaptor == nil {
+		return errors.New("adaptor not found")
+	}
+	resp, err := adaptor.FetchTask(*channel.BaseURL, channel.Key, map[string]any{
+		"ids": taskIds,
+	})
+	if err != nil {
+		common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
+		return err
+	}
+	if resp.StatusCode != http.StatusOK {
+		common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+		return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+	}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
+		return err
+	}
+	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
+	err = json.Unmarshal(responseBody, &responseItems)
+	if err != nil {
+		common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
+		return err
+	}
+	if !responseItems.IsSuccess() {
+		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
+		return err
+	}
+
+	for _, responseItem := range responseItems.Data {
+		task := taskM[responseItem.TaskID]
+		if !checkTaskNeedUpdate(task, responseItem) {
+			continue
+		}
+
+		task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
+		task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
+		task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
+		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
+		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
+		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
+			common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+			task.Progress = "100%"
+			err = model.CacheUpdateUserQuota(task.UserId)
+			if err != nil {
+				common.LogError(ctx, "error update user quota cache: "+err.Error())
+			} else {
+				quota := task.Quota
+				if quota != 0 {
+					err = model.IncreaseUserQuota(task.UserId, quota)
+					if err != nil {
+						common.LogError(ctx, "fail to increase user quota: "+err.Error())
+					}
+					logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
+					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+				}
+			}
+		}
+		if responseItem.Status == model.TaskStatusSuccess {
+			task.Progress = "100%"
+		}
+		task.Data = responseItem.Data
+
+		err = task.Update()
+		if err != nil {
+			common.SysError("UpdateMidjourneyTask task error: " + err.Error())
+		}
+	}
+	return nil
+}
+
+func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
+
+	if oldTask.SubmitTime != newTask.SubmitTime {
+		return true
+	}
+	if oldTask.StartTime != newTask.StartTime {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+	if string(oldTask.Status) != newTask.Status {
+		return true
+	}
+	if oldTask.FailReason != newTask.FailReason {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+
+	if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
+		return true
+	}
+
+	oldData, _ := json.Marshal(oldTask.Data)
+	newData, _ := json.Marshal(newTask.Data)
+
+	sort.Slice(oldData, func(i, j int) bool {
+		return oldData[i] < oldData[j]
+	})
+	sort.Slice(newData, func(i, j int) bool {
+		return newData[i] < newData[j]
+	})
+
+	if string(oldData) != string(newData) {
+		return true
+	}
+	return false
+}
+
+func GetAllTask(c *gin.Context) {
+	p, _ := strconv.Atoi(c.Query("p"))
+	if p < 0 {
+		p = 0
+	}
+	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+	// 解析其他查询参数
+	queryParams := model.SyncTaskQueryParams{
+		Platform:       constant.TaskPlatform(c.Query("platform")),
+		TaskID:         c.Query("task_id"),
+		Status:         c.Query("status"),
+		Action:         c.Query("action"),
+		StartTimestamp: startTimestamp,
+		EndTimestamp:   endTimestamp,
+	}
+
+	logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
+	if logs == nil {
+		logs = make([]*model.Task, 0)
+	}
+
+	c.JSON(200, gin.H{
+		"success": true,
+		"message": "",
+		"data":    logs,
+	})
+}
+
+func GetUserTask(c *gin.Context) {
+	p, _ := strconv.Atoi(c.Query("p"))
+	if p < 0 {
+		p = 0
+	}
+
+	userId := c.GetInt("id")
+
+	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+
+	queryParams := model.SyncTaskQueryParams{
+		Platform:       constant.TaskPlatform(c.Query("platform")),
+		TaskID:         c.Query("task_id"),
+		Status:         c.Query("status"),
+		Action:         c.Query("action"),
+		StartTimestamp: startTimestamp,
+		EndTimestamp:   endTimestamp,
+	}
+
+	logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
+	if logs == nil {
+		logs = make([]*model.Task, 0)
+	}
+
+	c.JSON(200, gin.H{
+		"success": true,
+		"message": "",
+		"data":    logs,
+	})
+}

+ 129 - 0
dto/suno.go

@@ -0,0 +1,129 @@
+package dto
+
+import (
+	"encoding/json"
+)
+
+type TaskData interface {
+	SunoDataResponse | []SunoDataResponse | string | any
+}
+
+type SunoSubmitReq struct {
+	GptDescriptionPrompt string  `json:"gpt_description_prompt,omitempty"`
+	Prompt               string  `json:"prompt,omitempty"`
+	Mv                   string  `json:"mv,omitempty"`
+	Title                string  `json:"title,omitempty"`
+	Tags                 string  `json:"tags,omitempty"`
+	ContinueAt           float64 `json:"continue_at,omitempty"`
+	TaskID               string  `json:"task_id,omitempty"`
+	ContinueClipId       string  `json:"continue_clip_id,omitempty"`
+	MakeInstrumental     bool    `json:"make_instrumental"`
+}
+
+type FetchReq struct {
+	IDs []string `json:"ids"`
+}
+
+type SunoDataResponse struct {
+	TaskID     string          `json:"task_id" gorm:"type:varchar(50);index"`
+	Action     string          `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
+	Status     string          `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed
+	FailReason string          `json:"fail_reason"`
+	SubmitTime int64           `json:"submit_time" gorm:"index"`
+	StartTime  int64           `json:"start_time" gorm:"index"`
+	FinishTime int64           `json:"finish_time" gorm:"index"`
+	Data       json.RawMessage `json:"data" gorm:"type:json"`
+}
+
+type SunoSong struct {
+	ID                string       `json:"id"`
+	VideoURL          string       `json:"video_url"`
+	AudioURL          string       `json:"audio_url"`
+	ImageURL          string       `json:"image_url"`
+	ImageLargeURL     string       `json:"image_large_url"`
+	MajorModelVersion string       `json:"major_model_version"`
+	ModelName         string       `json:"model_name"`
+	Status            string       `json:"status"`
+	Title             string       `json:"title"`
+	Text              string       `json:"text"`
+	Metadata          SunoMetadata `json:"metadata"`
+}
+
+type SunoMetadata struct {
+	Tags                 string      `json:"tags"`
+	Prompt               string      `json:"prompt"`
+	GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"`
+	AudioPromptID        interface{} `json:"audio_prompt_id"`
+	Duration             interface{} `json:"duration"`
+	ErrorType            interface{} `json:"error_type"`
+	ErrorMessage         interface{} `json:"error_message"`
+}
+
+type SunoLyrics struct {
+	ID     string `json:"id"`
+	Status string `json:"status"`
+	Title  string `json:"title"`
+	Text   string `json:"text"`
+}
+
+const TaskSuccessCode = "success"
+
+type TaskResponse[T TaskData] struct {
+	Code    string `json:"code"`
+	Message string `json:"message"`
+	Data    T      `json:"data"`
+}
+
+func (t *TaskResponse[T]) IsSuccess() bool {
+	return t.Code == TaskSuccessCode
+}
+
+type TaskDto struct {
+	TaskID     string          `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
+	Action     string          `json:"action"`  // 任务类型, song, lyrics, description-mode
+	Status     string          `json:"status"`  // 任务状态, submitted, queueing, processing, success, failed
+	FailReason string          `json:"fail_reason"`
+	SubmitTime int64           `json:"submit_time"`
+	StartTime  int64           `json:"start_time"`
+	FinishTime int64           `json:"finish_time"`
+	Progress   string          `json:"progress"`
+	Data       json.RawMessage `json:"data"`
+}
+
+type SunoGoAPISubmitReq struct {
+	CustomMode bool `json:"custom_mode"`
+
+	Input SunoGoAPISubmitReqInput `json:"input"`
+
+	NotifyHook string `json:"notify_hook,omitempty"`
+}
+
+type SunoGoAPISubmitReqInput struct {
+	GptDescriptionPrompt string  `json:"gpt_description_prompt"`
+	Prompt               string  `json:"prompt"`
+	Mv                   string  `json:"mv"`
+	Title                string  `json:"title"`
+	Tags                 string  `json:"tags"`
+	ContinueAt           float64 `json:"continue_at"`
+	TaskID               string  `json:"task_id"`
+	ContinueClipId       string  `json:"continue_clip_id"`
+	MakeInstrumental     bool    `json:"make_instrumental"`
+}
+
+type GoAPITaskResponse[T any] struct {
+	Code         int    `json:"code"`
+	Message      string `json:"message"`
+	Data         T      `json:"data"`
+	ErrorMessage string `json:"error_message,omitempty"`
+}
+
+type GoAPITaskResponseData struct {
+	TaskID string `json:"task_id"`
+}
+
+type GoAPIFetchResponseData struct {
+	TaskID string              `json:"task_id"`
+	Status string              `json:"status"`
+	Input  string              `json:"input"`
+	Clips  map[string]SunoSong `json:"clips"`
+}

+ 10 - 0
dto/task.go

@@ -0,0 +1,10 @@
+package dto
+
+type TaskError struct {
+	Code       string `json:"code"`
+	Message    string `json:"message"`
+	Data       any    `json:"data"`
+	StatusCode int    `json:"-"`
+	LocalError bool   `json:"-"`
+	Error      error  `json:"-"`
+}

+ 5 - 2
main.go

@@ -20,10 +20,10 @@ import (
 	_ "net/http/pprof"
 )
 
-//go:embed web/dist
+// /go:embed web/dist
 var buildFS embed.FS
 
-//go:embed web/dist/index.html
+// /go:embed web/dist/index.html
 var indexPage []byte
 
 func main() {
@@ -92,6 +92,9 @@ func main() {
 	common.SafeGoroutine(func() {
 		controller.UpdateMidjourneyTaskBulk()
 	})
+	common.SafeGoroutine(func() {
+		controller.UpdateTaskBulk()
+	})
 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
 		common.BatchUpdateEnabled = true
 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")

+ 11 - 0
middleware/distributor.go

@@ -125,6 +125,17 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 			modelRequest.Model = midjourneyModel
 		}
 		c.Set("relay_mode", relayMode)
+	} else if strings.Contains(c.Request.URL.Path, "/suno/") {
+		relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path)
+		if relayMode == relayconstant.RelayModeSunoFetch ||
+			relayMode == relayconstant.RelayModeSunoFetchByID {
+			shouldSelectChannel = false
+		} else {
+			modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
+			modelRequest.Model = modelName
+		}
+		c.Set("platform", string(constant.TaskPlatformSuno))
+		c.Set("relay_mode", relayMode)
 	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 		err = common.UnmarshalBodyReusable(c, &modelRequest)
 	}

+ 4 - 0
model/main.go

@@ -140,6 +140,10 @@ func InitDB() (err error) {
 		if err != nil {
 			return err
 		}
+		err = db.AutoMigrate(&Task{})
+		if err != nil {
+			return err
+		}
 		common.SysLog("database migrated")
 		err = createRootAccountIfNeed()
 		return err

+ 3 - 0
model/option.go

@@ -41,6 +41,7 @@ func InitOptionMap() {
 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
 	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
 	common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
+	common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
 	common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
 	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
 	common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
@@ -195,6 +196,8 @@ func updateOptionMap(key string, value string) (err error) {
 			common.DisplayTokenStatEnabled = boolValue
 		case "DrawingEnabled":
 			common.DrawingEnabled = boolValue
+		case "TaskEnabled":
+			common.TaskEnabled = boolValue
 		case "DataExportEnabled":
 			common.DataExportEnabled = boolValue
 		case "DefaultCollapseSidebar":

+ 304 - 0
model/task.go

@@ -0,0 +1,304 @@
+package model
+
+import (
+	"database/sql/driver"
+	"encoding/json"
+	"one-api/constant"
+	commonRelay "one-api/relay/common"
+	"time"
+)
+
+type TaskStatus string
+
+const (
+	TaskStatusNotStart   TaskStatus = "NOT_START"
+	TaskStatusSubmitted             = "SUBMITTED"
+	TaskStatusQueued                = "QUEUED"
+	TaskStatusInProgress            = "IN_PROGRESS"
+	TaskStatusFailure               = "FAILURE"
+	TaskStatusSuccess               = "SUCCESS"
+	TaskStatusUnknown               = "UNKNOWN"
+)
+
+type Task struct {
+	ID         int64                 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
+	CreatedAt  int64                 `json:"created_at" gorm:"index"`
+	UpdatedAt  int64                 `json:"updated_at"`
+	TaskID     string                `json:"task_id" gorm:"type:varchar(50);index"`  // 第三方id,不一定有/ song id\ Task id
+	Platform   constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
+	UserId     int                   `json:"user_id" gorm:"index"`
+	ChannelId  int                   `json:"channel_id" gorm:"index"`
+	Quota      int                   `json:"quota"`
+	Action     string                `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
+	Status     TaskStatus            `json:"status" gorm:"type:varchar(20);index"` // 任务状态
+	FailReason string                `json:"fail_reason"`
+	SubmitTime int64                 `json:"submit_time" gorm:"index"`
+	StartTime  int64                 `json:"start_time" gorm:"index"`
+	FinishTime int64                 `json:"finish_time" gorm:"index"`
+	Progress   string                `json:"progress" gorm:"type:varchar(20);index"`
+	Properties Properties            `json:"properties" gorm:"type:json"`
+
+	Data json.RawMessage `json:"data" gorm:"type:json"`
+}
+
+func (t *Task) SetData(data any) {
+	b, _ := json.Marshal(data)
+	t.Data = json.RawMessage(b)
+}
+
+func (t *Task) GetData(v any) error {
+	err := json.Unmarshal(t.Data, &v)
+	return err
+}
+
+type Properties struct {
+	Input string `json:"input"`
+}
+
+func (m *Properties) Scan(val interface{}) error {
+	bytesValue, _ := val.([]byte)
+	return json.Unmarshal(bytesValue, m)
+}
+
+func (m Properties) Value() (driver.Value, error) {
+	return json.Marshal(m)
+}
+
+// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
+type SyncTaskQueryParams struct {
+	Platform       constant.TaskPlatform
+	ChannelID      string
+	TaskID         string
+	UserID         string
+	Action         string
+	Status         string
+	StartTimestamp int64
+	EndTimestamp   int64
+	UserIDs        []int
+}
+
+func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
+	t := &Task{
+		UserId:     relayInfo.UserId,
+		SubmitTime: time.Now().Unix(),
+		Status:     TaskStatusNotStart,
+		Progress:   "0%",
+		ChannelId:  relayInfo.ChannelId,
+		Platform:   platform,
+	}
+	return t
+}
+
+func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
+	var tasks []*Task
+	var err error
+
+	// 初始化查询构建器
+	query := DB.Where("user_id = ?", userId)
+
+	if queryParams.TaskID != "" {
+		query = query.Where("task_id = ?", queryParams.TaskID)
+	}
+	if queryParams.Action != "" {
+		query = query.Where("action = ?", queryParams.Action)
+	}
+	if queryParams.Status != "" {
+		query = query.Where("status = ?", queryParams.Status)
+	}
+	if queryParams.Platform != "" {
+		query = query.Where("platform = ?", queryParams.Platform)
+	}
+	if queryParams.StartTimestamp != 0 {
+		// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
+		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+	}
+	if queryParams.EndTimestamp != 0 {
+		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+	}
+
+	// 获取数据
+	err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
+	if err != nil {
+		return nil
+	}
+
+	return tasks
+}
+
+func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
+	var tasks []*Task
+	var err error
+
+	// 初始化查询构建器
+	query := DB
+
+	// 添加过滤条件
+	if queryParams.ChannelID != "" {
+		query = query.Where("channel_id = ?", queryParams.ChannelID)
+	}
+	if queryParams.Platform != "" {
+		query = query.Where("platform = ?", queryParams.Platform)
+	}
+	if queryParams.UserID != "" {
+		query = query.Where("user_id = ?", queryParams.UserID)
+	}
+	if len(queryParams.UserIDs) != 0 {
+		query = query.Where("user_id in (?)", queryParams.UserIDs)
+	}
+	if queryParams.TaskID != "" {
+		query = query.Where("task_id = ?", queryParams.TaskID)
+	}
+	if queryParams.Action != "" {
+		query = query.Where("action = ?", queryParams.Action)
+	}
+	if queryParams.Status != "" {
+		query = query.Where("status = ?", queryParams.Status)
+	}
+	if queryParams.StartTimestamp != 0 {
+		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+	}
+	if queryParams.EndTimestamp != 0 {
+		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+	}
+
+	// 获取数据
+	err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
+	if err != nil {
+		return nil
+	}
+
+	return tasks
+}
+
+func GetAllUnFinishSyncTasks(limit int) []*Task {
+	var tasks []*Task
+	var err error
+	// get all tasks progress is not 100%
+	err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
+	if err != nil {
+		return nil
+	}
+	return tasks
+}
+
+func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
+	if taskId == "" {
+		return nil, false, nil
+	}
+	var task *Task
+	var err error
+	err = DB.Where("task_id = ?", taskId).First(&task).Error
+	exist, err := RecordExist(err)
+	if err != nil {
+		return nil, false, err
+	}
+	return task, exist, err
+}
+
+func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
+	if taskId == "" {
+		return nil, false, nil
+	}
+	var task *Task
+	var err error
+	err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
+		First(&task).Error
+	exist, err := RecordExist(err)
+	if err != nil {
+		return nil, false, err
+	}
+	return task, exist, err
+}
+
+func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
+	if len(taskIds) == 0 {
+		return nil, nil
+	}
+	var task []*Task
+	var err error
+	err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
+		Find(&task).Error
+	if err != nil {
+		return nil, err
+	}
+	return task, nil
+}
+
+func TaskUpdateProgress(id int64, progress string) error {
+	return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
+}
+
+func (Task *Task) Insert() error {
+	var err error
+	err = DB.Create(Task).Error
+	return err
+}
+
+func (Task *Task) Update() error {
+	var err error
+	err = DB.Save(Task).Error
+	return err
+}
+
+func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
+	if len(TaskIds) == 0 {
+		return nil
+	}
+	return DB.Model(&Task{}).
+		Where("task_id in (?)", TaskIds).
+		Updates(params).Error
+}
+
+func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
+	if len(taskIDs) == 0 {
+		return nil
+	}
+	return DB.Model(&Task{}).
+		Where("id in (?)", taskIDs).
+		Updates(params).Error
+}
+
+func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
+	if len(ids) == 0 {
+		return nil
+	}
+	return DB.Model(&Task{}).
+		Where("id in (?)", ids).
+		Updates(params).Error
+}
+
+type TaskQuotaUsage struct {
+	Mode  string  `json:"mode"`
+	Count float64 `json:"count"`
+}
+
+func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
+	query := DB.Model(Task{})
+	// 添加过滤条件
+	if queryParams.ChannelID != "" {
+		query = query.Where("channel_id = ?", queryParams.ChannelID)
+	}
+	if queryParams.UserID != "" {
+		query = query.Where("user_id = ?", queryParams.UserID)
+	}
+	if len(queryParams.UserIDs) != 0 {
+		query = query.Where("user_id in (?)", queryParams.UserIDs)
+	}
+	if queryParams.TaskID != "" {
+		query = query.Where("task_id = ?", queryParams.TaskID)
+	}
+	if queryParams.Action != "" {
+		query = query.Where("action = ?", queryParams.Action)
+	}
+	if queryParams.Status != "" {
+		query = query.Where("status = ?", queryParams.Status)
+	}
+	if queryParams.StartTimestamp != 0 {
+		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+	}
+	if queryParams.EndTimestamp != 0 {
+		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+	}
+	err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
+	return stat, err
+}

+ 12 - 0
model/utils.go

@@ -1,6 +1,8 @@
 package model
 
 import (
+	"errors"
+	"gorm.io/gorm"
 	"one-api/common"
 	"sync"
 	"time"
@@ -75,3 +77,13 @@ func batchUpdate() {
 	}
 	common.SysLog("batch update finished")
 }
+
+func RecordExist(err error) (bool, error) {
+	if err == nil {
+		return true, nil
+	}
+	if errors.Is(err, gorm.ErrRecordNotFound) {
+		return false, nil
+	}
+	return false, err
+}

+ 19 - 0
relay/channel/adapter.go

@@ -19,3 +19,22 @@ type Adaptor interface {
 	GetModelList() []string
 	GetChannelName() string
 }
+
+type TaskAdaptor interface {
+	Init(info *relaycommon.TaskRelayInfo)
+
+	ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
+
+	BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
+	BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
+	BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
+
+	DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
+	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
+
+	GetModelList() []string
+	GetChannelName() string
+
+	// FetchTask
+	FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
+}

+ 24 - 0
relay/channel/api_request.go

@@ -50,3 +50,27 @@ func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
 	_ = c.Request.Body.Close()
 	return resp, nil
 }
+
+func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+	fullRequestURL, err := a.BuildRequestURL(info)
+	if err != nil {
+		return nil, err
+	}
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	if err != nil {
+		return nil, fmt.Errorf("new request failed: %w", err)
+	}
+	req.GetBody = func() (io.ReadCloser, error) {
+		return io.NopCloser(requestBody), nil
+	}
+
+	err = a.BuildRequestHeader(c, req, info)
+	if err != nil {
+		return nil, fmt.Errorf("setup request header failed: %w", err)
+	}
+	resp, err := doRequest(c, req)
+	if err != nil {
+		return nil, fmt.Errorf("do request failed: %w", err)
+	}
+	return resp, nil
+}

+ 172 - 0
relay/channel/task/suno/adaptor.go

@@ -0,0 +1,172 @@
+package suno
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
+	"one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+	"strings"
+	"time"
+)
+
+type TaskAdaptor struct {
+	ChannelType int
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+	a.ChannelType = info.ChannelType
+}
+
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+	action := strings.ToUpper(c.Param("action"))
+
+	var sunoRequest *dto.SunoSubmitReq
+	err := common.UnmarshalBodyReusable(c, &sunoRequest)
+	if err != nil {
+		taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+		return
+	}
+	err = actionValidate(c, sunoRequest, action)
+	if err != nil {
+		taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+		return
+	}
+
+	if sunoRequest.ContinueClipId != "" {
+		if sunoRequest.TaskID == "" {
+			taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
+			return
+		}
+		info.OriginTaskID = sunoRequest.TaskID
+	}
+
+	info.Action = action
+	c.Set("task_request", sunoRequest)
+	return nil
+}
+
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+	baseURL := info.BaseUrl
+	fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
+	return fullRequestURL, nil
+}
+
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+	return nil
+}
+
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+	sunoRequest, ok := c.Get("task_request")
+	if !ok {
+		err := common.UnmarshalBodyReusable(c, &sunoRequest)
+		if err != nil {
+			return nil, err
+		}
+	}
+	data, err := json.Marshal(sunoRequest)
+	if err != nil {
+		return nil, err
+	}
+	return bytes.NewReader(data), nil
+}
+
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+	var sunoResponse dto.TaskResponse[string]
+	err = json.Unmarshal(responseBody, &sunoResponse)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+	if !sunoResponse.IsSuccess() {
+		taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
+		return
+	}
+
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+
+	_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+
+	return sunoResponse.Data, nil, nil
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+	return ChannelName
+}
+
+func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
+	requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
+	byteBody, err := json.Marshal(body)
+	if err != nil {
+		return nil, err
+	}
+
+	req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
+	if err != nil {
+		common.SysError(fmt.Sprintf("Get Task error: %v", err))
+		return nil, err
+	}
+	defer req.Body.Close()
+	// 设置超时时间
+	timeout := time.Second * 15
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	defer cancel()
+	// 使用带有超时的 context 创建新的请求
+	req = req.WithContext(ctx)
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Authorization", "Bearer "+key)
+	resp, err := service.GetHttpClient().Do(req)
+	if err != nil {
+		return nil, err
+	}
+	return resp, nil
+}
+
+func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
+	switch action {
+	case constant.SunoActionMusic:
+		if sunoRequest.Mv == "" {
+			sunoRequest.Mv = "chirp-v3-0"
+		}
+	case constant.SunoActionLyrics:
+		if sunoRequest.Prompt == "" {
+			err = fmt.Errorf("prompt_empty")
+			return
+		}
+	default:
+		err = fmt.Errorf("invalid_action")
+	}
+	return
+}

+ 7 - 0
relay/channel/task/suno/models.go

@@ -0,0 +1,7 @@
+package suno
+
+var ModelList = []string{
+	"suno_music", "suno_lyrics",
+}
+
+var ChannelName = "suno"

+ 50 - 0
relay/common/relay_info.go

@@ -72,3 +72,53 @@ func (info *RelayInfo) SetPromptTokens(promptTokens int) {
 func (info *RelayInfo) SetIsStream(isStream bool) {
 	info.IsStream = isStream
 }
+
+type TaskRelayInfo struct {
+	ChannelType       int
+	ChannelId         int
+	TokenId           int
+	UserId            int
+	Group             string
+	StartTime         time.Time
+	ApiType           int
+	RelayMode         int
+	UpstreamModelName string
+	RequestURLPath    string
+	ApiKey            string
+	BaseUrl           string
+
+	Action       string
+	OriginTaskID string
+
+	ConsumeQuota bool
+}
+
+func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
+	channelType := c.GetInt("channel")
+	channelId := c.GetInt("channel_id")
+
+	tokenId := c.GetInt("token_id")
+	userId := c.GetInt("id")
+	group := c.GetString("group")
+	startTime := time.Now()
+
+	apiType, _ := constant.ChannelType2APIType(channelType)
+
+	info := &TaskRelayInfo{
+		RelayMode:      constant.Path2RelayMode(c.Request.URL.Path),
+		BaseUrl:        c.GetString("base_url"),
+		RequestURLPath: c.Request.URL.String(),
+		ChannelType:    channelType,
+		ChannelId:      channelId,
+		TokenId:        tokenId,
+		UserId:         userId,
+		Group:          group,
+		StartTime:      startTime,
+		ApiType:        apiType,
+		ApiKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+	}
+	if info.BaseUrl == "" {
+		info.BaseUrl = common.ChannelBaseURLs[channelType]
+	}
+	return info
+}

+ 19 - 1
relay/constant/relay_mode.go

@@ -1,6 +1,9 @@
 package constant
 
-import "strings"
+import (
+	"net/http"
+	"strings"
+)
 
 const (
 	RelayModeUnknown = iota
@@ -26,6 +29,9 @@ const (
 	RelayModeMidjourneyModal
 	RelayModeMidjourneyShorten
 	RelayModeSwapFace
+	RelayModeSunoFetch
+	RelayModeSunoFetchByID
+	RelayModeSunoSubmit
 )
 
 func Path2RelayMode(path string) int {
@@ -89,3 +95,15 @@ func Path2RelayModeMidjourney(path string) int {
 	}
 	return relayMode
 }
+
+func Path2RelaySuno(method, path string) int {
+	relayMode := RelayModeUnknown
+	if method == http.MethodPost && strings.HasSuffix(path, "/fetch") {
+		relayMode = RelayModeSunoFetch
+	} else if method == http.MethodGet && strings.Contains(path, "/fetch/") {
+		relayMode = RelayModeSunoFetchByID
+	} else if strings.Contains(path, "/submit/") {
+		relayMode = RelayModeSunoSubmit
+	}
+	return relayMode
+}

+ 12 - 0
relay/relay_adaptor.go

@@ -1,6 +1,7 @@
 package relay
 
 import (
+	commonconstant "one-api/constant"
 	"one-api/relay/channel"
 	"one-api/relay/channel/ali"
 	"one-api/relay/channel/aws"
@@ -12,6 +13,7 @@ import (
 	"one-api/relay/channel/openai"
 	"one-api/relay/channel/palm"
 	"one-api/relay/channel/perplexity"
+	"one-api/relay/channel/task/suno"
 	"one-api/relay/channel/tencent"
 	"one-api/relay/channel/xunfei"
 	"one-api/relay/channel/zhipu"
@@ -54,3 +56,13 @@ func GetAdaptor(apiType int) channel.Adaptor {
 	}
 	return nil
 }
+
+func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
+	switch platform {
+	//case constant.APITypeAIProxyLibrary:
+	//	return &aiproxy.Adaptor{}
+	case commonconstant.TaskPlatformSuno:
+		return &suno.TaskAdaptor{}
+	}
+	return nil
+}

+ 242 - 0
relay/relay_task.go

@@ -0,0 +1,242 @@
+package relay
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
+	"one-api/model"
+	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
+	"one-api/service"
+)
+
+/*
+Task 任务通过平台、Action 区分任务
+*/
+func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
+	platform := constant.TaskPlatform(c.GetString("platform"))
+	relayInfo := relaycommon.GenTaskRelayInfo(c)
+
+	adaptor := GetTaskAdaptor(platform)
+	if adaptor == nil {
+		return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
+	}
+	adaptor.Init(relayInfo)
+	// get & validate taskRequest 获取并验证文本请求
+	taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
+	if taskErr != nil {
+		return
+	}
+
+	modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
+	modelPrice, success := common.GetModelPrice(modelName, true)
+	if !success {
+		defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
+		if !ok {
+			modelPrice = 0.1
+		} else {
+			modelPrice = defaultPrice
+		}
+	}
+
+	// 预扣
+	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	ratio := modelPrice * groupRatio
+	userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+		return
+	}
+	quota := int(ratio * common.QuotaPerUnit)
+	if userQuota-quota < 0 {
+		taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
+		return
+	}
+
+	if relayInfo.OriginTaskID != "" {
+		originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
+		if err != nil {
+			taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
+			return
+		}
+		if !exist {
+			taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
+			return
+		}
+		if originTask.ChannelId != relayInfo.ChannelId {
+			channel, err := model.GetChannelById(originTask.ChannelId, true)
+			if err != nil {
+				taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
+				return
+			}
+			if channel.Status != common.ChannelStatusEnabled {
+				return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
+			}
+			c.Set("base_url", channel.GetBaseURL())
+			c.Set("channel_id", originTask.ChannelId)
+			c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+
+			relayInfo.BaseUrl = channel.GetBaseURL()
+			relayInfo.ChannelId = originTask.ChannelId
+		}
+	}
+
+	// build body
+	requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
+		return
+	}
+	// do request
+	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+		return
+	}
+	// handle response
+	if resp != nil && resp.StatusCode != http.StatusOK {
+		responseBody, _ := io.ReadAll(resp.Body)
+		taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
+		return
+	}
+
+	defer func(ctx context.Context) {
+		// release quota
+		if relayInfo.ConsumeQuota && taskErr == nil {
+			err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true)
+			if err != nil {
+				common.SysError("error consuming token remain quota: " + err.Error())
+			}
+			err = model.CacheUpdateUserQuota(relayInfo.UserId)
+			if err != nil {
+				common.SysError("error update user quota cache: " + err.Error())
+			}
+			if quota != 0 {
+				tokenName := c.GetString("token_name")
+				logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
+				other := make(map[string]interface{})
+				other["model_price"] = modelPrice
+				other["group_ratio"] = groupRatio
+				model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other)
+				model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+				model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+			}
+		}
+	}(c.Request.Context())
+
+	taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
+	if taskErr != nil {
+		return
+	}
+	relayInfo.ConsumeQuota = true
+	// insert task
+	task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
+	task.TaskID = taskID
+	task.Quota = quota
+	task.Data = taskData
+	err = task.Insert()
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
+		return
+	}
+	return nil
+}
+
+var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
+	relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
+	relayconstant.RelayModeSunoFetch:     sunoFetchRespBodyBuilder,
+}
+
+func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
+	respBuilder, ok := fetchRespBuilders[relayMode]
+	if !ok {
+		taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
+	}
+
+	respBody, taskErr := respBuilder(c)
+	if taskErr != nil {
+		return taskErr
+	}
+
+	c.Writer.Header().Set("Content-Type", "application/json")
+	_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
+	if err != nil {
+		taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+	return
+}
+
+func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
+	userId := c.GetInt("id")
+	var condition = struct {
+		IDs    []any  `json:"ids"`
+		Action string `json:"action"`
+	}{}
+	err := c.BindJSON(&condition)
+	if err != nil {
+		taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
+		return
+	}
+	var tasks []any
+	if len(condition.IDs) > 0 {
+		taskModels, err := model.GetByTaskIds(userId, condition.IDs)
+		if err != nil {
+			taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
+			return
+		}
+		for _, task := range taskModels {
+			tasks = append(tasks, TaskModel2Dto(task))
+		}
+	} else {
+		tasks = make([]any, 0)
+	}
+	respBody, err = json.Marshal(dto.TaskResponse[[]any]{
+		Code: "success",
+		Data: tasks,
+	})
+	return
+}
+
+func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
+	taskId := c.Param("id")
+	userId := c.GetInt("id")
+
+	originTask, exist, err := model.GetByTaskId(userId, taskId)
+	if err != nil {
+		taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
+		return
+	}
+	if !exist {
+		taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
+		return
+	}
+
+	respBody, err = json.Marshal(dto.TaskResponse[any]{
+		Code: "success",
+		Data: TaskModel2Dto(originTask),
+	})
+	return
+}
+
+func TaskModel2Dto(task *model.Task) *dto.TaskDto {
+	return &dto.TaskDto{
+		TaskID:     task.TaskID,
+		Action:     task.Action,
+		Status:     string(task.Status),
+		FailReason: task.FailReason,
+		SubmitTime: task.SubmitTime,
+		StartTime:  task.StartTime,
+		FinishTime: task.FinishTime,
+		Progress:   task.Progress,
+		Data:       task.Data,
+	}
+}

+ 6 - 0
router/api-router.go

@@ -140,5 +140,11 @@ func SetApiRouter(router *gin.Engine) {
 		mjRoute := apiRouter.Group("/mj")
 		mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
 		mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
+
+		taskRoute := apiRouter.Group("/task")
+		{
+			taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask)
+			taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask)
+		}
 	}
 }

+ 9 - 0
router/relay-router.go

@@ -50,6 +50,15 @@ func SetRelayRouter(router *gin.Engine) {
 	relayMjModeRouter := router.Group("/:mode/mj")
 	registerMjRouterGroup(relayMjModeRouter)
 	//relayMjRouter.Use()
+
+	relaySunoRouter := router.Group("/suno")
+	relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
+	{
+		relaySunoRouter.POST("/submit/:action", controller.RelayTask)
+		relaySunoRouter.POST("/fetch", controller.RelayTask)
+		relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
+	}
+
 }
 
 func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {

+ 26 - 0
service/error.go

@@ -105,3 +105,29 @@ func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMapping
 		openaiErr.StatusCode = intCode
 	}
 }
+
+func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError {
+	openaiErr := TaskErrorWrapper(err, code, statusCode)
+	openaiErr.LocalError = true
+	return openaiErr
+}
+
+func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
+	text := err.Error()
+
+	// 定义一个正则表达式匹配URL
+	if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
+		common.SysLog(fmt.Sprintf("error: %s", text))
+		text = "请求上游地址失败"
+	}
+	//避免暴露内部错误
+
+	taskError := &dto.TaskError{
+		Code:       code,
+		Message:    text,
+		StatusCode: statusCode,
+		Error:      err,
+	}
+
+	return taskError
+}

+ 10 - 0
service/task.go

@@ -0,0 +1,10 @@
+package service
+
+import (
+	"one-api/constant"
+	"strings"
+)
+
+func CoverTaskActionToModelName(platform constant.TaskPlatform, action string) string {
+	return strings.ToLower(string(platform)) + "_" + strings.ToLower(action)
+}

+ 11 - 0
web/src/App.js

@@ -23,6 +23,7 @@ import Chat from './pages/Chat';
 import { Layout } from '@douyinfe/semi-ui';
 import Midjourney from './pages/Midjourney';
 import Pricing from './pages/Pricing/index.js';
+import Task from "./pages/Task/index.js";
 // import Detail from './pages/Detail';
 
 const Home = lazy(() => import('./pages/Home'));
@@ -220,6 +221,16 @@ function App() {
               </PrivateRoute>
             }
           />
+          <Route
+            path='/task'
+            element={
+                <PrivateRoute>
+                    <Suspense fallback={<Loading></Loading>}>
+                        <Task />
+                    </Suspense>
+                </PrivateRoute>
+            }
+          />
           <Route
             path='/pricing'
             element={

+ 13 - 1
web/src/components/SiderBar.js

@@ -14,7 +14,7 @@ import {
 import '../index.css';
 
 import {
-  IconCalendarClock,
+  IconCalendarClock, IconChecklistStroked,
   IconComment,
   IconCreditCard,
   IconGift,
@@ -58,6 +58,7 @@ const SiderBar = () => {
     chat: '/chat',
     detail: '/detail',
     pricing: '/pricing',
+    task: '/task',
   };
 
   const headerButtons = useMemo(
@@ -142,6 +143,16 @@ const SiderBar = () => {
             ? 'semi-navigation-item-normal'
             : 'tableHiddle',
       },
+      {
+        text: '异步任务',
+        itemKey: 'task',
+        to: '/task',
+        icon: <IconChecklistStroked />,
+        className:
+            localStorage.getItem('enable_task') === 'true'
+                ? 'semi-navigation-item-normal'
+                : 'tableHiddle',
+      },
       {
         text: '设置',
         itemKey: 'setting',
@@ -158,6 +169,7 @@ const SiderBar = () => {
     [
       localStorage.getItem('enable_data_export'),
       localStorage.getItem('enable_drawing'),
+      localStorage.getItem('enable_task'),
       localStorage.getItem('chat_link'),
       isAdmin(),
     ],

+ 400 - 0
web/src/components/TaskLogsTable.js

@@ -0,0 +1,400 @@
+import React, { useEffect, useState } from 'react';
+import { Label } from 'semantic-ui-react';
+import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers';
+
+import {
+    Table,
+    Tag,
+    Form,
+    Button,
+    Layout,
+    Modal,
+    Typography, Progress, Card
+} from '@douyinfe/semi-ui';
+import { ITEMS_PER_PAGE } from '../constants';
+
+const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
+    'light-blue', 'lime', 'orange', 'pink',
+    'purple', 'red', 'teal', 'violet', 'yellow'
+]
+
+
+const renderTimestamp = (timestampInSeconds) => {
+    const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒
+
+    const year = date.getFullYear(); // 获取年份
+    const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数
+    const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数
+    const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数
+    const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数
+    const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数
+
+    return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
+};
+
+function renderDuration(submit_time, finishTime) {
+    // 确保startTime和finishTime都是有效的时间戳
+    if (!submit_time || !finishTime) return 'N/A';
+
+    // 将时间戳转换为Date对象
+    const start = new Date(submit_time);
+    const finish = new Date(finishTime);
+
+    // 计算时间差(毫秒)
+    const durationMs = finish - start;
+
+    // 将时间差转换为秒,并保留一位小数
+    const durationSec = (durationMs / 1000).toFixed(1);
+
+    // 设置颜色:大于60秒则为红色,小于等于60秒则为绿色
+    const color = durationSec > 60 ? 'red' : 'green';
+
+    // 返回带有样式的颜色标签
+    return (
+        <Tag color={color} size="large">
+            {durationSec} 秒
+        </Tag>
+    );
+}
+
+const LogsTable = () => {
+    const [isModalOpen, setIsModalOpen] = useState(false);
+    const [modalContent, setModalContent] = useState('');
+    const isAdminUser = isAdmin();
+    const columns = [
+        {
+            title: "提交时间",
+            dataIndex: 'submit_time',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {text ? renderTimestamp(text) : "-"}
+                    </div>
+                );
+            },
+        },
+        {
+            title: "结束时间",
+            dataIndex: 'finish_time',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {text ? renderTimestamp(text) : "-"}
+                    </div>
+                );
+            },
+        },
+        {
+            title: '进度',
+            dataIndex: 'progress',
+            width: 50,
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {
+                            // 转换例如100%为数字100,如果text未定义,返回0
+                            isNaN(text.replace('%', '')) ? text : <Progress width={42} type="circle" showInfo={true} percent={Number(text.replace('%', '') || 0)} aria-label="drawing progress" />
+                        }
+                    </div>
+                );
+            },
+        },
+        {
+            title: '花费时间',
+            dataIndex: 'finish_time', // 以finish_time作为dataIndex
+            key: 'finish_time',
+            render: (finish, record) => {
+                // 假设record.start_time是存在的,并且finish是完成时间的时间戳
+                return <>
+                    {
+                        finish ? renderDuration(record.submit_time, finish) : "-"
+                    }
+                </>
+            },
+        },
+        {
+            title: "渠道",
+            dataIndex: 'channel_id',
+            className: isAdminUser ? 'tableShow' : 'tableHiddle',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        <Tag
+                            color={colors[parseInt(text) % colors.length]}
+                            size='large'
+                            onClick={() => {
+                                copyText(text); // 假设copyText是用于文本复制的函数
+                            }}
+                        >
+                            {' '}
+                            {text}{' '}
+                        </Tag>
+                    </div>
+                );
+            },
+        },
+        {
+            title: "平台",
+            dataIndex: 'platform',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {renderPlatform(text)}
+                    </div>
+                );
+            },
+        },
+        {
+            title: '类型',
+            dataIndex: 'action',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {renderType(text)}
+                    </div>
+                );
+            },
+        },
+        {
+            title: '任务ID(点击查看详情)',
+            dataIndex: 'task_id',
+            render: (text, record, index) => {
+                return (<Typography.Text
+                    ellipsis={{ showTooltip: true }}
+                    //style={{width: 100}}
+                    onClick={() => {
+                        setModalContent(JSON.stringify(record, null, 2));
+                        setIsModalOpen(true);
+                    }}
+                >
+                    <div>
+                        {text}
+                    </div>
+                </Typography.Text>);
+            },
+        },
+        {
+            title: '任务状态',
+            dataIndex: 'status',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {renderStatus(text)}
+                    </div>
+                );
+            },
+        },
+
+        {
+            title: '失败原因',
+            dataIndex: 'fail_reason',
+            render: (text, record, index) => {
+                // 如果text未定义,返回替代文本,例如空字符串''或其他
+                if (!text) {
+                    return '无';
+                }
+
+                return (
+                    <Typography.Text
+                        ellipsis={{ showTooltip: true }}
+                        style={{ width: 100 }}
+                        onClick={() => {
+                            setModalContent(text);
+                            setIsModalOpen(true);
+                        }}
+                    >
+                        {text}
+                    </Typography.Text>
+                );
+            }
+        }
+    ];
+
+    const [logs, setLogs] = useState([]);
+    const [loading, setLoading] = useState(true);
+    const [activePage, setActivePage] = useState(1);
+    const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
+    const [logType] = useState(0);
+
+    let now = new Date();
+    // 初始化start_timestamp为前一天
+    let zeroNow = new Date(now.getFullYear(), now.getMonth(), now.getDate());
+    const [inputs, setInputs] = useState({
+        channel_id: '',
+        task_id: '',
+        start_timestamp: timestamp2string(zeroNow.getTime() /1000),
+        end_timestamp: '',
+    });
+    const { channel_id, task_id, start_timestamp, end_timestamp } = inputs;
+
+    const handleInputChange = (value, name) => {
+        setInputs((inputs) => ({ ...inputs, [name]: value }));
+    };
+
+
+    const setLogsFormat = (logs) => {
+        for (let i = 0; i < logs.length; i++) {
+            logs[i].timestamp2string = timestamp2string(logs[i].created_at);
+            logs[i].key = '' + logs[i].id;
+        }
+        // data.key = '' + data.id
+        setLogs(logs);
+        setLogCount(logs.length + ITEMS_PER_PAGE);
+        // console.log(logCount);
+    }
+
+    const loadLogs = async (startIdx) => {
+        setLoading(true);
+
+        let url = '';
+        let localStartTimestamp = parseInt(Date.parse(start_timestamp) / 1000);
+        let localEndTimestamp = parseInt(Date.parse(end_timestamp) / 1000 );
+        if (isAdminUser) {
+            url = `/api/task/?p=${startIdx}&channel_id=${channel_id}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
+        } else {
+            url = `/api/task/self?p=${startIdx}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
+        }
+        const res = await API.get(url);
+        let { success, message, data } = res.data;
+        if (success) {
+            if (startIdx === 0) {
+                setLogsFormat(data);
+            } else {
+                let newLogs = [...logs];
+                newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
+                setLogsFormat(newLogs);
+            }
+        } else {
+            showError(message);
+        }
+        setLoading(false);
+    };
+
+    const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE);
+
+    const handlePageChange = page => {
+        setActivePage(page);
+        if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
+            // In this case we have to load more data and then append them.
+            loadLogs(page - 1).then(r => {
+            });
+        }
+    };
+
+    const refresh = async () => {
+        // setLoading(true);
+        setActivePage(1);
+        await loadLogs(0);
+    };
+
+    const copyText = async (text) => {
+        if (await copy(text)) {
+            showSuccess('已复制:' + text);
+        } else {
+            // setSearchKeyword(text);
+            Modal.error({ title: "无法复制到剪贴板,请手动复制", content: text });
+        }
+    }
+
+    useEffect(() => {
+        refresh().then();
+    }, [logType]);
+
+    const renderType = (type) => {
+        switch (type) {
+            case 'MUSIC':
+                return <Label basic color='grey'> 生成音乐 </Label>;
+            case 'LYRICS':
+                return <Label basic color='pink'> 生成歌词 </Label>;
+
+            default:
+                return <Label basic color='black'> 未知 </Label>;
+        }
+    }
+
+    const renderPlatform = (type) => {
+        switch (type) {
+            case "suno":
+                return <Label basic color='green'> Suno </Label>;
+            default:
+                return <Label basic color='black'> 未知 </Label>;
+        }
+    }
+
+    const renderStatus = (type) => {
+        switch (type) {
+            case 'SUCCESS':
+                return <Label basic color='green'> 成功 </Label>;
+            case 'NOT_START':
+                return <Label basic color='black'> 未启动 </Label>;
+            case 'SUBMITTED':
+                return <Label basic color='yellow'> 队列中 </Label>;
+            case 'IN_PROGRESS':
+                return <Label basic color='blue'> 执行中 </Label>;
+            case 'FAILURE':
+                return <Label basic color='red'> 失败 </Label>;
+            case 'QUEUED':
+                return <Label basic color='red'> 排队中 </Label>;
+            case 'UNKNOWN':
+                return <Label basic color='red'> 未知 </Label>;
+            case '':
+                return <Label basic color='black'> 正在提交 </Label>;
+            default:
+                return <Label basic color='black'> 未知 </Label>;
+        }
+    }
+
+    return (
+        <>
+
+            <Layout>
+                <Form layout='horizontal' labelPosition='inset'>
+                    <>
+                        {isAdminUser && <Form.Input field="channel_id" label='渠道 ID' style={{ width: '236px', marginBottom: '10px' }} value={channel_id}
+                                                    placeholder={'可选值'} name='channel_id'
+                                                    onChange={value => handleInputChange(value, 'channel_id')} />
+                        }
+                        <Form.Input field="task_id" label={"任务 ID"} style={{ width: '236px', marginBottom: '10px' }} value={task_id}
+                            placeholder={"可选值"}
+                            name='task_id'
+                            onChange={value => handleInputChange(value, 'task_id')} />
+
+                        <Form.DatePicker field="start_timestamp" label={"起始时间"} style={{ width: '236px', marginBottom: '10px' }}
+                            initValue={start_timestamp}
+                            value={start_timestamp} type='dateTime'
+                            name='start_timestamp'
+                            onChange={value => handleInputChange(value, 'start_timestamp')} />
+                        <Form.DatePicker field="end_timestamp" fluid label={"结束时间"} style={{ width: '236px', marginBottom: '10px' }}
+                            initValue={end_timestamp}
+                            value={end_timestamp} type='dateTime'
+                            name='end_timestamp'
+                            onChange={value => handleInputChange(value, 'end_timestamp')} />
+                        <Button label={"查询"} type="primary" htmlType="submit" className="btn-margin-right"
+                            onClick={refresh}>查询</Button>
+                    </>
+                </Form>
+                <Card>
+                    <Table columns={columns} dataSource={pageData} pagination={{
+                        currentPage: activePage,
+                        pageSize: ITEMS_PER_PAGE,
+                        total: logCount,
+                        pageSizeOpts: [10, 20, 50, 100],
+                        onPageChange: handlePageChange,
+                    }} loading={loading} />
+                </Card>
+                <Modal
+                    visible={isModalOpen}
+                    onOk={() => setIsModalOpen(false)}
+                    onCancel={() => setIsModalOpen(false)}
+                    closable={null}
+                    bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式
+                    width={800} // 设置模态框宽度
+                >
+                    <p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
+                </Modal>
+            </Layout>
+        </>
+    );
+};
+
+export default LogsTable;

+ 7 - 0
web/src/constants/channel.constants.js

@@ -14,6 +14,13 @@ export const CHANNEL_OPTIONS = [
     color: 'blue',
     label: 'Midjourney Proxy Plus',
   },
+  {
+    key: 36,
+    text: 'Suno API',
+    value: 36,
+    color: 'purple',
+    label: 'Suno API',
+  },
   { key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama' },
   {
     key: 14,

+ 1 - 0
web/src/helpers/data.js

@@ -6,6 +6,7 @@ export function setStatusData(data) {
   localStorage.setItem('quota_per_unit', data.quota_per_unit);
   localStorage.setItem('display_in_currency', data.display_in_currency);
   localStorage.setItem('enable_drawing', data.enable_drawing);
+  localStorage.setItem('enable_task', data.enable_task);
   localStorage.setItem('enable_data_export', data.enable_data_export);
   localStorage.setItem(
     'data_export_default_time',

+ 30 - 4
web/src/pages/Channel/EditChannel.js

@@ -126,6 +126,12 @@ const EditChannel = (props) => {
             'mj_uploads',
           ];
           break;
+        case 36:
+          localModels = [
+            'suno_music',
+            'suno_lyrics',
+          ];
+          break;
         default:
           localModels = getChannelModels(value);
           break;
@@ -513,12 +519,32 @@ const EditChannel = (props) => {
               />
             </>
           )}
-          <div style={{ marginTop: 10 }}>
+          {inputs.type === 36 && (
+              <>
+                <div style={{marginTop: 10}}>
+                  <Typography.Text strong>
+                    注意非Chat API,请务必填写正确的API地址,否则可能导致无法使用
+                  </Typography.Text>
+                </div>
+                <Input
+                    name='base_url'
+                    placeholder={
+                      '请输入到 /suno 前的路径,通常就是域名,例如:https://api.example.com '
+                    }
+                    onChange={(value) => {
+                      handleInputChange('base_url', value);
+                    }}
+                    value={inputs.base_url}
+                    autoComplete='new-password'
+                />
+              </>
+          )}
+          <div style={{marginTop: 10}}>
             <Typography.Text strong>名称:</Typography.Text>
           </div>
           <Input
-            required
-            name='name'
+              required
+              name='name'
             placeholder={'请为渠道命名'}
             onChange={(value) => {
               handleInputChange('name', value);
@@ -758,7 +784,7 @@ const EditChannel = (props) => {
               </Space>
             </div>
           )}
-          {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
+          {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && (
             <>
               <div style={{ marginTop: 10 }}>
                 <Typography.Text strong>代理:</Typography.Text>

+ 10 - 0
web/src/pages/Task/index.js

@@ -0,0 +1,10 @@
+import React from 'react';
+import TaskLogsTable from "../../components/TaskLogsTable.js";
+
+const Task = () => (
+  <>
+    <TaskLogsTable />
+  </>
+);
+
+export default Task;