package replicate import ( "bytes" "encoding/json" "errors" "fmt" "io" "mime/multipart" "net/http" "net/textproto" "strconv" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/relay/channel" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/types" "github.com/gin-gonic/gin" ) type Adaptor struct { } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { if info == nil { return "", errors.New("replicate adaptor: relay info is nil") } if info.ChannelBaseUrl == "" { info.ChannelBaseUrl = constant.ChannelBaseURLs[constant.ChannelTypeReplicate] } requestPath := info.RequestURLPath if requestPath == "" { return info.ChannelBaseUrl, nil } return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestPath, info.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { if info == nil { return errors.New("replicate adaptor: relay info is nil") } if info.ApiKey == "" { return errors.New("replicate adaptor: api key is required") } channel.SetupApiRequestHeader(info, c, req) req.Set("Authorization", "Bearer "+info.ApiKey) req.Set("Prefer", "wait") if req.Get("Content-Type") == "" { req.Set("Content-Type", "application/json") } if req.Get("Accept") == "" { req.Set("Accept", "application/json") } return nil } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { if info == nil { return nil, errors.New("replicate adaptor: relay info is nil") } if strings.TrimSpace(request.Prompt) == "" { if v := c.PostForm("prompt"); strings.TrimSpace(v) != "" { request.Prompt = v } } if strings.TrimSpace(request.Prompt) == "" { return nil, errors.New("replicate adaptor: prompt is required") } modelName := strings.TrimSpace(info.UpstreamModelName) if modelName == "" { modelName = strings.TrimSpace(request.Model) } if modelName == "" { modelName = ModelFlux11Pro } info.UpstreamModelName = modelName info.RequestURLPath = fmt.Sprintf("/v1/models/%s/predictions", modelName) inputPayload := make(map[string]any) inputPayload["prompt"] = request.Prompt if size := strings.TrimSpace(request.Size); size != "" { if aspect, width, height, ok := mapOpenAISizeToFlux(size); ok { if aspect != "" { if aspect == "custom" { inputPayload["aspect_ratio"] = "custom" if width > 0 { inputPayload["width"] = width } if height > 0 { inputPayload["height"] = height } } else { inputPayload["aspect_ratio"] = aspect } } } } if len(request.OutputFormat) > 0 { var outputFormat string if err := json.Unmarshal(request.OutputFormat, &outputFormat); err == nil && strings.TrimSpace(outputFormat) != "" { inputPayload["output_format"] = outputFormat } } if request.N > 0 { inputPayload["num_outputs"] = int(request.N) } if strings.EqualFold(request.Quality, "hd") || strings.EqualFold(request.Quality, "high") { inputPayload["prompt_upsampling"] = true } if info.RelayMode == relayconstant.RelayModeImagesEdits { imageURL, err := uploadFileFromForm(c, info, "image", "image[]", "image_prompt") if err != nil { return nil, err } if imageURL == "" { return nil, errors.New("replicate adaptor: image file is required for edits") } inputPayload["image_prompt"] = imageURL } if len(request.ExtraFields) > 0 { var extra map[string]any if err := common.Unmarshal(request.ExtraFields, &extra); err != nil { return nil, fmt.Errorf("replicate adaptor: failed to decode extra_fields: %w", err) } for key, val := range extra { inputPayload[key] = val } } for key, raw := range request.Extra { if strings.EqualFold(key, "input") { var extraInput map[string]any if err := common.Unmarshal(raw, &extraInput); err != nil { return nil, fmt.Errorf("replicate adaptor: failed to decode extra input: %w", err) } for k, v := range extraInput { inputPayload[k] = v } continue } if raw == nil { continue } var val any if err := common.Unmarshal(raw, &val); err != nil { return nil, fmt.Errorf("replicate adaptor: failed to decode extra field %s: %w", key, err) } inputPayload[key] = val } return map[string]any{ "input": inputPayload, }, nil } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (any, *types.NewAPIError) { if resp == nil { return nil, types.NewError(errors.New("replicate adaptor: empty response"), types.ErrorCodeBadResponse) } responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) } _ = resp.Body.Close() var prediction PredictionResponse if err := common.Unmarshal(responseBody, &prediction); err != nil { return nil, types.NewError(fmt.Errorf("replicate adaptor: failed to decode response: %w", err), types.ErrorCodeBadResponseBody) } if prediction.Error != nil { errMsg := prediction.Error.Message if errMsg == "" { errMsg = prediction.Error.Detail } if errMsg == "" { errMsg = prediction.Error.Code } if errMsg == "" { errMsg = "replicate adaptor: prediction error" } return nil, types.NewError(errors.New(errMsg), types.ErrorCodeBadResponse) } if prediction.Status != "" && !strings.EqualFold(prediction.Status, "succeeded") { return nil, types.NewError(fmt.Errorf("replicate adaptor: prediction status %q", prediction.Status), types.ErrorCodeBadResponse) } var urls []string appendOutput := func(value string) { value = strings.TrimSpace(value) if value == "" { return } urls = append(urls, value) } switch output := prediction.Output.(type) { case string: appendOutput(output) case []any: for _, item := range output { if str, ok := item.(string); ok { appendOutput(str) } } case nil: // no output default: if str, ok := output.(fmt.Stringer); ok { appendOutput(str.String()) } } if len(urls) == 0 { return nil, types.NewError(errors.New("replicate adaptor: empty prediction output"), types.ErrorCodeBadResponseBody) } var imageReq *dto.ImageRequest if info != nil { if req, ok := info.Request.(*dto.ImageRequest); ok { imageReq = req } } wantsBase64 := imageReq != nil && strings.EqualFold(imageReq.ResponseFormat, "b64_json") imageResponse := dto.ImageResponse{ Created: common.GetTimestamp(), Data: make([]dto.ImageData, 0), } if wantsBase64 { converted, convErr := downloadImagesToBase64(urls) if convErr != nil { return nil, types.NewError(convErr, types.ErrorCodeBadResponse) } for _, content := range converted { if content == "" { continue } imageResponse.Data = append(imageResponse.Data, dto.ImageData{B64Json: content}) } } else { for _, url := range urls { if url == "" { continue } imageResponse.Data = append(imageResponse.Data, dto.ImageData{Url: url}) } } if len(imageResponse.Data) == 0 { return nil, types.NewError(errors.New("replicate adaptor: no usable image data"), types.ErrorCodeBadResponse) } responseBytes, err := common.Marshal(imageResponse) if err != nil { return nil, types.NewError(fmt.Errorf("replicate adaptor: encode response failed: %w", err), types.ErrorCodeBadResponseBody) } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(http.StatusOK) _, _ = c.Writer.Write(responseBytes) usage := &dto.Usage{} return usage, nil } func (a *Adaptor) GetModelList() []string { return ModelList } func (a *Adaptor) GetChannelName() string { return ChannelName } func downloadImagesToBase64(urls []string) ([]string, error) { results := make([]string, 0, len(urls)) for _, url := range urls { if strings.TrimSpace(url) == "" { continue } _, data, err := service.GetImageFromUrl(url) if err != nil { return nil, fmt.Errorf("replicate adaptor: failed to download image from %s: %w", url, err) } results = append(results, data) } return results, nil } func mapOpenAISizeToFlux(size string) (aspect string, width int, height int, ok bool) { parts := strings.Split(size, "x") if len(parts) != 2 { return "", 0, 0, false } w, err1 := strconv.Atoi(strings.TrimSpace(parts[0])) h, err2 := strconv.Atoi(strings.TrimSpace(parts[1])) if err1 != nil || err2 != nil || w <= 0 || h <= 0 { return "", 0, 0, false } switch { case w == h: return "1:1", 0, 0, true case w == 1792 && h == 1024: return "16:9", 0, 0, true case w == 1024 && h == 1792: return "9:16", 0, 0, true case w == 1536 && h == 1024: return "3:2", 0, 0, true case w == 1024 && h == 1536: return "2:3", 0, 0, true } rw, rh := reduceRatio(w, h) ratioStr := fmt.Sprintf("%d:%d", rw, rh) switch ratioStr { case "1:1", "16:9", "9:16", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3": return ratioStr, 0, 0, true } width = normalizeFluxDimension(w) height = normalizeFluxDimension(h) return "custom", width, height, true } func reduceRatio(w, h int) (int, int) { g := gcd(w, h) if g == 0 { return w, h } return w / g, h / g } func gcd(a, b int) int { for b != 0 { a, b = b, a%b } if a < 0 { return -a } return a } func normalizeFluxDimension(value int) int { const ( minDim = 256 maxDim = 1440 step = 32 ) if value < minDim { value = minDim } if value > maxDim { value = maxDim } remainder := value % step if remainder != 0 { if remainder >= step/2 { value += step - remainder } else { value -= remainder } } if value < minDim { value = minDim } if value > maxDim { value = maxDim } return value } func uploadFileFromForm(c *gin.Context, info *relaycommon.RelayInfo, fieldCandidates ...string) (string, error) { if info == nil { return "", errors.New("replicate adaptor: relay info is nil") } mf := c.Request.MultipartForm if mf == nil { if _, err := c.MultipartForm(); err != nil { return "", fmt.Errorf("replicate adaptor: parse multipart form failed: %w", err) } mf = c.Request.MultipartForm } if mf == nil || len(mf.File) == 0 { return "", nil } if len(fieldCandidates) == 0 { fieldCandidates = []string{"image", "image[]", "image_prompt"} } var fileHeader *multipart.FileHeader for _, key := range fieldCandidates { if files := mf.File[key]; len(files) > 0 { fileHeader = files[0] break } } if fileHeader == nil { for _, files := range mf.File { if len(files) > 0 { fileHeader = files[0] break } } } if fileHeader == nil { return "", nil } file, err := fileHeader.Open() if err != nil { return "", fmt.Errorf("replicate adaptor: failed to open image file: %w", err) } defer file.Close() var body bytes.Buffer writer := multipart.NewWriter(&body) hdr := make(textproto.MIMEHeader) hdr.Set("Content-Disposition", fmt.Sprintf("form-data; name=\"content\"; filename=\"%s\"", fileHeader.Filename)) contentType := fileHeader.Header.Get("Content-Type") if contentType == "" { contentType = "application/octet-stream" } hdr.Set("Content-Type", contentType) part, err := writer.CreatePart(hdr) if err != nil { writer.Close() return "", fmt.Errorf("replicate adaptor: create upload form failed: %w", err) } if _, err := io.Copy(part, file); err != nil { writer.Close() return "", fmt.Errorf("replicate adaptor: copy image content failed: %w", err) } formContentType := writer.FormDataContentType() writer.Close() baseURL := info.ChannelBaseUrl if baseURL == "" { baseURL = constant.ChannelBaseURLs[constant.ChannelTypeReplicate] } uploadURL := relaycommon.GetFullRequestURL(baseURL, "/v1/files", info.ChannelType) req, err := http.NewRequest(http.MethodPost, uploadURL, &body) if err != nil { return "", fmt.Errorf("replicate adaptor: create upload request failed: %w", err) } req.Header.Set("Content-Type", formContentType) req.Header.Set("Authorization", "Bearer "+info.ApiKey) resp, err := service.GetHttpClient().Do(req) if err != nil { return "", fmt.Errorf("replicate adaptor: upload image failed: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("replicate adaptor: read upload response failed: %w", err) } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { return "", fmt.Errorf("replicate adaptor: upload image failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) } var uploadResp FileUploadResponse if err := common.Unmarshal(respBody, &uploadResp); err != nil { return "", fmt.Errorf("replicate adaptor: decode upload response failed: %w", err) } if uploadResp.Urls.Get == "" { return "", errors.New("replicate adaptor: upload response missing url") } return uploadResp.Urls.Get, nil } func (a *Adaptor) ConvertOpenAIRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeneralOpenAIRequest) (any, error) { return nil, errors.New("replicate adaptor: ConvertOpenAIRequest is not implemented") } func (a *Adaptor) ConvertRerankRequest(*gin.Context, int, dto.RerankRequest) (any, error) { return nil, errors.New("replicate adaptor: ConvertRerankRequest is not implemented") } func (a *Adaptor) ConvertEmbeddingRequest(*gin.Context, *relaycommon.RelayInfo, dto.EmbeddingRequest) (any, error) { return nil, errors.New("replicate adaptor: ConvertEmbeddingRequest is not implemented") } func (a *Adaptor) ConvertAudioRequest(*gin.Context, *relaycommon.RelayInfo, dto.AudioRequest) (io.Reader, error) { return nil, errors.New("replicate adaptor: ConvertAudioRequest is not implemented") } func (a *Adaptor) ConvertOpenAIResponsesRequest(*gin.Context, *relaycommon.RelayInfo, dto.OpenAIResponsesRequest) (any, error) { return nil, errors.New("replicate adaptor: ConvertOpenAIResponsesRequest is not implemented") } func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { return nil, errors.New("replicate adaptor: ConvertClaudeRequest is not implemented") } func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { return nil, errors.New("replicate adaptor: ConvertGeminiRequest is not implemented") }