| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530 |
- package replicate
- import (
- "bytes"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "mime/multipart"
- "net/http"
- "net/textproto"
- "strconv"
- "strings"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/relay/channel"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- relayconstant "github.com/QuantumNous/new-api/relay/constant"
- "github.com/QuantumNous/new-api/service"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- )
- type Adaptor struct {
- }
- func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
- }
- func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
- if info == nil {
- return "", errors.New("replicate adaptor: relay info is nil")
- }
- if info.ChannelBaseUrl == "" {
- info.ChannelBaseUrl = constant.ChannelBaseURLs[constant.ChannelTypeReplicate]
- }
- requestPath := info.RequestURLPath
- if requestPath == "" {
- return info.ChannelBaseUrl, nil
- }
- return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestPath, info.ChannelType), nil
- }
- func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
- if info == nil {
- return errors.New("replicate adaptor: relay info is nil")
- }
- if info.ApiKey == "" {
- return errors.New("replicate adaptor: api key is required")
- }
- channel.SetupApiRequestHeader(info, c, req)
- req.Set("Authorization", "Bearer "+info.ApiKey)
- req.Set("Prefer", "wait")
- if req.Get("Content-Type") == "" {
- req.Set("Content-Type", "application/json")
- }
- if req.Get("Accept") == "" {
- req.Set("Accept", "application/json")
- }
- return nil
- }
- func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
- if info == nil {
- return nil, errors.New("replicate adaptor: relay info is nil")
- }
- if strings.TrimSpace(request.Prompt) == "" {
- if v := c.PostForm("prompt"); strings.TrimSpace(v) != "" {
- request.Prompt = v
- }
- }
- if strings.TrimSpace(request.Prompt) == "" {
- return nil, errors.New("replicate adaptor: prompt is required")
- }
- modelName := strings.TrimSpace(info.UpstreamModelName)
- if modelName == "" {
- modelName = strings.TrimSpace(request.Model)
- }
- if modelName == "" {
- modelName = ModelFlux11Pro
- }
- info.UpstreamModelName = modelName
- info.RequestURLPath = fmt.Sprintf("/v1/models/%s/predictions", modelName)
- inputPayload := make(map[string]any)
- inputPayload["prompt"] = request.Prompt
- if size := strings.TrimSpace(request.Size); size != "" {
- if aspect, width, height, ok := mapOpenAISizeToFlux(size); ok {
- if aspect != "" {
- if aspect == "custom" {
- inputPayload["aspect_ratio"] = "custom"
- if width > 0 {
- inputPayload["width"] = width
- }
- if height > 0 {
- inputPayload["height"] = height
- }
- } else {
- inputPayload["aspect_ratio"] = aspect
- }
- }
- }
- }
- if len(request.OutputFormat) > 0 {
- var outputFormat string
- if err := json.Unmarshal(request.OutputFormat, &outputFormat); err == nil && strings.TrimSpace(outputFormat) != "" {
- inputPayload["output_format"] = outputFormat
- }
- }
- if request.N > 0 {
- inputPayload["num_outputs"] = int(request.N)
- }
- if strings.EqualFold(request.Quality, "hd") || strings.EqualFold(request.Quality, "high") {
- inputPayload["prompt_upsampling"] = true
- }
- if info.RelayMode == relayconstant.RelayModeImagesEdits {
- imageURL, err := uploadFileFromForm(c, info, "image", "image[]", "image_prompt")
- if err != nil {
- return nil, err
- }
- if imageURL == "" {
- return nil, errors.New("replicate adaptor: image file is required for edits")
- }
- inputPayload["image_prompt"] = imageURL
- }
- if len(request.ExtraFields) > 0 {
- var extra map[string]any
- if err := common.Unmarshal(request.ExtraFields, &extra); err != nil {
- return nil, fmt.Errorf("replicate adaptor: failed to decode extra_fields: %w", err)
- }
- for key, val := range extra {
- inputPayload[key] = val
- }
- }
- for key, raw := range request.Extra {
- if strings.EqualFold(key, "input") {
- var extraInput map[string]any
- if err := common.Unmarshal(raw, &extraInput); err != nil {
- return nil, fmt.Errorf("replicate adaptor: failed to decode extra input: %w", err)
- }
- for k, v := range extraInput {
- inputPayload[k] = v
- }
- continue
- }
- if raw == nil {
- continue
- }
- var val any
- if err := common.Unmarshal(raw, &val); err != nil {
- return nil, fmt.Errorf("replicate adaptor: failed to decode extra field %s: %w", key, err)
- }
- inputPayload[key] = val
- }
- return map[string]any{
- "input": inputPayload,
- }, nil
- }
- func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
- return channel.DoApiRequest(a, c, info, requestBody)
- }
- func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (any, *types.NewAPIError) {
- if resp == nil {
- return nil, types.NewError(errors.New("replicate adaptor: empty response"), types.ErrorCodeBadResponse)
- }
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
- }
- _ = resp.Body.Close()
- var prediction PredictionResponse
- if err := common.Unmarshal(responseBody, &prediction); err != nil {
- return nil, types.NewError(fmt.Errorf("replicate adaptor: failed to decode response: %w", err), types.ErrorCodeBadResponseBody)
- }
- if prediction.Error != nil {
- errMsg := prediction.Error.Message
- if errMsg == "" {
- errMsg = prediction.Error.Detail
- }
- if errMsg == "" {
- errMsg = prediction.Error.Code
- }
- if errMsg == "" {
- errMsg = "replicate adaptor: prediction error"
- }
- return nil, types.NewError(errors.New(errMsg), types.ErrorCodeBadResponse)
- }
- if prediction.Status != "" && !strings.EqualFold(prediction.Status, "succeeded") {
- return nil, types.NewError(fmt.Errorf("replicate adaptor: prediction status %q", prediction.Status), types.ErrorCodeBadResponse)
- }
- var urls []string
- appendOutput := func(value string) {
- value = strings.TrimSpace(value)
- if value == "" {
- return
- }
- urls = append(urls, value)
- }
- switch output := prediction.Output.(type) {
- case string:
- appendOutput(output)
- case []any:
- for _, item := range output {
- if str, ok := item.(string); ok {
- appendOutput(str)
- }
- }
- case nil:
- // no output
- default:
- if str, ok := output.(fmt.Stringer); ok {
- appendOutput(str.String())
- }
- }
- if len(urls) == 0 {
- return nil, types.NewError(errors.New("replicate adaptor: empty prediction output"), types.ErrorCodeBadResponseBody)
- }
- var imageReq *dto.ImageRequest
- if info != nil {
- if req, ok := info.Request.(*dto.ImageRequest); ok {
- imageReq = req
- }
- }
- wantsBase64 := imageReq != nil && strings.EqualFold(imageReq.ResponseFormat, "b64_json")
- imageResponse := dto.ImageResponse{
- Created: common.GetTimestamp(),
- Data: make([]dto.ImageData, 0),
- }
- if wantsBase64 {
- converted, convErr := downloadImagesToBase64(urls)
- if convErr != nil {
- return nil, types.NewError(convErr, types.ErrorCodeBadResponse)
- }
- for _, content := range converted {
- if content == "" {
- continue
- }
- imageResponse.Data = append(imageResponse.Data, dto.ImageData{B64Json: content})
- }
- } else {
- for _, url := range urls {
- if url == "" {
- continue
- }
- imageResponse.Data = append(imageResponse.Data, dto.ImageData{Url: url})
- }
- }
- if len(imageResponse.Data) == 0 {
- return nil, types.NewError(errors.New("replicate adaptor: no usable image data"), types.ErrorCodeBadResponse)
- }
- responseBytes, err := common.Marshal(imageResponse)
- if err != nil {
- return nil, types.NewError(fmt.Errorf("replicate adaptor: encode response failed: %w", err), types.ErrorCodeBadResponseBody)
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.Writer.WriteHeader(http.StatusOK)
- _, _ = c.Writer.Write(responseBytes)
- usage := &dto.Usage{}
- return usage, nil
- }
- func (a *Adaptor) GetModelList() []string {
- return ModelList
- }
- func (a *Adaptor) GetChannelName() string {
- return ChannelName
- }
- func downloadImagesToBase64(urls []string) ([]string, error) {
- results := make([]string, 0, len(urls))
- for _, url := range urls {
- if strings.TrimSpace(url) == "" {
- continue
- }
- _, data, err := service.GetImageFromUrl(url)
- if err != nil {
- return nil, fmt.Errorf("replicate adaptor: failed to download image from %s: %w", url, err)
- }
- results = append(results, data)
- }
- return results, nil
- }
- func mapOpenAISizeToFlux(size string) (aspect string, width int, height int, ok bool) {
- parts := strings.Split(size, "x")
- if len(parts) != 2 {
- return "", 0, 0, false
- }
- w, err1 := strconv.Atoi(strings.TrimSpace(parts[0]))
- h, err2 := strconv.Atoi(strings.TrimSpace(parts[1]))
- if err1 != nil || err2 != nil || w <= 0 || h <= 0 {
- return "", 0, 0, false
- }
- switch {
- case w == h:
- return "1:1", 0, 0, true
- case w == 1792 && h == 1024:
- return "16:9", 0, 0, true
- case w == 1024 && h == 1792:
- return "9:16", 0, 0, true
- case w == 1536 && h == 1024:
- return "3:2", 0, 0, true
- case w == 1024 && h == 1536:
- return "2:3", 0, 0, true
- }
- rw, rh := reduceRatio(w, h)
- ratioStr := fmt.Sprintf("%d:%d", rw, rh)
- switch ratioStr {
- case "1:1", "16:9", "9:16", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3":
- return ratioStr, 0, 0, true
- }
- width = normalizeFluxDimension(w)
- height = normalizeFluxDimension(h)
- return "custom", width, height, true
- }
- func reduceRatio(w, h int) (int, int) {
- g := gcd(w, h)
- if g == 0 {
- return w, h
- }
- return w / g, h / g
- }
- func gcd(a, b int) int {
- for b != 0 {
- a, b = b, a%b
- }
- if a < 0 {
- return -a
- }
- return a
- }
- func normalizeFluxDimension(value int) int {
- const (
- minDim = 256
- maxDim = 1440
- step = 32
- )
- if value < minDim {
- value = minDim
- }
- if value > maxDim {
- value = maxDim
- }
- remainder := value % step
- if remainder != 0 {
- if remainder >= step/2 {
- value += step - remainder
- } else {
- value -= remainder
- }
- }
- if value < minDim {
- value = minDim
- }
- if value > maxDim {
- value = maxDim
- }
- return value
- }
- func uploadFileFromForm(c *gin.Context, info *relaycommon.RelayInfo, fieldCandidates ...string) (string, error) {
- if info == nil {
- return "", errors.New("replicate adaptor: relay info is nil")
- }
- mf := c.Request.MultipartForm
- if mf == nil {
- if _, err := c.MultipartForm(); err != nil {
- return "", fmt.Errorf("replicate adaptor: parse multipart form failed: %w", err)
- }
- mf = c.Request.MultipartForm
- }
- if mf == nil || len(mf.File) == 0 {
- return "", nil
- }
- if len(fieldCandidates) == 0 {
- fieldCandidates = []string{"image", "image[]", "image_prompt"}
- }
- var fileHeader *multipart.FileHeader
- for _, key := range fieldCandidates {
- if files := mf.File[key]; len(files) > 0 {
- fileHeader = files[0]
- break
- }
- }
- if fileHeader == nil {
- for _, files := range mf.File {
- if len(files) > 0 {
- fileHeader = files[0]
- break
- }
- }
- }
- if fileHeader == nil {
- return "", nil
- }
- file, err := fileHeader.Open()
- if err != nil {
- return "", fmt.Errorf("replicate adaptor: failed to open image file: %w", err)
- }
- defer file.Close()
- var body bytes.Buffer
- writer := multipart.NewWriter(&body)
- hdr := make(textproto.MIMEHeader)
- hdr.Set("Content-Disposition", fmt.Sprintf("form-data; name=\"content\"; filename=\"%s\"", fileHeader.Filename))
- contentType := fileHeader.Header.Get("Content-Type")
- if contentType == "" {
- contentType = "application/octet-stream"
- }
- hdr.Set("Content-Type", contentType)
- part, err := writer.CreatePart(hdr)
- if err != nil {
- writer.Close()
- return "", fmt.Errorf("replicate adaptor: create upload form failed: %w", err)
- }
- if _, err := io.Copy(part, file); err != nil {
- writer.Close()
- return "", fmt.Errorf("replicate adaptor: copy image content failed: %w", err)
- }
- formContentType := writer.FormDataContentType()
- writer.Close()
- baseURL := info.ChannelBaseUrl
- if baseURL == "" {
- baseURL = constant.ChannelBaseURLs[constant.ChannelTypeReplicate]
- }
- uploadURL := relaycommon.GetFullRequestURL(baseURL, "/v1/files", info.ChannelType)
- req, err := http.NewRequest(http.MethodPost, uploadURL, &body)
- if err != nil {
- return "", fmt.Errorf("replicate adaptor: create upload request failed: %w", err)
- }
- req.Header.Set("Content-Type", formContentType)
- req.Header.Set("Authorization", "Bearer "+info.ApiKey)
- resp, err := service.GetHttpClient().Do(req)
- if err != nil {
- return "", fmt.Errorf("replicate adaptor: upload image failed: %w", err)
- }
- defer resp.Body.Close()
- respBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return "", fmt.Errorf("replicate adaptor: read upload response failed: %w", err)
- }
- if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
- return "", fmt.Errorf("replicate adaptor: upload image failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
- }
- var uploadResp FileUploadResponse
- if err := common.Unmarshal(respBody, &uploadResp); err != nil {
- return "", fmt.Errorf("replicate adaptor: decode upload response failed: %w", err)
- }
- if uploadResp.Urls.Get == "" {
- return "", errors.New("replicate adaptor: upload response missing url")
- }
- return uploadResp.Urls.Get, nil
- }
- func (a *Adaptor) ConvertOpenAIRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeneralOpenAIRequest) (any, error) {
- return nil, errors.New("replicate adaptor: ConvertOpenAIRequest is not implemented")
- }
- func (a *Adaptor) ConvertRerankRequest(*gin.Context, int, dto.RerankRequest) (any, error) {
- return nil, errors.New("replicate adaptor: ConvertRerankRequest is not implemented")
- }
- func (a *Adaptor) ConvertEmbeddingRequest(*gin.Context, *relaycommon.RelayInfo, dto.EmbeddingRequest) (any, error) {
- return nil, errors.New("replicate adaptor: ConvertEmbeddingRequest is not implemented")
- }
- func (a *Adaptor) ConvertAudioRequest(*gin.Context, *relaycommon.RelayInfo, dto.AudioRequest) (io.Reader, error) {
- return nil, errors.New("replicate adaptor: ConvertAudioRequest is not implemented")
- }
- func (a *Adaptor) ConvertOpenAIResponsesRequest(*gin.Context, *relaycommon.RelayInfo, dto.OpenAIResponsesRequest) (any, error) {
- return nil, errors.New("replicate adaptor: ConvertOpenAIResponsesRequest is not implemented")
- }
- func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
- return nil, errors.New("replicate adaptor: ConvertClaudeRequest is not implemented")
- }
- func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
- return nil, errors.New("replicate adaptor: ConvertGeminiRequest is not implemented")
- }
|