ソースを参照

feat: suno api 支持

Xiangyuan Liu 1 年間 前
コミット
1e8abc7027

+ 2 - 0
common/constants.go

@@ -208,8 +208,10 @@ const (
 	ChannelTypeAws            = 33
 	ChannelTypeCohere         = 34
 	ChannelTypeMiniMax        = 35
+	ChannelTypeSuno           = 36
 
 	ChannelTypeDummy // this one is only for count, do not add any channel after this
+
 )
 
 var ChannelBaseURLs = []string{

+ 18 - 0
constant/task.go

@@ -0,0 +1,18 @@
+package constant
+
+type TaskPlatform string
+
+const (
+	TaskPlatformSuno       TaskPlatform = "suno"
+	TaskPlatformMidjourney              = "mj"
+)
+
+const (
+	SunoActionMusic  = "MUSIC"
+	SunoActionLyrics = "LYRICS"
+)
+
+var SunoModel2Action = map[string]string{
+	"suno_music":  SunoActionMusic,
+	"suno_lyrics": SunoActionLyrics,
+}

+ 91 - 0
controller/relay.go

@@ -190,3 +190,94 @@ func RelayNotFound(c *gin.Context) {
 		"error": err,
 	})
 }
+
+func RelayTask(c *gin.Context) {
+	retryTimes := common.RetryTimes
+	channelId := c.GetInt("channel_id")
+	relayMode := c.GetInt("relay_mode")
+	group := c.GetString("group")
+	originalModel := c.GetString("original_model")
+	c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
+	taskErr := taskRelayHandler(c, relayMode)
+	if taskErr == nil {
+		retryTimes = 0
+	}
+	for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
+		channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
+		if err != nil {
+			common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
+			break
+		}
+		channelId = channel.Id
+		useChannel := c.GetStringSlice("use_channel")
+		useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
+		c.Set("use_channel", useChannel)
+		common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
+		middleware.SetupContextForSelectedChannel(c, channel, originalModel)
+
+		requestBody, err := common.GetRequestBody(c)
+		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+		taskErr = taskRelayHandler(c, relayMode)
+	}
+	useChannel := c.GetStringSlice("use_channel")
+	if len(useChannel) > 1 {
+		retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
+		common.LogInfo(c.Request.Context(), retryLogStr)
+	}
+	if taskErr != nil {
+		if taskErr.StatusCode == http.StatusTooManyRequests {
+			taskErr.Message = "当前分组上游负载已饱和,请稍后再试"
+		}
+		c.JSON(taskErr.StatusCode, taskErr)
+	}
+}
+
+func taskRelayHandler(c *gin.Context, relayMode int) *dto.TaskError {
+	var err *dto.TaskError
+	switch relayMode {
+	case relayconstant.RelayModeSunoFetch, relayconstant.RelayModeSunoFetchByID:
+		err = relay.RelayTaskFetch(c, relayMode)
+	default:
+		err = relay.RelayTaskSubmit(c, relayMode)
+	}
+	return err
+}
+
+func shouldRetryTaskRelay(c *gin.Context, channelId int, taskErr *dto.TaskError, retryTimes int) bool {
+	if taskErr == nil {
+		return false
+	}
+	if retryTimes <= 0 {
+		return false
+	}
+	if _, ok := c.Get("specific_channel_id"); ok {
+		return false
+	}
+	if taskErr.StatusCode == http.StatusTooManyRequests {
+		return true
+	}
+	if taskErr.StatusCode == 307 {
+		return true
+	}
+	if taskErr.StatusCode/100 == 5 {
+		// 超时不重试
+		if taskErr.StatusCode == 504 || taskErr.StatusCode == 524 {
+			return false
+		}
+		return true
+	}
+	if taskErr.StatusCode == http.StatusBadRequest {
+		return false
+	}
+	if taskErr.StatusCode == 408 {
+		// azure处理超时不重试
+		return false
+	}
+	if taskErr.LocalError {
+		return false
+	}
+	if taskErr.StatusCode/100 == 2 {
+		return false
+	}
+	return true
+}

+ 92 - 0
controller/task.go

