فهرست منبع

feat: 调试 suno

Xyfacai 1 سال پیش
والد
کامیت
606aa8a4a7

+ 3 - 1
common/constants.go

@@ -21,6 +21,7 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
 var DisplayInCurrencyEnabled = true
 var DisplayTokenStatEnabled = true
 var DrawingEnabled = true
+var TaskEnabled = true
 var DataExportEnabled = true
 var DataExportInterval = 5         // unit: minute
 var DataExportDefaultTime = "hour" // unit: minute
@@ -208,7 +209,7 @@ const (
 	ChannelTypeAws            = 33
 	ChannelTypeCohere         = 34
 	ChannelTypeMiniMax        = 35
-	ChannelTypeSuno           = 36
+	ChannelTypeSunoAPI        = 36
 
 	ChannelTypeDummy // this one is only for count, do not add any channel after this
 
@@ -251,4 +252,5 @@ var ChannelBaseURLs = []string{
 	"",                                          //33
 	"https://api.cohere.ai",                     //34
 	"https://api.minimax.chat",                  //35
+	"",                                          //36
 }

+ 3 - 0
controller/channel-test.go

@@ -27,6 +27,9 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 	if channel.Type == common.ChannelTypeMidjourney {
 		return errors.New("midjourney channel test is not supported"), nil
 	}
+	if channel.Type == common.ChannelTypeSunoAPI {
+		return errors.New("suno channel test is not supported"), nil
+	}
 	w := httptest.NewRecorder()
 	c, _ := gin.CreateTestContext(w)
 	c.Request = &http.Request{

+ 1 - 0
controller/misc.go

@@ -57,6 +57,7 @@ func GetStatus(c *gin.Context) {
 			"display_in_currency":      common.DisplayInCurrencyEnabled,
 			"enable_batch_update":      common.BatchUpdateEnabled,
 			"enable_drawing":           common.DrawingEnabled,
+			"enable_task":              common.TaskEnabled,
 			"enable_data_export":       common.DataExportEnabled,
 			"data_export_default_time": common.DataExportDefaultTime,
 			"default_collapse_sidebar": common.DefaultCollapseSidebar,

+ 237 - 31
controller/task.go

@@ -1,11 +1,21 @@
 package controller
 
 import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"errors"
+	"fmt"
 	"github.com/gin-gonic/gin"
-	"log"
+	"github.com/samber/lo"
+	"io"
+	"net/http"
 	"one-api/common"
 	"one-api/constant"
+	"one-api/dto"
 	"one-api/model"
+	"one-api/service"
+	"sort"
 	"strconv"
 	"time"
 )
@@ -16,42 +26,238 @@ func UpdateTaskBulk() {
 	for {
 		time.Sleep(time.Duration(15) * time.Second)
 		common.SysLog("任务进度轮询开始")
+		ctx := context.TODO()
 		allTasks := model.GetAllUnFinishSyncTasks(500)
 		platformTask := make(map[constant.TaskPlatform][]*model.Task)
 		for _, t := range allTasks {
 			platformTask[t.Platform] = append(platformTask[t.Platform], t)
 		}
 		for platform, tasks := range platformTask {
-			UpdateTaskByPlatform(platform, tasks)
+			if len(tasks) == 0 {
+				continue
+			}
+			taskChannelM := make(map[int][]string)
+			taskM := make(map[string]*model.Task)
+			nullTaskIds := make([]int64, 0)
+			for _, task := range tasks {
+				if task.TaskID == "" {
+					// 统计失败的未完成任务
+					nullTaskIds = append(nullTaskIds, task.ID)
+					continue
+				}
+				taskM[task.TaskID] = task
+				taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], task.TaskID)
+			}
+			if len(nullTaskIds) > 0 {
+				err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{
+					"status":   "FAILURE",
+					"progress": "100%",
+				})
+				if err != nil {
+					common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
+				} else {
+					common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
+				}
+			}
+			if len(taskChannelM) == 0 {
+				continue
+			}
+
+			UpdateTaskByPlatform(platform, taskChannelM, taskM)
 		}
 		common.SysLog("任务进度轮询完成")
 	}
 }
 
