adaptor.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. package replicate
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "mime/multipart"
  9. "net/http"
  10. "net/textproto"
  11. "strconv"
  12. "strings"
  13. "github.com/QuantumNous/new-api/common"
  14. "github.com/QuantumNous/new-api/constant"
  15. "github.com/QuantumNous/new-api/dto"
  16. "github.com/QuantumNous/new-api/relay/channel"
  17. relaycommon "github.com/QuantumNous/new-api/relay/common"
  18. relayconstant "github.com/QuantumNous/new-api/relay/constant"
  19. "github.com/QuantumNous/new-api/service"
  20. "github.com/QuantumNous/new-api/types"
  21. "github.com/gin-gonic/gin"
  22. )
  23. type Adaptor struct {
  24. }
  25. func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
  26. }
  27. func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
  28. if info == nil {
  29. return "", errors.New("replicate adaptor: relay info is nil")
  30. }
  31. if info.ChannelBaseUrl == "" {
  32. info.ChannelBaseUrl = constant.ChannelBaseURLs[constant.ChannelTypeReplicate]
  33. }
  34. requestPath := info.RequestURLPath
  35. if requestPath == "" {
  36. return info.ChannelBaseUrl, nil
  37. }
  38. return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestPath, info.ChannelType), nil
  39. }
  40. func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
  41. if info == nil {
  42. return errors.New("replicate adaptor: relay info is nil")
  43. }
  44. if info.ApiKey == "" {
  45. return errors.New("replicate adaptor: api key is required")
  46. }
  47. channel.SetupApiRequestHeader(info, c, req)
  48. req.Set("Authorization", "Bearer "+info.ApiKey)
  49. req.Set("Prefer", "wait")
  50. if req.Get("Content-Type") == "" {
  51. req.Set("Content-Type", "application/json")
  52. }
  53. if req.Get("Accept") == "" {
  54. req.Set("Accept", "application/json")
  55. }
  56. return nil
  57. }
  58. func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
  59. if info == nil {
  60. return nil, errors.New("replicate adaptor: relay info is nil")
  61. }
  62. if strings.TrimSpace(request.Prompt) == "" {
  63. if v := c.PostForm("prompt"); strings.TrimSpace(v) != "" {
  64. request.Prompt = v
  65. }
  66. }
  67. if strings.TrimSpace(request.Prompt) == "" {
  68. return nil, errors.New("replicate adaptor: prompt is required")
  69. }
  70. modelName := strings.TrimSpace(info.UpstreamModelName)
  71. if modelName == "" {
  72. modelName = strings.TrimSpace(request.Model)
  73. }
  74. if modelName == "" {
  75. modelName = ModelFlux11Pro
  76. }
  77. info.UpstreamModelName = modelName
  78. info.RequestURLPath = fmt.Sprintf("/v1/models/%s/predictions", modelName)
  79. inputPayload := make(map[string]any)
  80. inputPayload["prompt"] = request.Prompt
  81. if size := strings.TrimSpace(request.Size); size != "" {
  82. if aspect, width, height, ok := mapOpenAISizeToFlux(size); ok {
  83. if aspect != "" {
  84. if aspect == "custom" {
  85. inputPayload["aspect_ratio"] = "custom"
  86. if width > 0 {
  87. inputPayload["width"] = width
  88. }
  89. if height > 0 {
  90. inputPayload["height"] = height
  91. }
  92. } else {
  93. inputPayload["aspect_ratio"] = aspect
  94. }
  95. }
  96. }
  97. }
  98. if len(request.OutputFormat) > 0 {
  99. var outputFormat string
  100. if err := json.Unmarshal(request.OutputFormat, &outputFormat); err == nil && strings.TrimSpace(outputFormat) != "" {
  101. inputPayload["output_format"] = outputFormat
  102. }
  103. }
  104. if request.N > 0 {
  105. inputPayload["num_outputs"] = int(request.N)
  106. }
  107. if strings.EqualFold(request.Quality, "hd") || strings.EqualFold(request.Quality, "high") {
  108. inputPayload["prompt_upsampling"] = true
  109. }
  110. if info.RelayMode == relayconstant.RelayModeImagesEdits {
  111. imageURL, err := uploadFileFromForm(c, info, "image", "image[]", "image_prompt")
  112. if err != nil {
  113. return nil, err
  114. }
  115. if imageURL == "" {
  116. return nil, errors.New("replicate adaptor: image file is required for edits")
  117. }
  118. inputPayload["image_prompt"] = imageURL
  119. }
  120. if len(request.ExtraFields) > 0 {
  121. var extra map[string]any
  122. if err := common.Unmarshal(request.ExtraFields, &extra); err != nil {
  123. return nil, fmt.Errorf("replicate adaptor: failed to decode extra_fields: %w", err)
  124. }
  125. for key, val := range extra {
  126. inputPayload[key] = val
  127. }
  128. }
  129. for key, raw := range request.Extra {
  130. if strings.EqualFold(key, "input") {
  131. var extraInput map[string]any
  132. if err := common.Unmarshal(raw, &extraInput); err != nil {
  133. return nil, fmt.Errorf("replicate adaptor: failed to decode extra input: %w", err)
  134. }
  135. for k, v := range extraInput {
  136. inputPayload[k] = v
  137. }
  138. continue
  139. }
  140. if raw == nil {
  141. continue
  142. }
  143. var val any
  144. if err := common.Unmarshal(raw, &val); err != nil {
  145. return nil, fmt.Errorf("replicate adaptor: failed to decode extra field %s: %w", key, err)
  146. }
  147. inputPayload[key] = val
  148. }
  149. return map[string]any{
  150. "input": inputPayload,
  151. }, nil
  152. }
  153. func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
  154. return channel.DoApiRequest(a, c, info, requestBody)
  155. }
  156. func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (any, *types.NewAPIError) {
  157. if resp == nil {
  158. return nil, types.NewError(errors.New("replicate adaptor: empty response"), types.ErrorCodeBadResponse)
  159. }
  160. responseBody, err := io.ReadAll(resp.Body)
  161. if err != nil {
  162. return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
  163. }
  164. _ = resp.Body.Close()
  165. var prediction PredictionResponse
  166. if err := common.Unmarshal(responseBody, &prediction); err != nil {
  167. return nil, types.NewError(fmt.Errorf("replicate adaptor: failed to decode response: %w", err), types.ErrorCodeBadResponseBody)
  168. }
  169. if prediction.Error != nil {
  170. errMsg := prediction.Error.Message
  171. if errMsg == "" {
  172. errMsg = prediction.Error.Detail
  173. }
  174. if errMsg == "" {
  175. errMsg = prediction.Error.Code
  176. }
  177. if errMsg == "" {
  178. errMsg = "replicate adaptor: prediction error"
  179. }
  180. return nil, types.NewError(errors.New(errMsg), types.ErrorCodeBadResponse)
  181. }
  182. if prediction.Status != "" && !strings.EqualFold(prediction.Status, "succeeded") {
  183. return nil, types.NewError(fmt.Errorf("replicate adaptor: prediction status %q", prediction.Status), types.ErrorCodeBadResponse)
  184. }
  185. var urls []string
  186. appendOutput := func(value string) {
  187. value = strings.TrimSpace(value)
  188. if value == "" {
  189. return
  190. }
  191. urls = append(urls, value)
  192. }
  193. switch output := prediction.Output.(type) {
  194. case string:
  195. appendOutput(output)
  196. case []any:
  197. for _, item := range output {
  198. if str, ok := item.(string); ok {
  199. appendOutput(str)
  200. }
  201. }
  202. case nil:
  203. // no output
  204. default:
  205. if str, ok := output.(fmt.Stringer); ok {
  206. appendOutput(str.String())
  207. }
  208. }
  209. if len(urls) == 0 {
  210. return nil, types.NewError(errors.New("replicate adaptor: empty prediction output"), types.ErrorCodeBadResponseBody)
  211. }
  212. var imageReq *dto.ImageRequest
  213. if info != nil {
  214. if req, ok := info.Request.(*dto.ImageRequest); ok {
  215. imageReq = req
  216. }
  217. }
  218. wantsBase64 := imageReq != nil && strings.EqualFold(imageReq.ResponseFormat, "b64_json")
  219. imageResponse := dto.ImageResponse{
  220. Created: common.GetTimestamp(),
  221. Data: make([]dto.ImageData, 0),
  222. }
  223. if wantsBase64 {
  224. converted, convErr := downloadImagesToBase64(urls)
  225. if convErr != nil {
  226. return nil, types.NewError(convErr, types.ErrorCodeBadResponse)
  227. }
  228. for _, content := range converted {
  229. if content == "" {
  230. continue
  231. }
  232. imageResponse.Data = append(imageResponse.Data, dto.ImageData{B64Json: content})
  233. }
  234. } else {
  235. for _, url := range urls {
  236. if url == "" {
  237. continue
  238. }
  239. imageResponse.Data = append(imageResponse.Data, dto.ImageData{Url: url})
  240. }
  241. }
  242. if len(imageResponse.Data) == 0 {
  243. return nil, types.NewError(errors.New("replicate adaptor: no usable image data"), types.ErrorCodeBadResponse)
  244. }
  245. responseBytes, err := common.Marshal(imageResponse)
  246. if err != nil {
  247. return nil, types.NewError(fmt.Errorf("replicate adaptor: encode response failed: %w", err), types.ErrorCodeBadResponseBody)
  248. }
  249. c.Writer.Header().Set("Content-Type", "application/json")
  250. c.Writer.WriteHeader(http.StatusOK)
  251. _, _ = c.Writer.Write(responseBytes)
  252. usage := &dto.Usage{}
  253. return usage, nil
  254. }
  255. func (a *Adaptor) GetModelList() []string {
  256. return ModelList
  257. }
  258. func (a *Adaptor) GetChannelName() string {
  259. return ChannelName
  260. }
  261. func downloadImagesToBase64(urls []string) ([]string, error) {
  262. results := make([]string, 0, len(urls))
  263. for _, url := range urls {
  264. if strings.TrimSpace(url) == "" {
  265. continue
  266. }
  267. _, data, err := service.GetImageFromUrl(url)
  268. if err != nil {
  269. return nil, fmt.Errorf("replicate adaptor: failed to download image from %s: %w", url, err)
  270. }
  271. results = append(results, data)
  272. }
  273. return results, nil
  274. }
  275. func mapOpenAISizeToFlux(size string) (aspect string, width int, height int, ok bool) {
  276. parts := strings.Split(size, "x")
  277. if len(parts) != 2 {
  278. return "", 0, 0, false
  279. }
  280. w, err1 := strconv.Atoi(strings.TrimSpace(parts[0]))
  281. h, err2 := strconv.Atoi(strings.TrimSpace(parts[1]))
  282. if err1 != nil || err2 != nil || w <= 0 || h <= 0 {
  283. return "", 0, 0, false
  284. }
  285. switch {
  286. case w == h:
  287. return "1:1", 0, 0, true
  288. case w == 1792 && h == 1024:
  289. return "16:9", 0, 0, true
  290. case w == 1024 && h == 1792:
  291. return "9:16", 0, 0, true
  292. case w == 1536 && h == 1024:
  293. return "3:2", 0, 0, true
  294. case w == 1024 && h == 1536:
  295. return "2:3", 0, 0, true
  296. }
  297. rw, rh := reduceRatio(w, h)
  298. ratioStr := fmt.Sprintf("%d:%d", rw, rh)
  299. switch ratioStr {
  300. case "1:1", "16:9", "9:16", "3:2", "2:3", "4:5", "5:4", "3:4", "4:3":
  301. return ratioStr, 0, 0, true
  302. }
  303. width = normalizeFluxDimension(w)
  304. height = normalizeFluxDimension(h)
  305. return "custom", width, height, true
  306. }
  307. func reduceRatio(w, h int) (int, int) {
  308. g := gcd(w, h)
  309. if g == 0 {
  310. return w, h
  311. }
  312. return w / g, h / g
  313. }
  314. func gcd(a, b int) int {
  315. for b != 0 {
  316. a, b = b, a%b
  317. }
  318. if a < 0 {
  319. return -a
  320. }
  321. return a
  322. }
  323. func normalizeFluxDimension(value int) int {
  324. const (
  325. minDim = 256
  326. maxDim = 1440
  327. step = 32
  328. )
  329. if value < minDim {
  330. value = minDim
  331. }
  332. if value > maxDim {
  333. value = maxDim
  334. }
  335. remainder := value % step
  336. if remainder != 0 {
  337. if remainder >= step/2 {
  338. value += step - remainder
  339. } else {
  340. value -= remainder
  341. }
  342. }
  343. if value < minDim {
  344. value = minDim
  345. }
  346. if value > maxDim {
  347. value = maxDim
  348. }
  349. return value
  350. }
  351. func uploadFileFromForm(c *gin.Context, info *relaycommon.RelayInfo, fieldCandidates ...string) (string, error) {
  352. if info == nil {
  353. return "", errors.New("replicate adaptor: relay info is nil")
  354. }
  355. mf := c.Request.MultipartForm
  356. if mf == nil {
  357. if _, err := c.MultipartForm(); err != nil {
  358. return "", fmt.Errorf("replicate adaptor: parse multipart form failed: %w", err)
  359. }
  360. mf = c.Request.MultipartForm
  361. }
  362. if mf == nil || len(mf.File) == 0 {
  363. return "", nil
  364. }
  365. if len(fieldCandidates) == 0 {
  366. fieldCandidates = []string{"image", "image[]", "image_prompt"}
  367. }
  368. var fileHeader *multipart.FileHeader
  369. for _, key := range fieldCandidates {
  370. if files := mf.File[key]; len(files) > 0 {
  371. fileHeader = files[0]
  372. break
  373. }
  374. }
  375. if fileHeader == nil {
  376. for _, files := range mf.File {
  377. if len(files) > 0 {
  378. fileHeader = files[0]
  379. break
  380. }
  381. }
  382. }
  383. if fileHeader == nil {
  384. return "", nil
  385. }
  386. file, err := fileHeader.Open()
  387. if err != nil {
  388. return "", fmt.Errorf("replicate adaptor: failed to open image file: %w", err)
  389. }
  390. defer file.Close()
  391. var body bytes.Buffer
  392. writer := multipart.NewWriter(&body)
  393. hdr := make(textproto.MIMEHeader)
  394. hdr.Set("Content-Disposition", fmt.Sprintf("form-data; name=\"content\"; filename=\"%s\"", fileHeader.Filename))
  395. contentType := fileHeader.Header.Get("Content-Type")
  396. if contentType == "" {
  397. contentType = "application/octet-stream"
  398. }
  399. hdr.Set("Content-Type", contentType)
  400. part, err := writer.CreatePart(hdr)
  401. if err != nil {
  402. writer.Close()
  403. return "", fmt.Errorf("replicate adaptor: create upload form failed: %w", err)
  404. }
  405. if _, err := io.Copy(part, file); err != nil {
  406. writer.Close()
  407. return "", fmt.Errorf("replicate adaptor: copy image content failed: %w", err)
  408. }
  409. formContentType := writer.FormDataContentType()
  410. writer.Close()
  411. baseURL := info.ChannelBaseUrl
  412. if baseURL == "" {
  413. baseURL = constant.ChannelBaseURLs[constant.ChannelTypeReplicate]
  414. }
  415. uploadURL := relaycommon.GetFullRequestURL(baseURL, "/v1/files", info.ChannelType)
  416. req, err := http.NewRequest(http.MethodPost, uploadURL, &body)
  417. if err != nil {
  418. return "", fmt.Errorf("replicate adaptor: create upload request failed: %w", err)
  419. }
  420. req.Header.Set("Content-Type", formContentType)
  421. req.Header.Set("Authorization", "Bearer "+info.ApiKey)
  422. resp, err := service.GetHttpClient().Do(req)
  423. if err != nil {
  424. return "", fmt.Errorf("replicate adaptor: upload image failed: %w", err)
  425. }
  426. defer resp.Body.Close()
  427. respBody, err := io.ReadAll(resp.Body)
  428. if err != nil {
  429. return "", fmt.Errorf("replicate adaptor: read upload response failed: %w", err)
  430. }
  431. if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
  432. return "", fmt.Errorf("replicate adaptor: upload image failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
  433. }
  434. var uploadResp FileUploadResponse
  435. if err := common.Unmarshal(respBody, &uploadResp); err != nil {
  436. return "", fmt.Errorf("replicate adaptor: decode upload response failed: %w", err)
  437. }
  438. if uploadResp.Urls.Get == "" {
  439. return "", errors.New("replicate adaptor: upload response missing url")
  440. }
  441. return uploadResp.Urls.Get, nil
  442. }
  443. func (a *Adaptor) ConvertOpenAIRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeneralOpenAIRequest) (any, error) {
  444. return nil, errors.New("replicate adaptor: ConvertOpenAIRequest is not implemented")
  445. }
  446. func (a *Adaptor) ConvertRerankRequest(*gin.Context, int, dto.RerankRequest) (any, error) {
  447. return nil, errors.New("replicate adaptor: ConvertRerankRequest is not implemented")
  448. }
  449. func (a *Adaptor) ConvertEmbeddingRequest(*gin.Context, *relaycommon.RelayInfo, dto.EmbeddingRequest) (any, error) {
  450. return nil, errors.New("replicate adaptor: ConvertEmbeddingRequest is not implemented")
  451. }
  452. func (a *Adaptor) ConvertAudioRequest(*gin.Context, *relaycommon.RelayInfo, dto.AudioRequest) (io.Reader, error) {
  453. return nil, errors.New("replicate adaptor: ConvertAudioRequest is not implemented")
  454. }
  455. func (a *Adaptor) ConvertOpenAIResponsesRequest(*gin.Context, *relaycommon.RelayInfo, dto.OpenAIResponsesRequest) (any, error) {
  456. return nil, errors.New("replicate adaptor: ConvertOpenAIResponsesRequest is not implemented")
  457. }
  458. func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
  459. return nil, errors.New("replicate adaptor: ConvertClaudeRequest is not implemented")
  460. }
  461. func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
  462. return nil, errors.New("replicate adaptor: ConvertGeminiRequest is not implemented")
  463. }