| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- package doubao
- import (
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/model"
- "github.com/QuantumNous/new-api/relay/channel"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/service"
- "github.com/gin-gonic/gin"
- "github.com/pkg/errors"
- )
- // ============================
- // Request / Response structures
- // ============================
- type ContentItem struct {
- Type string `json:"type"` // "text" or "image_url"
- Text string `json:"text,omitempty"` // for text type
- ImageURL *ImageURL `json:"image_url,omitempty"` // for image_url type
- }
- type ImageURL struct {
- URL string `json:"url"`
- }
- type requestPayload struct {
- Model string `json:"model"`
- Content []ContentItem `json:"content"`
- }
- type responsePayload struct {
- ID string `json:"id"` // task_id
- }
- type responseTask struct {
- ID string `json:"id"`
- Model string `json:"model"`
- Status string `json:"status"`
- Content struct {
- VideoURL string `json:"video_url"`
- } `json:"content"`
- Seed int `json:"seed"`
- Resolution string `json:"resolution"`
- Duration int `json:"duration"`
- Ratio string `json:"ratio"`
- FramesPerSecond int `json:"framespersecond"`
- Usage struct {
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
- } `json:"usage"`
- CreatedAt int64 `json:"created_at"`
- UpdatedAt int64 `json:"updated_at"`
- }
- // ============================
- // Adaptor implementation
- // ============================
- type TaskAdaptor struct {
- ChannelType int
- apiKey string
- baseURL string
- }
- func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
- a.ChannelType = info.ChannelType
- a.baseURL = info.ChannelBaseUrl
- a.apiKey = info.ApiKey
- }
- // ValidateRequestAndSetAction parses body, validates fields and sets default action.
- func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
- // Accept only POST /v1/video/generations as "generate" action.
- return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
- }
- // BuildRequestURL constructs the upstream URL.
- func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
- return fmt.Sprintf("%s/api/v3/contents/generations/tasks", a.baseURL), nil
- }
- // BuildRequestHeader sets required headers.
- func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Authorization", "Bearer "+a.apiKey)
- return nil
- }
- // BuildRequestBody converts request into Doubao specific format.
- func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
- v, exists := c.Get("task_request")
- if !exists {
- return nil, fmt.Errorf("request not found in context")
- }
- req := v.(relaycommon.TaskSubmitReq)
- body, err := a.convertToRequestPayload(&req)
- if err != nil {
- return nil, errors.Wrap(err, "convert request payload failed")
- }
- data, err := json.Marshal(body)
- if err != nil {
- return nil, err
- }
- return bytes.NewReader(data), nil
- }
- // DoRequest delegates to common helper.
- func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
- return channel.DoTaskApiRequest(a, c, info, requestBody)
- }
- // DoResponse handles upstream response, returns taskID etc.
- func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- return
- }
- _ = resp.Body.Close()
- // Parse Doubao response
- var dResp responsePayload
- if err := json.Unmarshal(responseBody, &dResp); err != nil {
- taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
- return
- }
- if dResp.ID == "" {
- taskErr = service.TaskErrorWrapper(fmt.Errorf("task_id is empty"), "invalid_response", http.StatusInternalServerError)
- return
- }
- c.JSON(http.StatusOK, gin.H{"task_id": dResp.ID})
- return dResp.ID, responseBody, nil
- }
- // FetchTask fetch task status
- func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
- taskID, ok := body["task_id"].(string)
- if !ok {
- return nil, fmt.Errorf("invalid task_id")
- }
- uri := fmt.Sprintf("%s/api/v3/contents/generations/tasks/%s", baseUrl, taskID)
- req, err := http.NewRequest(http.MethodGet, uri, nil)
- if err != nil {
- return nil, err
- }
- req.Header.Set("Accept", "application/json")
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Authorization", "Bearer "+key)
- return service.GetHttpClient().Do(req)
- }
- func (a *TaskAdaptor) GetModelList() []string {
- return ModelList
- }
- func (a *TaskAdaptor) GetChannelName() string {
- return ChannelName
- }
- func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
- r := requestPayload{
- Model: req.Model,
- Content: []ContentItem{},
- }
- // Add text prompt
- if req.Prompt != "" {
- r.Content = append(r.Content, ContentItem{
- Type: "text",
- Text: req.Prompt,
- })
- }
- // Add images if present
- if req.HasImage() {
- for _, imgURL := range req.Images {
- r.Content = append(r.Content, ContentItem{
- Type: "image_url",
- ImageURL: &ImageURL{
- URL: imgURL,
- },
- })
- }
- }
- // TODO: Add support for additional parameters from metadata
- // such as ratio, duration, seed, etc.
- // metadata := req.Metadata
- // if metadata != nil {
- // // Parse and apply metadata parameters
- // }
- return &r, nil
- }
- func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
- resTask := responseTask{}
- if err := json.Unmarshal(respBody, &resTask); err != nil {
- return nil, errors.Wrap(err, "unmarshal task result failed")
- }
- taskResult := relaycommon.TaskInfo{
- Code: 0,
- }
- // Map Doubao status to internal status
- switch resTask.Status {
- case "pending", "queued":
- taskResult.Status = model.TaskStatusQueued
- taskResult.Progress = "10%"
- case "processing":
- taskResult.Status = model.TaskStatusInProgress
- taskResult.Progress = "50%"
- case "succeeded":
- taskResult.Status = model.TaskStatusSuccess
- taskResult.Progress = "100%"
- taskResult.Url = resTask.Content.VideoURL
- // 解析 usage 信息用于按倍率计费
- taskResult.CompletionTokens = resTask.Usage.CompletionTokens
- taskResult.TotalTokens = resTask.Usage.TotalTokens
- case "failed":
- taskResult.Status = model.TaskStatusFailure
- taskResult.Progress = "100%"
- taskResult.Reason = "task failed"
- default:
- // Unknown status, treat as processing
- taskResult.Status = model.TaskStatusInProgress
- taskResult.Progress = "30%"
- }
- return &taskResult, nil
- }
|