@@ -0,0 +1,92 @@
+package controller
+
+import (
+	"github.com/gin-gonic/gin"
+	"log"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/model"
+	"strconv"
+	"time"
+)
+
+func UpdateTaskBulk() {
+	//revocer
+	//imageModel := "midjourney"
+	for {
+		time.Sleep(time.Duration(15) * time.Second)
+		common.SysLog("任务进度轮询开始")
+		allTasks := model.GetAllUnFinishSyncTasks(500)
+		platformTask := make(map[constant.TaskPlatform][]*model.Task)
+		for _, t := range allTasks {
+			platformTask[t.Platform] = append(platformTask[t.Platform], t)
+		}
+		for platform, tasks := range platformTask {
+			UpdateTaskByPlatform(platform, tasks)
+		}
+		common.SysLog("任务进度轮询完成")
+	}
+}
+
+func GetAllMidjourney(c *gin.Context) {
+	p, _ := strconv.Atoi(c.Query("p"))
+	if p < 0 {
+		p = 0
+	}
+
+	// 解析其他查询参数
+	queryParams := model.TaskQueryParams{
+		ChannelID:      c.Query("channel_id"),
+		MjID:           c.Query("mj_id"),
+		StartTimestamp: c.Query("start_timestamp"),
+		EndTimestamp:   c.Query("end_timestamp"),
+	}
+
+	logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
+	if logs == nil {
+		logs = make([]*model.Midjourney, 0)
+	}
+	if constant.MjForwardUrlEnabled {
+		for i, midjourney := range logs {
+			midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
+			logs[i] = midjourney
+		}
+	}
+	c.JSON(200, gin.H{
+		"success": true,
+		"message": "",
+		"data":    logs,
+	})
+}
+
+func GetUserMidjourney(c *gin.Context) {
+	p, _ := strconv.Atoi(c.Query("p"))
+	if p < 0 {
+		p = 0
+	}
+
+	userId := c.GetInt("id")
+	log.Printf("userId = %d \n", userId)
+
+	queryParams := model.TaskQueryParams{
+		MjID:           c.Query("mj_id"),
+		StartTimestamp: c.Query("start_timestamp"),
+		EndTimestamp:   c.Query("end_timestamp"),
+	}
+
+	logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage, queryParams)
+	if logs == nil {
+		logs = make([]*model.Midjourney, 0)
+	}
+	if constant.MjForwardUrlEnabled {
+		for i, midjourney := range logs {
+			midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId
+			logs[i] = midjourney
+		}
+	}
+	c.JSON(200, gin.H{
+		"success": true,
+		"message": "",
+		"data":    logs,
+	})
+}

+ 129 - 0
dto/suno.go

