package vertex import ( "bytes" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "regexp" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/model" "github.com/gin-gonic/gin" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex" relaycommon "github.com/QuantumNous/new-api/relay/common" "github.com/QuantumNous/new-api/service" ) // ============================ // Request / Response structures // ============================ type requestPayload struct { Instances []map[string]any `json:"instances"` Parameters map[string]any `json:"parameters,omitempty"` } type submitResponse struct { Name string `json:"name"` } type operationVideo struct { MimeType string `json:"mimeType"` BytesBase64Encoded string `json:"bytesBase64Encoded"` Encoding string `json:"encoding"` } type operationResponse struct { Name string `json:"name"` Done bool `json:"done"` Response struct { Type string `json:"@type"` RaiMediaFilteredCount int `json:"raiMediaFilteredCount"` Videos []operationVideo `json:"videos"` BytesBase64Encoded string `json:"bytesBase64Encoded"` Encoding string `json:"encoding"` Video string `json:"video"` } `json:"response"` Error struct { Message string `json:"message"` } `json:"error"` } // ============================ // Adaptor implementation // ============================ type TaskAdaptor struct { ChannelType int apiKey string baseURL string } func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType a.baseURL = info.ChannelBaseUrl a.apiKey = info.ApiKey } // ValidateRequestAndSetAction parses body, validates fields and sets default action. func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { // Use the standard validation method for TaskSubmitReq return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate) } // BuildRequestURL constructs the upstream URL. func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { adc := &vertexcore.Credentials{} if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } modelName := info.OriginModelName if modelName == "" { modelName = "veo-3.0-generate-001" } region := vertexcore.GetModelRegion(info.ApiVersion, modelName) if strings.TrimSpace(region) == "" { region = "global" } if region == "global" { return fmt.Sprintf( "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning", adc.ProjectID, modelName, ), nil } return fmt.Sprintf( "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning", region, adc.ProjectID, region, modelName, ), nil } // BuildRequestHeader sets required headers. func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") adc := &vertexcore.Credentials{} if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { return fmt.Errorf("failed to decode credentials: %w", err) } token, err := vertexcore.AcquireAccessToken(*adc, "") if err != nil { return fmt.Errorf("failed to acquire access token: %w", err) } req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("x-goog-user-project", adc.ProjectID) return nil } // BuildRequestBody converts request into Vertex specific format. func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, ok := c.Get("task_request") if !ok { return nil, fmt.Errorf("request not found in context") } req := v.(relaycommon.TaskSubmitReq) body := requestPayload{ Instances: []map[string]any{{"prompt": req.Prompt}}, Parameters: map[string]any{}, } if req.Metadata != nil { if v, ok := req.Metadata["storageUri"]; ok { body.Parameters["storageUri"] = v } if v, ok := req.Metadata["sampleCount"]; ok { body.Parameters["sampleCount"] = v } } if _, ok := body.Parameters["sampleCount"]; !ok { body.Parameters["sampleCount"] = 1 } data, err := json.Marshal(body) if err != nil { return nil, err } return bytes.NewReader(data), nil } // DoRequest delegates to common helper. func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } // DoResponse handles upstream response, returns taskID etc. func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } _ = resp.Body.Close() var s submitResponse if err := json.Unmarshal(responseBody, &s); err != nil { return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError) } if strings.TrimSpace(s.Name) == "" { return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError) } localID := encodeLocalTaskID(s.Name) c.JSON(http.StatusOK, gin.H{"task_id": localID}) return localID, responseBody, nil } func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} } func (a *TaskAdaptor) GetChannelName() string { return "vertex" } // FetchTask fetch task status func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { return nil, fmt.Errorf("invalid task_id") } upstreamName, err := decodeLocalTaskID(taskID) if err != nil { return nil, fmt.Errorf("decode task_id failed: %w", err) } region := extractRegionFromOperationName(upstreamName) if region == "" { region = "us-central1" } project := extractProjectFromOperationName(upstreamName) modelName := extractModelFromOperationName(upstreamName) if project == "" || modelName == "" { return nil, fmt.Errorf("cannot extract project/model from operation name") } var url string if region == "global" { url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName) } else { url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName) } payload := map[string]string{"operationName": upstreamName} data, err := json.Marshal(payload) if err != nil { return nil, err } adc := &vertexcore.Credentials{} if err := json.Unmarshal([]byte(key), adc); err != nil { return nil, fmt.Errorf("failed to decode credentials: %w", err) } token, err := vertexcore.AcquireAccessToken(*adc, "") if err != nil { return nil, fmt.Errorf("failed to acquire access token: %w", err) } req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("x-goog-user-project", adc.ProjectID) return service.GetHttpClient().Do(req) } func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) { var op operationResponse if err := json.Unmarshal(respBody, &op); err != nil { return nil, fmt.Errorf("unmarshal operation response failed: %w", err) } ti := &relaycommon.TaskInfo{} if op.Error.Message != "" { ti.Status = model.TaskStatusFailure ti.Reason = op.Error.Message ti.Progress = "100%" return ti, nil } if !op.Done { ti.Status = model.TaskStatusInProgress ti.Progress = "50%" return ti, nil } ti.Status = model.TaskStatusSuccess ti.Progress = "100%" if len(op.Response.Videos) > 0 { v0 := op.Response.Videos[0] if v0.BytesBase64Encoded != "" { mime := strings.TrimSpace(v0.MimeType) if mime == "" { enc := strings.TrimSpace(v0.Encoding) if enc == "" { enc = "mp4" } if strings.Contains(enc, "/") { mime = enc } else { mime = "video/" + enc } } ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded return ti, nil } } if op.Response.BytesBase64Encoded != "" { enc := strings.TrimSpace(op.Response.Encoding) if enc == "" { enc = "mp4" } mime := enc if !strings.Contains(enc, "/") { mime = "video/" + enc } ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded return ti, nil } if op.Response.Video != "" { // some variants use `video` as base64 enc := strings.TrimSpace(op.Response.Encoding) if enc == "" { enc = "mp4" } mime := enc if !strings.Contains(enc, "/") { mime = "video/" + enc } ti.Url = "data:" + mime + ";base64," + op.Response.Video return ti, nil } return ti, nil } func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) { upstreamName, err := decodeLocalTaskID(task.TaskID) if err != nil { upstreamName = "" } modelName := extractModelFromOperationName(upstreamName) if strings.TrimSpace(modelName) == "" { modelName = "veo-3.0-generate-001" } v := dto.NewOpenAIVideo() v.ID = task.TaskID v.Model = modelName v.Status = task.Status.ToVideoStatus() v.SetProgressStr(task.Progress) v.CreatedAt = task.CreatedAt v.CompletedAt = task.UpdatedAt if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 { v.SetMetadata("url", task.FailReason) } return common.Marshal(v) } // ============================ // helpers // ============================ func encodeLocalTaskID(name string) string { return base64.RawURLEncoding.EncodeToString([]byte(name)) } func decodeLocalTaskID(local string) (string, error) { b, err := base64.RawURLEncoding.DecodeString(local) if err != nil { return "", err } return string(b), nil } var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`) func extractRegionFromOperationName(name string) string { m := regionRe.FindStringSubmatch(name) if len(m) == 2 { return m[1] } return "" } var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`) func extractModelFromOperationName(name string) string { m := modelRe.FindStringSubmatch(name) if len(m) == 2 { return m[1] } idx := strings.Index(name, "models/") if idx >= 0 { s := name[idx+len("models/"):] if p := strings.Index(s, "/operations/"); p > 0 { return s[:p] } } return "" } var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`) func extractProjectFromOperationName(name string) string { m := projectRe.FindStringSubmatch(name) if len(m) == 2 { return m[1] } return "" }