-func GetAllMidjourney(c *gin.Context) {
+func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) {
+	switch platform {
+	case constant.TaskPlatformMidjourney:
+		//_ = UpdateMidjourneyTaskAll(context.Background(), tasks)
+	case constant.TaskPlatformSuno:
+		_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
+	default:
+		common.SysLog("未知平台")
+	}
+}
+
+func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error {
+	for channelId, taskIds := range taskChannelM {
+		err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
+		if err != nil {
+			common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
+		}
+	}
+	return nil
+}
+
+func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
+	common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
+	if len(taskIds) == 0 {
+		return nil
+	}
+	channel, err := model.CacheGetChannel(channelId)
+	if err != nil {
+		common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
+		err = model.TaskBulkUpdate(taskIds, map[string]any{
+			"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
+			"status":      "FAILURE",
+			"progress":    "100%",
+		})
+		if err != nil {
+			common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
+		}
+		return err
+	}
+	requestUrl := fmt.Sprintf("%s/fetch", *channel.BaseURL)
+
+	body, _ := json.Marshal(map[string]any{
+		"ids": taskIds,
+	})
+	req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(body))
+	if err != nil {
+		common.SysError(fmt.Sprintf("Get Task error: %v", err))
+		return err
+	}
+	defer req.Body.Close()
+	// 设置超时时间
+	timeout := time.Second * 15
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	defer cancel()
+	// 使用带有超时的 context 创建新的请求
+	req = req.WithContext(ctx)
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Authorization", "Bearer "+channel.Key)
+	resp, err := service.GetHttpClient().Do(req)
+	if err != nil {
+		common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
+		return err
+	}
+	if resp.StatusCode != http.StatusOK {
+		common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+		return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
+	}
+	defer resp.Body.Close()
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
+		return err
+	}
+	var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
+	err = json.Unmarshal(responseBody, &responseItems)
+	if err != nil {
+		common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, req: %s, body: %s", err, string(body), string(responseBody)))
+		return err
+	}
+	if !responseItems.IsSuccess() {
+		common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
+		return err
+	}
+
+	for _, responseItem := range responseItems.Data {
+		task := taskM[responseItem.TaskID]
+		if !checkTaskNeedUpdate(task, responseItem) {
+			continue
+		}
+
+		task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status)
+		task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason)
+		task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime)
+		task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
+		task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
+		if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
+			common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
+			task.Progress = "100%"
+			err = model.CacheUpdateUserQuota(task.UserId)
+			if err != nil {
+				common.LogError(ctx, "error update user quota cache: "+err.Error())
+			} else {
+				quota := task.Quota
+				if quota != 0 {
+					err = model.IncreaseUserQuota(task.UserId, quota)
+					if err != nil {
+						common.LogError(ctx, "fail to increase user quota: "+err.Error())
+					}
+					logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
+					model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
+				}
+			}
+		}
+		if responseItem.Status == model.TaskStatusSuccess {
+			task.Progress = "100%"
+		}
+		task.Data = responseItem.Data
+
+		err = task.Update()
+		if err != nil {
+			common.SysError("UpdateMidjourneyTask task error: " + err.Error())
+		}
+	}
+	return nil
+}
+
+func checkTaskNeedUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool {
+
+	if oldTask.SubmitTime != newTask.SubmitTime {
+		return true
+	}
+	if oldTask.StartTime != newTask.StartTime {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+	if string(oldTask.Status) != newTask.Status {
+		return true
+	}
+	if oldTask.FailReason != newTask.FailReason {
+		return true
+	}
+	if oldTask.FinishTime != newTask.FinishTime {
+		return true
+	}
+
+	if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" {
+		return true
+	}
+
+	oldData, _ := json.Marshal(oldTask.Data)
+	newData, _ := json.Marshal(newTask.Data)
+
+	sort.Slice(oldData, func(i, j int) bool {
+		return oldData[i] < oldData[j]
+	})
+	sort.Slice(newData, func(i, j int) bool {
+		return newData[i] < newData[j]
+	})
+
+	if string(oldData) != string(newData) {
+		return true
+	}
+	return false
+}
+
+func GetAllTask(c *gin.Context) {
 	p, _ := strconv.Atoi(c.Query("p"))
 	if p < 0 {
 		p = 0
 	}
-
+	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 	// 解析其他查询参数
-	queryParams := model.TaskQueryParams{
-		ChannelID:      c.Query("channel_id"),
-		MjID:           c.Query("mj_id"),
-		StartTimestamp: c.Query("start_timestamp"),
-		EndTimestamp:   c.Query("end_timestamp"),
+	queryParams := model.SyncTaskQueryParams{
+		Platform:       constant.TaskPlatform(c.Query("platform")),
+		TaskID:         c.Query("task_id"),
+		Status:         c.Query("status"),
+		Action:         c.Query("action"),
+		StartTimestamp: startTimestamp,
+		EndTimestamp:   endTimestamp,
 	}
 
-	logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
+	logs := model.TaskGetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
 	if logs == nil {
-		logs = make([]*model.Midjourney, 0)
-	}
-	if constant.MjForwardUrlEnabled {
-		for i, midjourney := range logs {
-			midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
-			logs[i] = midjourney
-		}
+		logs = make([]*model.Task, 0)
 	}
+
 	c.JSON(200, gin.H{
 		"success": true,
 		"message": "",
@@ -59,31 +265,31 @@ func GetAllMidjourney(c *gin.Context) {
 	})
 }
 
-func GetUserMidjourney(c *gin.Context) {
+func GetUserTask(c *gin.Context) {
 	p, _ := strconv.Atoi(c.Query("p"))
 	if p < 0 {
 		p = 0
 	}
 
 	userId := c.GetInt("id")
-	log.Printf("userId = %d \n", userId)
 
-	queryParams := model.TaskQueryParams{
-		MjID:           c.Query("mj_id"),
-		StartTimestamp: c.Query("start_timestamp"),
-		EndTimestamp:   c.Query("end_timestamp"),
+	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
+	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
+
+	queryParams := model.SyncTaskQueryParams{
+		Platform:       constant.TaskPlatform(c.Query("platform")),
+		TaskID:         c.Query("task_id"),
+		Status:         c.Query("status"),
+		Action:         c.Query("action"),
+		StartTimestamp: startTimestamp,
+		EndTimestamp:   endTimestamp,
 	}
 
-	logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
+	logs := model.TaskGetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
 	if logs == nil {
-		logs = make([]*model.Midjourney, 0)
-	}
-	if constant.MjForwardUrlEnabled {
-		for i, midjourney := range logs {
-			midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
-			logs[i] = midjourney
-		}
+		logs = make([]*model.Task, 0)
 	}
+
 	c.JSON(200, gin.H{
 		"success": true,
 		"message": "",

+ 3 - 0
main.go

@@ -92,6 +92,9 @@ func main() {
 	common.SafeGoroutine(func() {
 		controller.UpdateMidjourneyTaskBulk()
 	})
+	common.SafeGoroutine(func() {
+		controller.UpdateTaskBulk()
+	})
 	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
 		common.BatchUpdateEnabled = true
 		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")

+ 1 - 1
middleware/distributor.go

@@ -134,7 +134,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 			modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
 			modelRequest.Model = modelName
 		}
-		c.Set("platform", constant.TaskPlatformSuno)
+		c.Set("platform", string(constant.TaskPlatformSuno))
 		c.Set("relay_mode", relayMode)
 	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 		err = common.UnmarshalBodyReusable(c, &modelRequest)

+ 4 - 0
model/main.go

@@ -140,6 +140,10 @@ func InitDB() (err error) {
 		if err != nil {
 			return err
 		}
+		err = db.AutoMigrate(&Task{})
+		if err != nil {
+			return err
+		}
 		common.SysLog("database migrated")
 		err = createRootAccountIfNeed()
 		return err

+ 3 - 0
model/option.go

@@ -41,6 +41,7 @@ func InitOptionMap() {
 	common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
 	common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
 	common.OptionMap["DrawingEnabled"] = strconv.FormatBool(common.DrawingEnabled)
+	common.OptionMap["TaskEnabled"] = strconv.FormatBool(common.TaskEnabled)
 	common.OptionMap["DataExportEnabled"] = strconv.FormatBool(common.DataExportEnabled)
 	common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
 	common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled)
@@ -195,6 +196,8 @@ func updateOptionMap(key string, value string) (err error) {
 			common.DisplayTokenStatEnabled = boolValue
 		case "DrawingEnabled":
 			common.DrawingEnabled = boolValue
+		case "TaskEnabled":
+			common.TaskEnabled = boolValue
 		case "DataExportEnabled":
 			common.DataExportEnabled = boolValue
 		case "DefaultCollapseSidebar":

+ 2 - 7
relay/channel/task/suno/adaptor.go

@@ -18,12 +18,10 @@ import (
 
 type TaskAdaptor struct {
 	ChannelType int
-	Action      string
 }
 
 func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
 	a.ChannelType = info.ChannelType
-
 }
 
 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
@@ -49,16 +47,13 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
 		info.OriginTaskID = sunoRequest.TaskID
 	}
 
-	a.Action = info.Action
+	info.Action = action
 	c.Set("task_request", sunoRequest)
 	return nil
 }
 
 func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
-	baseURL := common.ChannelBaseURLs[info.ChannelType]
-	if info.BaseUrl != "" {
-		baseURL = info.BaseUrl
-	}
+	baseURL := info.BaseUrl
 	fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/submit/"+info.Action)
 	return fullRequestURL, nil
 }

+ 6 - 0
router/api-router.go

@@ -140,5 +140,11 @@ func SetApiRouter(router *gin.Engine) {
 		mjRoute := apiRouter.Group("/mj")
 		mjRoute.GET("/self", middleware.UserAuth(), controller.GetUserMidjourney)
 		mjRoute.GET("/", middleware.AdminAuth(), controller.GetAllMidjourney)
+
+		taskRoute := apiRouter.Group("/task")
+		{
+			taskRoute.GET("/self", middleware.UserAuth(), controller.GetUserTask)
+			taskRoute.GET("/", middleware.AdminAuth(), controller.GetAllTask)
+		}
 	}
 }

+ 11 - 0
web/src/App.js

@@ -23,6 +23,7 @@ import Chat from './pages/Chat';
 import { Layout } from '@douyinfe/semi-ui';
 import Midjourney from './pages/Midjourney';
 import Pricing from './pages/Pricing/index.js';
+import Task from "./pages/Task/index.js";
 // import Detail from './pages/Detail';
 
 const Home = lazy(() => import('./pages/Home'));
@@ -220,6 +221,16 @@ function App() {
               </PrivateRoute>
             }
           />
+          <Route
+            path='/task'
+            element={
+                <PrivateRoute>
+                    <Suspense fallback={<Loading></Loading>}>
+                        <Task />
+                    </Suspense>
+                </PrivateRoute>
+            }
+          />
           <Route
             path='/pricing'
             element={

+ 13 - 1
web/src/components/SiderBar.js

@@ -14,7 +14,7 @@ import {
 import '../index.css';
 
 import {
-  IconCalendarClock,
+  IconCalendarClock, IconChecklistStroked,
   IconComment,
   IconCreditCard,
   IconGift,
@@ -58,6 +58,7 @@ const SiderBar = () => {
     chat: '/chat',
     detail: '/detail',
     pricing: '/pricing',
+    task: '/task',
   };
 
   const headerButtons = useMemo(
@@ -142,6 +143,16 @@ const SiderBar = () => {
             ? 'semi-navigation-item-normal'
             : 'tableHiddle',
       },
+      {
+        text: '异步任务',
+        itemKey: 'task',
+        to: '/task',
+        icon: <IconChecklistStroked />,
+        className:
+            localStorage.getItem('enable_task') === 'true'
+                ? 'semi-navigation-item-normal'
+                : 'tableHiddle',
+      },
       {
         text: '设置',
         itemKey: 'setting',
@@ -158,6 +169,7 @@ const SiderBar = () => {
     [
       localStorage.getItem('enable_data_export'),
       localStorage.getItem('enable_drawing'),
+      localStorage.getItem('enable_task'),
       localStorage.getItem('chat_link'),
       isAdmin(),
     ],

+ 400 - 0
web/src/components/TaskLogsTable.js

@@ -0,0 +1,400 @@
+import React, { useEffect, useState } from 'react';
+import { Label } from 'semantic-ui-react';
+import { API, copy, isAdmin, showError, showSuccess, timestamp2string } from '../helpers';
+
+import {
+    Table,
+    Tag,
+    Form,
+    Button,
+    Layout,
+    Modal,
+    Typography, Progress, Card
+} from '@douyinfe/semi-ui';
+import { ITEMS_PER_PAGE } from '../constants';
+
+const colors = ['amber', 'blue', 'cyan', 'green', 'grey', 'indigo',
+    'light-blue', 'lime', 'orange', 'pink',
+    'purple', 'red', 'teal', 'violet', 'yellow'
+]
+
+
+const renderTimestamp = (timestampInSeconds) => {
+    const date = new Date(timestampInSeconds * 1000); // 从秒转换为毫秒
+
+    const year = date.getFullYear(); // 获取年份
+    const month = ('0' + (date.getMonth() + 1)).slice(-2); // 获取月份,从0开始需要+1,并保证两位数
+    const day = ('0' + date.getDate()).slice(-2); // 获取日期,并保证两位数
+    const hours = ('0' + date.getHours()).slice(-2); // 获取小时,并保证两位数
+    const minutes = ('0' + date.getMinutes()).slice(-2); // 获取分钟,并保证两位数
+    const seconds = ('0' + date.getSeconds()).slice(-2); // 获取秒钟,并保证两位数
+
+    return `${year}-${month}-${day} ${hours}:${minutes}:${seconds}`; // 格式化输出
+};
+
+function renderDuration(submit_time, finishTime) {
+    // 确保startTime和finishTime都是有效的时间戳
+    if (!submit_time || !finishTime) return 'N/A';
+
+    // 将时间戳转换为Date对象
+    const start = new Date(submit_time);
+    const finish = new Date(finishTime);
+
+    // 计算时间差(毫秒)
+    const durationMs = finish - start;
+
+    // 将时间差转换为秒,并保留一位小数
+    const durationSec = (durationMs / 1000).toFixed(1);
+
+    // 设置颜色:大于60秒则为红色,小于等于60秒则为绿色
+    const color = durationSec > 60 ? 'red' : 'green';
+
+    // 返回带有样式的颜色标签
+    return (
+        <Tag color={color} size="large">
+            {durationSec} 秒
+        </Tag>
+    );
+}
+
+const LogsTable = () => {
+    const [isModalOpen, setIsModalOpen] = useState(false);
+    const [modalContent, setModalContent] = useState('');
+    const isAdminUser = isAdmin();
+    const columns = [
+        {
+            title: "提交时间",
+            dataIndex: 'submit_time',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {text ? renderTimestamp(text) : "-"}
+                    </div>
+                );
+            },
+        },
+        {
+            title: "结束时间",
+            dataIndex: 'finish_time',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {text ? renderTimestamp(text) : "-"}
+                    </div>
+                );
+            },
+        },
+        {
+            title: '进度',
+            dataIndex: 'progress',
+            width: 50,
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {
+                            // 转换例如100%为数字100,如果text未定义,返回0
+                            isNaN(text.replace('%', '')) ? text : <Progress width={42} type="circle" showInfo={true} percent={Number(text.replace('%', '') || 0)} aria-label="drawing progress" />
+                        }
+                    </div>
+                );
+            },
+        },
+        {
+            title: '花费时间',
+            dataIndex: 'finish_time', // 以finish_time作为dataIndex
+            key: 'finish_time',
+            render: (finish, record) => {
+                // 假设record.start_time是存在的,并且finish是完成时间的时间戳
+                return <>
+                    {
+                        finish ? renderDuration(record.submit_time, finish) : "-"
+                    }
+                </>
+            },
+        },
+        {
+            title: "渠道",
+            dataIndex: 'channel_id',
+            className: isAdminUser ? 'tableShow' : 'tableHiddle',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        <Tag
+                            color={colors[parseInt(text) % colors.length]}
+                            size='large'
+                            onClick={() => {
+                                copyText(text); // 假设copyText是用于文本复制的函数
+                            }}
+                        >
+                            {' '}
+                            {text}{' '}
+                        </Tag>
+                    </div>
+                );
+            },
+        },
+        {
+            title: "平台",
+            dataIndex: 'platform',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {renderPlatform(text)}
+                    </div>
+                );
+            },
+        },
+        {
+            title: '类型',
+            dataIndex: 'action',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {renderType(text)}
+                    </div>
+                );
+            },
+        },
+        {
+            title: '任务ID(点击查看详情)',
+            dataIndex: 'task_id',
+            render: (text, record, index) => {
+                return (<Typography.Text
+                    ellipsis={{ showTooltip: true }}
+                    //style={{width: 100}}
+                    onClick={() => {
+                        setModalContent(JSON.stringify(record, null, 2));
+                        setIsModalOpen(true);
+                    }}
+                >
+                    <div>
+                        {text}
+                    </div>
+                </Typography.Text>);
+            },
+        },
+        {
+            title: '任务状态',
+            dataIndex: 'status',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        {renderStatus(text)}
+                    </div>
+                );
+            },
+        },
+
+        {
+            title: '失败原因',
+            dataIndex: 'fail_reason',
+            render: (text, record, index) => {
+                // 如果text未定义,返回替代文本,例如空字符串''或其他
+                if (!text) {
+                    return '无';
+                }
+
+                return (
+                    <Typography.Text
+                        ellipsis={{ showTooltip: true }}
+                        style={{ width: 100 }}
+                        onClick={() => {
+                            setModalContent(text);
+                            setIsModalOpen(true);
+                        }}
+                    >
+                        {text}
+                    </Typography.Text>
+                );
+            }
+        }
+    ];
+
+    const [logs, setLogs] = useState([]);
+    const [loading, setLoading] = useState(true);
+    const [activePage, setActivePage] = useState(1);
+    const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
+    const [logType] = useState(0);
+
+    let now = new Date();
+    // 初始化start_timestamp为前一天
+    let zeroNow = new Date(now.getFullYear(), now.getMonth(), now.getDate());
+    const [inputs, setInputs] = useState({
+        channel_id: '',
+        task_id: '',
+        start_timestamp: timestamp2string(zeroNow.getTime() /1000),
+        end_timestamp: '',
+    });
+    const { channel_id, task_id, start_timestamp, end_timestamp } = inputs;
+
+    const handleInputChange = (value, name) => {
+        setInputs((inputs) => ({ ...inputs, [name]: value }));
+    };
+
+
+    const setLogsFormat = (logs) => {
+        for (let i = 0; i < logs.length; i++) {
+            logs[i].timestamp2string = timestamp2string(logs[i].created_at);
+            logs[i].key = '' + logs[i].id;
+        }
+        // data.key = '' + data.id
+        setLogs(logs);
+        setLogCount(logs.length + ITEMS_PER_PAGE);
+        // console.log(logCount);
+    }
+
+    const loadLogs = async (startIdx) => {
+        setLoading(true);
+
+        let url = '';
+        let localStartTimestamp = parseInt(Date.parse(start_timestamp) / 1000);
+        let localEndTimestamp = parseInt(Date.parse(end_timestamp) / 1000 );
+        if (isAdminUser) {
+            url = `/api/task/?p=${startIdx}&channel_id=${channel_id}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
+        } else {
+            url = `/api/task/self?p=${startIdx}&task_id=${task_id}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
+        }
+        const res = await API.get(url);
+        let { success, message, data } = res.data;
+        if (success) {
+            if (startIdx === 0) {
+                setLogsFormat(data);
+            } else {
+                let newLogs = [...logs];
+                newLogs.splice(startIdx * ITEMS_PER_PAGE, data.length, ...data);
+                setLogsFormat(newLogs);
+            }
+        } else {
+            showError(message);
+        }
+        setLoading(false);
+    };
+
+    const pageData = logs.slice((activePage - 1) * ITEMS_PER_PAGE, activePage * ITEMS_PER_PAGE);
+
+    const handlePageChange = page => {
+        setActivePage(page);
+        if (page === Math.ceil(logs.length / ITEMS_PER_PAGE) + 1) {
+            // In this case we have to load more data and then append them.
+            loadLogs(page - 1).then(r => {
+            });
+        }
+    };
+
+    const refresh = async () => {
+        // setLoading(true);
+        setActivePage(1);
+        await loadLogs(0);
+    };
+
+    const copyText = async (text) => {
+        if (await copy(text)) {
+            showSuccess('已复制:' + text);
+        } else {
+            // setSearchKeyword(text);
+            Modal.error({ title: "无法复制到剪贴板,请手动复制", content: text });
+        }
+    }
+
+    useEffect(() => {
+        refresh().then();
+    }, [logType]);
+
+    const renderType = (type) => {
+        switch (type) {
+            case 'MUSIC':
+                return <Label basic color='grey'> 生成音乐 </Label>;
+            case 'LYRICS':
+                return <Label basic color='pink'> 生成歌词 </Label>;
+
+            default:
+                return <Label basic color='black'> 未知 </Label>;
+        }
+    }
+
+    const renderPlatform = (type) => {
+        switch (type) {
+            case "suno":
+                return <Label basic color='green'> Suno </Label>;
+            default:
+                return <Label basic color='black'> 未知 </Label>;
+        }
+    }
+
+    const renderStatus = (type) => {
+        switch (type) {
+            case 'SUCCESS':
+                return <Label basic color='green'> 成功 </Label>;
+            case 'NOT_START':
+                return <Label basic color='black'> 未启动 </Label>;
+            case 'SUBMITTED':
+                return <Label basic color='yellow'> 队列中 </Label>;
+            case 'IN_PROGRESS':
+                return <Label basic color='blue'> 执行中 </Label>;
+            case 'FAILURE':
+                return <Label basic color='red'> 失败 </Label>;
+            case 'QUEUED':
+                return <Label basic color='red'> 排队中 </Label>;
+            case 'UNKNOWN':
+                return <Label basic color='red'> 未知 </Label>;
+            case '':
+                return <Label basic color='black'> 正在提交 </Label>;
+            default:
+                return <Label basic color='black'> 未知 </Label>;
+        }
+    }
+
+    return (
+        <>
+
+            <Layout>
+                <Form layout='horizontal' labelPosition='inset'>
+                    <>
+                        {isAdminUser && <Form.Input field="channel_id" label='渠道 ID' style={{ width: '236px', marginBottom: '10px' }} value={channel_id}
+                                                    placeholder={'可选值'} name='channel_id'
+                                                    onChange={value => handleInputChange(value, 'channel_id')} />
+                        }
+                        <Form.Input field="task_id" label={"任务 ID"} style={{ width: '236px', marginBottom: '10px' }} value={task_id}
+                            placeholder={"可选值"}
+                            name='task_id'
+                            onChange={value => handleInputChange(value, 'task_id')} />
+
+                        <Form.DatePicker field="start_timestamp" label={"起始时间"} style={{ width: '236px', marginBottom: '10px' }}
+                            initValue={start_timestamp}
+                            value={start_timestamp} type='dateTime'
+                            name='start_timestamp'
+                            onChange={value => handleInputChange(value, 'start_timestamp')} />
+                        <Form.DatePicker field="end_timestamp" fluid label={"结束时间"} style={{ width: '236px', marginBottom: '10px' }}
+                            initValue={end_timestamp}
+                            value={end_timestamp} type='dateTime'
+                            name='end_timestamp'
+                            onChange={value => handleInputChange(value, 'end_timestamp')} />
+                        <Button label={"查询"} type="primary" htmlType="submit" className="btn-margin-right"
+                            onClick={refresh}>查询</Button>
+                    </>
+                </Form>
+                <Card>
+                    <Table columns={columns} dataSource={pageData} pagination={{
+                        currentPage: activePage,
+                        pageSize: ITEMS_PER_PAGE,
+                        total: logCount,
+                        pageSizeOpts: [10, 20, 50, 100],
+                        onPageChange: handlePageChange,
+                    }} loading={loading} />
+                </Card>
+                <Modal
+                    visible={isModalOpen}
+                    onOk={() => setIsModalOpen(false)}
+                    onCancel={() => setIsModalOpen(false)}
+                    closable={null}
+                    bodyStyle={{ height: '400px', overflow: 'auto' }} // 设置模态框内容区域样式
+                    width={800} // 设置模态框宽度
+                >
+                    <p style={{ whiteSpace: 'pre-line' }}>{modalContent}</p>
+                </Modal>
+            </Layout>
+        </>
+    );
+};
+
+export default LogsTable;