@@ -0,0 +1,129 @@
+package dto
+
+import (
+	"encoding/json"
+)
+
+type TaskData interface {
+	SunoDataResponse | []SunoDataResponse | string | any
+}
+
+type SunoSubmitReq struct {
+	GptDescriptionPrompt string  `json:"gpt_description_prompt,omitempty"`
+	Prompt               string  `json:"prompt,omitempty"`
+	Mv                   string  `json:"mv,omitempty"`
+	Title                string  `json:"title,omitempty"`
+	Tags                 string  `json:"tags,omitempty"`
+	ContinueAt           float64 `json:"continue_at,omitempty"`
+	TaskID               string  `json:"task_id,omitempty"`
+	ContinueClipId       string  `json:"continue_clip_id,omitempty"`
+	MakeInstrumental     bool    `json:"make_instrumental"`
+}
+
+type FetchReq struct {
+	IDs []string `json:"ids"`
+}
+
+type SunoDataResponse struct {
+	TaskID     string          `json:"task_id" gorm:"type:varchar(50);index"`
+	Action     string          `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
+	Status     string          `json:"status" gorm:"type:varchar(20);index"` // 任务状态, submitted, queueing, processing, success, failed
+	FailReason string          `json:"fail_reason"`
+	SubmitTime int64           `json:"submit_time" gorm:"index"`
+	StartTime  int64           `json:"start_time" gorm:"index"`
+	FinishTime int64           `json:"finish_time" gorm:"index"`
+	Data       json.RawMessage `json:"data" gorm:"type:json"`
+}
+
+type SunoSong struct {
+	ID                string       `json:"id"`
+	VideoURL          string       `json:"video_url"`
+	AudioURL          string       `json:"audio_url"`
+	ImageURL          string       `json:"image_url"`
+	ImageLargeURL     string       `json:"image_large_url"`
+	MajorModelVersion string       `json:"major_model_version"`
+	ModelName         string       `json:"model_name"`
+	Status            string       `json:"status"`
+	Title             string       `json:"title"`
+	Text              string       `json:"text"`
+	Metadata          SunoMetadata `json:"metadata"`
+}
+
+type SunoMetadata struct {
+	Tags                 string      `json:"tags"`
+	Prompt               string      `json:"prompt"`
+	GPTDescriptionPrompt interface{} `json:"gpt_description_prompt"`
+	AudioPromptID        interface{} `json:"audio_prompt_id"`
+	Duration             interface{} `json:"duration"`
+	ErrorType            interface{} `json:"error_type"`
+	ErrorMessage         interface{} `json:"error_message"`
+}
+
+type SunoLyrics struct {
+	ID     string `json:"id"`
+	Status string `json:"status"`
+	Title  string `json:"title"`
+	Text   string `json:"text"`
+}
+
+const TaskSuccessCode = "success"
+
+type TaskResponse[T TaskData] struct {
+	Code    string `json:"code"`
+	Message string `json:"message"`
+	Data    T      `json:"data"`
+}
+
+func (t *TaskResponse[T]) IsSuccess() bool {
+	return t.Code == TaskSuccessCode
+}
+
+type TaskDto struct {
+	TaskID     string          `json:"task_id"` // 第三方id,不一定有/ song id\ Task id
+	Action     string          `json:"action"`  // 任务类型, song, lyrics, description-mode
+	Status     string          `json:"status"`  // 任务状态, submitted, queueing, processing, success, failed
+	FailReason string          `json:"fail_reason"`
+	SubmitTime int64           `json:"submit_time"`
+	StartTime  int64           `json:"start_time"`
+	FinishTime int64           `json:"finish_time"`
+	Progress   string          `json:"progress"`
+	Data       json.RawMessage `json:"data"`
+}
+
+type SunoGoAPISubmitReq struct {
+	CustomMode bool `json:"custom_mode"`
+
+	Input SunoGoAPISubmitReqInput `json:"input"`
+
+	NotifyHook string `json:"notify_hook,omitempty"`
+}
+
+type SunoGoAPISubmitReqInput struct {
+	GptDescriptionPrompt string  `json:"gpt_description_prompt"`
+	Prompt               string  `json:"prompt"`
+	Mv                   string  `json:"mv"`
+	Title                string  `json:"title"`
+	Tags                 string  `json:"tags"`
+	ContinueAt           float64 `json:"continue_at"`
+	TaskID               string  `json:"task_id"`
+	ContinueClipId       string  `json:"continue_clip_id"`
+	MakeInstrumental     bool    `json:"make_instrumental"`
+}
+
+type GoAPITaskResponse[T any] struct {
+	Code         int    `json:"code"`
+	Message      string `json:"message"`
+	Data         T      `json:"data"`
+	ErrorMessage string `json:"error_message,omitempty"`
+}
+
+type GoAPITaskResponseData struct {
+	TaskID string `json:"task_id"`
+}
+
+type GoAPIFetchResponseData struct {
+	TaskID string              `json:"task_id"`
+	Status string              `json:"status"`
+	Input  string              `json:"input"`
+	Clips  map[string]SunoSong `json:"clips"`
+}

+ 10 - 0
dto/task.go

@@ -0,0 +1,10 @@
+package dto
+
+type TaskError struct {
+	Code       string `json:"code"`
+	Message    string `json:"message"`
+	Data       any    `json:"data"`
+	StatusCode int    `json:"-"`
+	LocalError bool   `json:"-"`
+	Error      error  `json:"-"`
+}

+ 2 - 2
main.go

@@ -20,10 +20,10 @@ import (
 	_ "net/http/pprof"
 )
 
-//go:embed web/dist
+// /go:embed web/dist
 var buildFS embed.FS
 
-//go:embed web/dist/index.html
+// /go:embed web/dist/index.html
 var indexPage []byte
 
 func main() {

+ 11 - 0
middleware/distributor.go

@@ -125,6 +125,17 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 			modelRequest.Model = midjourneyModel
 		}
 		c.Set("relay_mode", relayMode)
+	} else if strings.Contains(c.Request.URL.Path, "/suno/") {
+		relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path)
+		if relayMode == relayconstant.RelayModeSunoFetch ||
+			relayMode == relayconstant.RelayModeSunoFetchByID {
+			shouldSelectChannel = false
+		} else {
+			modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action"))
+			modelRequest.Model = modelName
+		}
+		c.Set("platform", constant.TaskPlatformSuno)
+		c.Set("relay_mode", relayMode)
 	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
 		err = common.UnmarshalBodyReusable(c, &modelRequest)
 	}

+ 304 - 0
model/task.go

@@ -0,0 +1,304 @@
+package model
+
+import (
+	"database/sql/driver"
+	"encoding/json"
+	"one-api/constant"
+	commonRelay "one-api/relay/common"
+	"time"
+)
+
+type TaskStatus string
+
+const (
+	TaskStatusNotStart   TaskStatus = "NOT_START"
+	TaskStatusSubmitted             = "SUBMITTED"
+	TaskStatusQueued                = "QUEUED"
+	TaskStatusInProgress            = "IN_PROGRESS"
+	TaskStatusFailure               = "FAILURE"
+	TaskStatusSuccess               = "SUCCESS"
+	TaskStatusUnknown               = "UNKNOWN"
+)
+
+type Task struct {
+	ID         int64                 `json:"id" gorm:"primary_key;AUTO_INCREMENT"`
+	CreatedAt  int64                 `json:"created_at" gorm:"index"`
+	UpdatedAt  int64                 `json:"updated_at"`
+	TaskID     string                `json:"task_id" gorm:"type:varchar(50);index"`  // 第三方id,不一定有/ song id\ Task id
+	Platform   constant.TaskPlatform `json:"platform" gorm:"type:varchar(30);index"` // 平台
+	UserId     int                   `json:"user_id" gorm:"index"`
+	ChannelId  int                   `json:"channel_id" gorm:"index"`
+	Quota      int                   `json:"quota"`
+	Action     string                `json:"action" gorm:"type:varchar(40);index"` // 任务类型, song, lyrics, description-mode
+	Status     TaskStatus            `json:"status" gorm:"type:varchar(20);index"` // 任务状态
+	FailReason string                `json:"fail_reason"`
+	SubmitTime int64                 `json:"submit_time" gorm:"index"`
+	StartTime  int64                 `json:"start_time" gorm:"index"`
+	FinishTime int64                 `json:"finish_time" gorm:"index"`
+	Progress   string                `json:"progress" gorm:"type:varchar(20);index"`
+	Properties Properties            `json:"properties" gorm:"type:json"`
+
+	Data json.RawMessage `json:"data" gorm:"type:json"`
+}
+
+func (t *Task) SetData(data any) {
+	b, _ := json.Marshal(data)
+	t.Data = json.RawMessage(b)
+}
+
+func (t *Task) GetData(v any) error {
+	err := json.Unmarshal(t.Data, &v)
+	return err
+}
+
+type Properties struct {
+	Input string `json:"input"`
+}
+
+func (m *Properties) Scan(val interface{}) error {
+	bytesValue, _ := val.([]byte)
+	return json.Unmarshal(bytesValue, m)
+}
+
+func (m Properties) Value() (driver.Value, error) {
+	return json.Marshal(m)
+}
+
+// SyncTaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
+type SyncTaskQueryParams struct {
+	Platform       constant.TaskPlatform
+	ChannelID      string
+	TaskID         string
+	UserID         string
+	Action         string
+	Status         string
+	StartTimestamp int64
+	EndTimestamp   int64
+	UserIDs        []int
+}
+
+func InitTask(platform constant.TaskPlatform, relayInfo *commonRelay.TaskRelayInfo) *Task {
+	t := &Task{
+		UserId:     relayInfo.UserId,
+		SubmitTime: time.Now().Unix(),
+		Status:     TaskStatusNotStart,
+		Progress:   "0%",
+		ChannelId:  relayInfo.ChannelId,
+		Platform:   platform,
+	}
+	return t
+}
+
+func TaskGetAllUserTask(userId int, startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
+	var tasks []*Task
+	var err error
+
+	// 初始化查询构建器
+	query := DB.Where("user_id = ?", userId)
+
+	if queryParams.TaskID != "" {
+		query = query.Where("task_id = ?", queryParams.TaskID)
+	}
+	if queryParams.Action != "" {
+		query = query.Where("action = ?", queryParams.Action)
+	}
+	if queryParams.Status != "" {
+		query = query.Where("status = ?", queryParams.Status)
+	}
+	if queryParams.Platform != "" {
+		query = query.Where("platform = ?", queryParams.Platform)
+	}
+	if queryParams.StartTimestamp != 0 {
+		// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
+		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+	}
+	if queryParams.EndTimestamp != 0 {
+		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+	}
+
+	// 获取数据
+	err = query.Omit("channel_id").Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
+	if err != nil {
+		return nil
+	}
+
+	return tasks
+}
+
+func TaskGetAllTasks(startIdx int, num int, queryParams SyncTaskQueryParams) []*Task {
+	var tasks []*Task
+	var err error
+
+	// 初始化查询构建器
+	query := DB
+
+	// 添加过滤条件
+	if queryParams.ChannelID != "" {
+		query = query.Where("channel_id = ?", queryParams.ChannelID)
+	}
+	if queryParams.Platform != "" {
+		query = query.Where("platform = ?", queryParams.Platform)
+	}
+	if queryParams.UserID != "" {
+		query = query.Where("user_id = ?", queryParams.UserID)
+	}
+	if len(queryParams.UserIDs) != 0 {
+		query = query.Where("user_id in (?)", queryParams.UserIDs)
+	}
+	if queryParams.TaskID != "" {
+		query = query.Where("task_id = ?", queryParams.TaskID)
+	}
+	if queryParams.Action != "" {
+		query = query.Where("action = ?", queryParams.Action)
+	}
+	if queryParams.Status != "" {
+		query = query.Where("status = ?", queryParams.Status)
+	}
+	if queryParams.StartTimestamp != 0 {
+		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+	}
+	if queryParams.EndTimestamp != 0 {
+		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+	}
+
+	// 获取数据
+	err = query.Order("id desc").Limit(num).Offset(startIdx).Find(&tasks).Error
+	if err != nil {
+		return nil
+	}
+
+	return tasks
+}
+
+func GetAllUnFinishSyncTasks(limit int) []*Task {
+	var tasks []*Task
+	var err error
+	// get all tasks progress is not 100%
+	err = DB.Where("progress != ?", "100%").Limit(limit).Order("id").Find(&tasks).Error
+	if err != nil {
+		return nil
+	}
+	return tasks
+}
+
+func GetByOnlyTaskId(taskId string) (*Task, bool, error) {
+	if taskId == "" {
+		return nil, false, nil
+	}
+	var task *Task
+	var err error
+	err = DB.Where("task_id = ?", taskId).First(&task).Error
+	exist, err := RecordExist(err)
+	if err != nil {
+		return nil, false, err
+	}
+	return task, exist, err
+}
+
+func GetByTaskId(userId int, taskId string) (*Task, bool, error) {
+	if taskId == "" {
+		return nil, false, nil
+	}
+	var task *Task
+	var err error
+	err = DB.Where("user_id = ? and task_id = ?", userId, taskId).
+		First(&task).Error
+	exist, err := RecordExist(err)
+	if err != nil {
+		return nil, false, err
+	}
+	return task, exist, err
+}
+
+func GetByTaskIds(userId int, taskIds []any) ([]*Task, error) {
+	if len(taskIds) == 0 {
+		return nil, nil
+	}
+	var task []*Task
+	var err error
+	err = DB.Where("user_id = ? and task_id in (?)", userId, taskIds).
+		Find(&task).Error
+	if err != nil {
+		return nil, err
+	}
+	return task, nil
+}
+
+func TaskUpdateProgress(id int64, progress string) error {
+	return DB.Model(&Task{}).Where("id = ?", id).Update("progress", progress).Error
+}
+
+func (Task *Task) Insert() error {
+	var err error
+	err = DB.Create(Task).Error
+	return err
+}
+
+func (Task *Task) Update() error {
+	var err error
+	err = DB.Save(Task).Error
+	return err
+}
+
+func TaskBulkUpdate(TaskIds []string, params map[string]any) error {
+	if len(TaskIds) == 0 {
+		return nil
+	}
+	return DB.Model(&Task{}).
+		Where("task_id in (?)", TaskIds).
+		Updates(params).Error
+}
+
+func TaskBulkUpdateByTaskIds(taskIDs []int64, params map[string]any) error {
+	if len(taskIDs) == 0 {
+		return nil
+	}
+	return DB.Model(&Task{}).
+		Where("id in (?)", taskIDs).
+		Updates(params).Error
+}
+
+func TaskBulkUpdateByID(ids []int64, params map[string]any) error {
+	if len(ids) == 0 {
+		return nil
+	}
+	return DB.Model(&Task{}).
+		Where("id in (?)", ids).
+		Updates(params).Error
+}
+
+type TaskQuotaUsage struct {
+	Mode  string  `json:"mode"`
+	Count float64 `json:"count"`
+}
+
+func SumUsedTaskQuota(queryParams SyncTaskQueryParams) (stat []TaskQuotaUsage, err error) {
+	query := DB.Model(Task{})
+	// 添加过滤条件
+	if queryParams.ChannelID != "" {
+		query = query.Where("channel_id = ?", queryParams.ChannelID)
+	}
+	if queryParams.UserID != "" {
+		query = query.Where("user_id = ?", queryParams.UserID)
+	}
+	if len(queryParams.UserIDs) != 0 {
+		query = query.Where("user_id in (?)", queryParams.UserIDs)
+	}
+	if queryParams.TaskID != "" {
+		query = query.Where("task_id = ?", queryParams.TaskID)
+	}
+	if queryParams.Action != "" {
+		query = query.Where("action = ?", queryParams.Action)
+	}
+	if queryParams.Status != "" {
+		query = query.Where("status = ?", queryParams.Status)
+	}
+	if queryParams.StartTimestamp != 0 {
+		query = query.Where("submit_time >= ?", queryParams.StartTimestamp)
+	}
+	if queryParams.EndTimestamp != 0 {
+		query = query.Where("submit_time <= ?", queryParams.EndTimestamp)
+	}
+	err = query.Select("mode, sum(quota) as count").Group("mode").Find(&stat).Error
+	return stat, err
+}

+ 12 - 0
model/utils.go

@@ -1,6 +1,8 @@
 package model
 
 import (
+	"errors"
+	"gorm.io/gorm"
 	"one-api/common"
 	"sync"
 	"time"
@@ -75,3 +77,13 @@ func batchUpdate() {
 	}
 	common.SysLog("batch update finished")
 }
+
+func RecordExist(err error) (bool, error) {
+	if err == nil {
+		return true, nil
+	}
+	if errors.Is(err, gorm.ErrRecordNotFound) {
+		return false, nil
+	}
+	return false, err
+}

+ 18 - 0
relay/channel/adapter.go

@@ -19,3 +19,21 @@ type Adaptor interface {
 	GetModelList() []string
 	GetChannelName() string
 }
+
+type TaskAdaptor interface {
+	Init(info *relaycommon.TaskRelayInfo)
+
+	ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
+
+	BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
+	BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
+	BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
+
+	DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
+	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
+
+	GetModelList() []string
+	GetChannelName() string
+
+	// FetchTask
+}

+ 24 - 0
relay/channel/api_request.go

@@ -50,3 +50,27 @@ func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
 	_ = c.Request.Body.Close()
 	return resp, nil
 }
+
+func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+	fullRequestURL, err := a.BuildRequestURL(info)
+	if err != nil {
+		return nil, err
+	}
+	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
+	if err != nil {
+		return nil, fmt.Errorf("new request failed: %w", err)
+	}
+	req.GetBody = func() (io.ReadCloser, error) {
+		return io.NopCloser(requestBody), nil
+	}
+
+	err = a.BuildRequestHeader(c, req, info)
+	if err != nil {
+		return nil, fmt.Errorf("setup request header failed: %w", err)
+	}
+	resp, err := doRequest(c, req)
+	if err != nil {
+		return nil, fmt.Errorf("do request failed: %w", err)
+	}
+	return resp, nil
+}

+ 147 - 0
relay/channel/task/suno/adaptor.go

@@ -0,0 +1,147 @@
+package suno
+
+import (
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
+	"one-api/relay/channel"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+	"strings"
+)
+
+type TaskAdaptor struct {
+	ChannelType int
+	Action      string
+}
+
+func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
+	a.ChannelType = info.ChannelType
+
+}
+
+func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
+	action := strings.ToUpper(c.Param("action"))
+
+	var sunoRequest *dto.SunoSubmitReq
+	err := common.UnmarshalBodyReusable(c, &sunoRequest)
+	if err != nil {
+		taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+		return
+	}
+	err = actionValidate(c, sunoRequest, action)
+	if err != nil {
+		taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
+		return
+	}
+
+	if sunoRequest.ContinueClipId != "" {
+		if sunoRequest.TaskID == "" {
+			taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
+			return
+		}
+		info.OriginTaskID = sunoRequest.TaskID
+	}
+
+	a.Action = info.Action
+	c.Set("task_request", sunoRequest)
+	return nil
+}
+
+func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
+	baseURL := common.ChannelBaseURLs[info.ChannelType]
+	if info.BaseUrl != "" {
+		baseURL = info.BaseUrl
+	}
+	fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/submit/"+info.Action)
+	return fullRequestURL, nil
+}
+
+func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
+	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
+	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
+	req.Header.Set("Authorization", "Bearer "+info.ApiKey)
+	return nil
+}
+
+func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
+	sunoRequest, ok := c.Get("task_request")
+	if !ok {
+		err := common.UnmarshalBodyReusable(c, &sunoRequest)
+		if err != nil {
+			return nil, err
+		}
+	}
+	data, err := json.Marshal(sunoRequest)
+	if err != nil {
+		return nil, err
+	}
+	return bytes.NewReader(data), nil
+}
+
+func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
+	return channel.DoTaskApiRequest(a, c, info, requestBody)
+}
+
+func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+	var sunoResponse dto.TaskResponse[string]
+	err = json.Unmarshal(responseBody, &sunoResponse)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+	if !sunoResponse.IsSuccess() {
+		taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
+		return
+	}
+
+	for k, v := range resp.Header {
+		c.Writer.Header().Set(k, v[0])
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+
+	_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+
+	return sunoResponse.Data, nil, nil
+}
+
+func (a *TaskAdaptor) GetModelList() []string {
+	return ModelList
+}
+
+func (a *TaskAdaptor) GetChannelName() string {
+	return ChannelName
+}
+
+func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
+	switch action {
+	case constant.SunoActionMusic:
+		if sunoRequest.Mv == "" {
+			sunoRequest.Mv = "chirp-v3-0"
+		}
+	case constant.SunoActionLyrics:
+		if sunoRequest.Prompt == "" {
+			err = fmt.Errorf("prompt_empty")
+			return
+		}
+	default:
+		err = fmt.Errorf("invalid_action")
+	}
+	return
+}

+ 7 - 0
relay/channel/task/suno/models.go

@@ -0,0 +1,7 @@
+package suno
+
+var ModelList = []string{
+	"suno_music", "suno_lyrics",
+}
+
+var ChannelName = "suno"

+ 50 - 0
relay/common/relay_info.go

@@ -72,3 +72,53 @@ func (info *RelayInfo) SetPromptTokens(promptTokens int) {
 func (info *RelayInfo) SetIsStream(isStream bool) {
 	info.IsStream = isStream
 }
+
+type TaskRelayInfo struct {
+	ChannelType       int
+	ChannelId         int
+	TokenId           int
+	UserId            int
+	Group             string
+	StartTime         time.Time
+	ApiType           int
+	RelayMode         int
+	UpstreamModelName string
+	RequestURLPath    string
+	ApiKey            string
+	BaseUrl           string
+
+	Action       string
+	OriginTaskID string
+
+	ConsumeQuota bool
+}
+
+func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
+	channelType := c.GetInt("channel")
+	channelId := c.GetInt("channel_id")
+
+	tokenId := c.GetInt("token_id")
+	userId := c.GetInt("id")
+	group := c.GetString("group")
+	startTime := time.Now()
+
+	apiType, _ := constant.ChannelType2APIType(channelType)
+
+	info := &TaskRelayInfo{
+		RelayMode:      constant.Path2RelayMode(c.Request.URL.Path),
+		BaseUrl:        c.GetString("base_url"),
+		RequestURLPath: c.Request.URL.String(),
+		ChannelType:    channelType,
+		ChannelId:      channelId,
+		TokenId:        tokenId,
+		UserId:         userId,
+		Group:          group,
+		StartTime:      startTime,
+		ApiType:        apiType,
+		ApiKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
+	}
+	if info.BaseUrl == "" {
+		info.BaseUrl = common.ChannelBaseURLs[channelType]
+	}
+	return info
+}

+ 19 - 1
relay/constant/relay_mode.go

@@ -1,6 +1,9 @@
 package constant
 
-import "strings"
+import (
+	"net/http"
+	"strings"
+)
 
 const (
 	RelayModeUnknown = iota
@@ -26,6 +29,9 @@ const (
 	RelayModeMidjourneyModal
 	RelayModeMidjourneyShorten
 	RelayModeSwapFace
+	RelayModeSunoFetch
+	RelayModeSunoFetchByID
+	RelayModeSunoSubmit
 )
 
 func Path2RelayMode(path string) int {
@@ -89,3 +95,15 @@ func Path2RelayModeMidjourney(path string) int {
 	}
 	return relayMode
 }
+
+func Path2RelaySuno(method, path string) int {
+	relayMode := RelayModeUnknown
+	if method == http.MethodPost && strings.HasSuffix(path, "/fetch") {
+		relayMode = RelayModeSunoFetch
+	} else if method == http.MethodGet && strings.Contains(path, "/fetch/") {
+		relayMode = RelayModeSunoFetchByID
+	} else if strings.Contains(path, "/submit/") {
+		relayMode = RelayModeSunoSubmit
+	}
+	return relayMode
+}

+ 12 - 0
relay/relay_adaptor.go

@@ -1,6 +1,7 @@
 package relay
 
 import (
+	commonconstant "one-api/constant"
 	"one-api/relay/channel"
 	"one-api/relay/channel/ali"
 	"one-api/relay/channel/aws"
@@ -12,6 +13,7 @@ import (
 	"one-api/relay/channel/openai"
 	"one-api/relay/channel/palm"
 	"one-api/relay/channel/perplexity"
+	"one-api/relay/channel/task/suno"
 	"one-api/relay/channel/tencent"
 	"one-api/relay/channel/xunfei"
 	"one-api/relay/channel/zhipu"
@@ -54,3 +56,13 @@ func GetAdaptor(apiType int) channel.Adaptor {
 	}
 	return nil
 }
+
+func GetTaskAdaptor(platform commonconstant.TaskPlatform) channel.TaskAdaptor {
+	switch platform {
+	//case constant.APITypeAIProxyLibrary:
+	//	return &aiproxy.Adaptor{}
+	case commonconstant.TaskPlatformSuno:
+		return &suno.TaskAdaptor{}
+	}
+	return nil
+}

+ 242 - 0
relay/relay_task.go

@@ -0,0 +1,242 @@
+package relay
+
+import (
+	"bytes"
+	"context"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"github.com/gin-gonic/gin"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/constant"
+	"one-api/dto"
+	"one-api/model"
+	relaycommon "one-api/relay/common"
+	relayconstant "one-api/relay/constant"
+	"one-api/service"
+)
+
+/*
+Task 任务通过平台、Action 区分任务
+*/
+func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
+	platform := constant.TaskPlatform(c.GetString("platform"))
+	relayInfo := relaycommon.GenTaskRelayInfo(c)
+
+	adaptor := GetTaskAdaptor(platform)
+	if adaptor == nil {
+		return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
+	}
+	adaptor.Init(relayInfo)
+	// get & validate taskRequest 获取并验证文本请求
+	taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
+	if taskErr != nil {
+		return
+	}
+
+	modelName := service.CoverTaskActionToModelName(platform, relayInfo.Action)
+	modelPrice, success := common.GetModelPrice(modelName, true)
+	if !success {
+		defaultPrice, ok := common.GetDefaultModelRatioMap()[modelName]
+		if !ok {
+			modelPrice = 0.1
+		} else {
+			modelPrice = defaultPrice
+		}
+	}
+
+	// 预扣
+	groupRatio := common.GetGroupRatio(relayInfo.Group)
+	ratio := modelPrice * groupRatio
+	userQuota, err := model.CacheGetUserQuota(relayInfo.UserId)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
+		return
+	}
+	quota := int(ratio * common.QuotaPerUnit)
+	if userQuota-quota < 0 {
+		taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
+		return
+	}
+
+	if relayInfo.OriginTaskID != "" {
+		originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
+		if err != nil {
+			taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
+			return
+		}
+		if !exist {
+			taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
+			return
+		}
+		if originTask.ChannelId != relayInfo.ChannelId {
+			channel, err := model.GetChannelById(originTask.ChannelId, true)
+			if err != nil {
+				taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
+				return
+			}
+			if channel.Status != common.ChannelStatusEnabled {
+				return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
+			}
+			c.Set("base_url", channel.GetBaseURL())
+			c.Set("channel_id", originTask.ChannelId)
+			c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
+
+			relayInfo.BaseUrl = channel.GetBaseURL()
+			relayInfo.ChannelId = originTask.ChannelId
+		}
+	}
+
+	// build body
+	requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
+		return
+	}
+	// do request
+	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
+		return
+	}
+	// handle response
+	if resp != nil && resp.StatusCode != http.StatusOK {
+		responseBody, _ := io.ReadAll(resp.Body)
+		taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
+		return
+	}
+
+	defer func(ctx context.Context) {
+		// release quota
+		if relayInfo.ConsumeQuota && taskErr == nil {
+			err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quota, 0, true)
+			if err != nil {
+				common.SysError("error consuming token remain quota: " + err.Error())
+			}
+			err = model.CacheUpdateUserQuota(relayInfo.UserId)
+			if err != nil {
+				common.SysError("error update user quota cache: " + err.Error())
+			}
+			if quota != 0 {
+				tokenName := c.GetString("token_name")
+				logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, groupRatio, relayInfo.Action)
+				other := make(map[string]interface{})
+				other["model_price"] = modelPrice
+				other["group_ratio"] = groupRatio
+				model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, 0, 0, modelName, tokenName, quota, logContent, relayInfo.TokenId, userQuota, 0, false, other)
+				model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
+				model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
+			}
+		}
+	}(c.Request.Context())
+
+	taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
+	if taskErr != nil {
+		return
+	}
+	relayInfo.ConsumeQuota = true
+	// insert task
+	task := model.InitTask(constant.TaskPlatformSuno, relayInfo)
+	task.TaskID = taskID
+	task.Quota = quota
+	task.Data = taskData
+	err = task.Insert()
+	if err != nil {
+		taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
+		return
+	}
+	return nil
+}
+
+var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
+	relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
+	relayconstant.RelayModeSunoFetch:     sunoFetchRespBodyBuilder,
+}
+
+func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
+	respBuilder, ok := fetchRespBuilders[relayMode]
+	if !ok {
+		taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
+	}
+
+	respBody, taskErr := respBuilder(c)
+	if taskErr != nil {
+		return taskErr
+	}
+
+	c.Writer.Header().Set("Content-Type", "application/json")
+	_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
+	if err != nil {
+		taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
+		return
+	}
+	return
+}
+
+func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
+	userId := c.GetInt("id")
+	var condition = struct {
+		IDs    []any  `json:"ids"`
+		Action string `json:"action"`
+	}{}
+	err := c.BindJSON(&condition)
+	if err != nil {
+		taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
+		return
+	}
+	var tasks []any
+	if len(condition.IDs) > 0 {
+		taskModels, err := model.GetByTaskIds(userId, condition.IDs)
+		if err != nil {
+			taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
+			return
+		}
+		for _, task := range taskModels {
+			tasks = append(tasks, TaskModel2Dto(task))
+		}
+	} else {
+		tasks = make([]any, 0)
+	}
+	respBody, err = json.Marshal(dto.TaskResponse[[]any]{
+		Code: "success",
+		Data: tasks,
+	})
+	return
+}
+
+func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
+	taskId := c.Param("id")
+	userId := c.GetInt("id")
+
+	originTask, exist, err := model.GetByTaskId(userId, taskId)
+	if err != nil {
+		taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
+		return
+	}
+	if !exist {
+		taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
+		return
+	}
+
+	respBody, err = json.Marshal(dto.TaskResponse[any]{
+		Code: "success",
+		Data: TaskModel2Dto(originTask),
+	})
+	return
+}
+
+func TaskModel2Dto(task *model.Task) *dto.TaskDto {
+	return &dto.TaskDto{
+		TaskID:     task.TaskID,
+		Action:     task.Action,
+		Status:     string(task.Status),
+		FailReason: task.FailReason,
+		SubmitTime: task.SubmitTime,
+		StartTime:  task.StartTime,
+		FinishTime: task.FinishTime,
+		Progress:   task.Progress,
+		Data:       task.Data,
+	}
+}

+ 9 - 0
router/relay-router.go

@@ -50,6 +50,15 @@ func SetRelayRouter(router *gin.Engine) {
 	relayMjModeRouter := router.Group("/:mode/mj")
 	registerMjRouterGroup(relayMjModeRouter)
 	//relayMjRouter.Use()
+
+	relaySunoRouter := router.Group("/suno")
+	relaySunoRouter.Use(middleware.TokenAuth(), middleware.Distribute())
+	{
+		relaySunoRouter.POST("/submit/:action", controller.RelayTask)
+		relaySunoRouter.POST("/fetch", controller.RelayTask)
+		relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
+	}
+
 }
 
 func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {

+ 26 - 0
service/error.go

@@ -105,3 +105,29 @@ func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMapping
 		openaiErr.StatusCode = intCode
 	}
 }
+
+func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError {
+	openaiErr := TaskErrorWrapper(err, code, statusCode)
+	openaiErr.LocalError = true
+	return openaiErr
+}
+
+func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError {
+	text := err.Error()
+
+	// 定义一个正则表达式匹配URL
+	if strings.Contains(text, "Post") || strings.Contains(text, "dial") {
+		common.SysLog(fmt.Sprintf("error: %s", text))
+		text = "请求上游地址失败"
+	}
+	//避免暴露内部错误
+
+	taskError := &dto.TaskError{
+		Code:       code,
+		Message:    text,
+		StatusCode: statusCode,
+		Error:      err,
+	}
+
+	return taskError
+}

+ 10 - 0
service/task.go

@@ -0,0 +1,10 @@
+package service
+
+import (
+	"one-api/constant"
+	"strings"
+)
+
+func CoverTaskActionToModelName(platform constant.TaskPlatform, action string) string {
+	return strings.ToLower(string(platform)) + "_" + strings.ToLower(action)
+}