+ 7 - 0
web/src/constants/channel.constants.js

@@ -14,6 +14,13 @@ export const CHANNEL_OPTIONS = [
     color: 'blue',
     label: 'Midjourney Proxy Plus',
   },
+  {
+    key: 36,
+    text: 'Suno API',
+    value: 36,
+    color: 'purple',
+    label: 'Suno API',
+  },
   { key: 4, text: 'Ollama', value: 4, color: 'grey', label: 'Ollama' },
   {
     key: 14,

+ 1 - 0
web/src/helpers/data.js

@@ -6,6 +6,7 @@ export function setStatusData(data) {
   localStorage.setItem('quota_per_unit', data.quota_per_unit);
   localStorage.setItem('display_in_currency', data.display_in_currency);
   localStorage.setItem('enable_drawing', data.enable_drawing);
+  localStorage.setItem('enable_task', data.enable_task);
   localStorage.setItem('enable_data_export', data.enable_data_export);
   localStorage.setItem(
     'data_export_default_time',

+ 32 - 1
web/src/pages/Channel/EditChannel.js

@@ -126,6 +126,12 @@ const EditChannel = (props) => {
             'mj_uploads',
           ];
           break;
+        case 36:
+          localModels = [
+            'suno_music',
+            'suno_lyrics',
+          ];
+          break;
         default:
           localModels = getChannelModels(value);
           break;
@@ -513,6 +519,31 @@ const EditChannel = (props) => {
               />
             </>
           )}
+          {inputs.type === 36 && (
+              <>
+                <div style={{ marginTop: 10 }}>
+                  <Banner
+                      type={'info'}
+                      description={
+                        <>
+                          Suno 非官方 API,https://github.com/Suno-API/Suno-API
+                        </>
+                      }
+                  ></Banner>
+                </div>
+                <Input
+                    name='base_url'
+                    placeholder={
+                      '需要输入到 /submit 前的路径,通常就是域名 + /suno,例如:https://sunoapi.com/suno '
+                    }
+                    onChange={(value) => {
+                      handleInputChange('base_url', value);
+                    }}
+                    value={inputs.base_url}
+                    autoComplete='new-password'
+                />
+              </>
+          )}
           <div style={{ marginTop: 10 }}>
             <Typography.Text strong>名称:</Typography.Text>
           </div>
@@ -758,7 +789,7 @@ const EditChannel = (props) => {
               </Space>
             </div>
           )}
-          {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
+          {inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && inputs.type !== 36 && (
             <>
               <div style={{ marginTop: 10 }}>
                 <Typography.Text strong>代理:</Typography.Text>

+ 10 - 0
web/src/pages/Task/index.js

@@ -0,0 +1,10 @@
+import React from 'react';
+import TaskLogsTable from "../../components/TaskLogsTable.js";
+
+const Task = () => (
+  <>
+    <TaskLogsTable />
+  </>
+);
+
+export default Task;