Matt Rubens 2 месяцев назад
Родитель
Сommit
e4847ed8d9

+ 7 - 0
packages/types/src/image-generation.ts

@@ -2,10 +2,16 @@
  * Image generation model constants
  */
 
+/**
+ * API method used for image generation
+ */
+export type ImageGenerationApiMethod = "chat_completions" | "images_api"
+
 export interface ImageGenerationModel {
 	value: string
 	label: string
 	provider: ImageGenerationProvider
+	apiMethod?: ImageGenerationApiMethod
 }
 
 export const IMAGE_GENERATION_MODELS: ImageGenerationModel[] = [
@@ -17,6 +23,7 @@ export const IMAGE_GENERATION_MODELS: ImageGenerationModel[] = [
 	// Roo Code Cloud models
 	{ value: "google/gemini-2.5-flash-image", label: "Gemini 2.5 Flash Image", provider: "roo" },
 	{ value: "google/gemini-3-pro-image", label: "Gemini 3 Pro Image", provider: "roo" },
+	{ value: "bfl/flux-2-pro", label: "BFL Flux 2 Pro", provider: "roo", apiMethod: "images_api" },
 ]
 
 /**

+ 4 - 1
src/api/providers/openrouter.ts

@@ -435,7 +435,8 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
 	}
 
 	/**
-	 * Generate an image using OpenRouter's image generation API
+	 * Generate an image using OpenRouter's image generation API (chat completions with modalities)
+	 * Note: OpenRouter only supports the chat completions approach, not the /images/generations endpoint
 	 * @param prompt The text prompt for image generation
 	 * @param model The model to use for generation
 	 * @param apiKey The OpenRouter API key (must be explicitly provided)
@@ -456,6 +457,8 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
 		}
 
 		const baseURL = this.options.openRouterBaseUrl || "https://openrouter.ai/api/v1"
+
+		// OpenRouter only supports chat completions approach for image generation
 		return generateImageWithProvider({
 			baseURL,
 			authToken: apiKey,

+ 25 - 4
src/api/providers/roo.ts

@@ -1,7 +1,7 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI from "openai"
 
-import { rooDefaultModelId, getApiProtocol } from "@roo-code/types"
+import { rooDefaultModelId, getApiProtocol, type ImageGenerationApiMethod } from "@roo-code/types"
 import { CloudService } from "@roo-code/cloud"
 
 import type { ApiHandlerOptions, ModelRecord } from "../../shared/api"
@@ -15,7 +15,7 @@ import type { ApiHandlerCreateMessageMetadata } from "../index"
 import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"
 import { getModels, getModelsFromCache } from "../providers/fetchers/modelCache"
 import { handleOpenAIError } from "./utils/openai-error-handler"
-import { generateImageWithProvider, ImageGenerationResult } from "./utils/image-generation"
+import { generateImageWithProvider, generateImageWithImagesApi, ImageGenerationResult } from "./utils/image-generation"
 import { t } from "../../i18n"
 
 // Extend OpenAI's CompletionUsage to include Roo specific fields
@@ -273,9 +273,15 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<string> {
 	 * @param prompt The text prompt for image generation
 	 * @param model The model to use for generation
 	 * @param inputImage Optional base64 encoded input image data URL
+	 * @param apiMethod The API method to use (chat_completions or images_api)
 	 * @returns The generated image data and format, or an error
 	 */
-	async generateImage(prompt: string, model: string, inputImage?: string): Promise<ImageGenerationResult> {
+	async generateImage(
+		prompt: string,
+		model: string,
+		inputImage?: string,
+		apiMethod?: ImageGenerationApiMethod,
+	): Promise<ImageGenerationResult> {
 		const sessionToken = getSessionToken()
 
 		if (!sessionToken || sessionToken === "unauthenticated") {
@@ -285,8 +291,23 @@ export class RooHandler extends BaseOpenAiCompatibleProvider<string> {
 			}
 		}
 
+		const baseURL = `${this.fetcherBaseURL}/v1`
+
+		// Use the specified API method, defaulting to chat_completions for backward compatibility
+		if (apiMethod === "images_api") {
+			return generateImageWithImagesApi({
+				baseURL,
+				authToken: sessionToken,
+				model,
+				prompt,
+				inputImage,
+				outputFormat: "png",
+			})
+		}
+
+		// Default to chat completions approach
 		return generateImageWithProvider({
-			baseURL: `${this.fetcherBaseURL}/v1`,
+			baseURL,
 			authToken: sessionToken,
 			model,
 			prompt,

+ 417 - 0
src/api/providers/utils/__tests__/image-generation.spec.ts

@@ -0,0 +1,417 @@
+import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"
+import { generateImageWithImagesApi, generateImageWithProvider } from "../image-generation"
+
+// Mock the i18n module
+vi.mock("../../../i18n", () => ({
+	t: (key: string, options?: any) => {
+		// Return a sensible mock for i18n
+		if (key === "tools:generateImage.failedWithMessage" && options?.message) {
+			return options.message
+		}
+		return key
+	},
+}))
+
+// Mock fetch globally
+global.fetch = vi.fn()
+global.FormData = vi.fn(() => ({
+	append: vi.fn(),
+})) as any
+global.Blob = vi.fn() as any
+global.atob = vi.fn((str: string) => {
+	return Buffer.from(str, "base64").toString("binary")
+})
+
+describe("generateImageWithImagesApi", () => {
+	beforeEach(() => {
+		vi.clearAllMocks()
+	})
+
+	afterEach(() => {
+		vi.clearAllMocks()
+	})
+
+	describe("image generation (text-to-image)", () => {
+		it("should successfully generate an image", async () => {
+			const mockBase64 = Buffer.from("fake image data").toString("base64")
+			const mockResponse = {
+				ok: true,
+				json: vi.fn().mockResolvedValue({
+					data: [{ b64_json: mockBase64 }],
+				}),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+				outputFormat: "png",
+			})
+
+			expect(result.success).toBe(true)
+			expect(result.imageData).toContain("data:image/png;base64,")
+			expect(result.imageFormat).toBe("png")
+
+			// Verify fetch was called with correct parameters
+			expect(global.fetch).toHaveBeenCalledWith(
+				"https://api.example.com/v1/images/generations",
+				expect.objectContaining({
+					method: "POST",
+					headers: expect.objectContaining({
+						Authorization: "Bearer test-token",
+						"Content-Type": "application/json",
+					}),
+				}),
+			)
+		})
+
+		it("should handle API errors gracefully", async () => {
+			const mockResponse = {
+				ok: false,
+				status: 400,
+				statusText: "Bad Request",
+				text: vi.fn().mockResolvedValue("{}"),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+			})
+
+			expect(result.success).toBe(false)
+			expect(result.error).toBeDefined()
+		})
+
+		it("should handle missing image data in response", async () => {
+			const mockResponse = {
+				ok: true,
+				json: vi.fn().mockResolvedValue({
+					data: [{}], // Missing b64_json and url
+				}),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+			})
+
+			expect(result.success).toBe(false)
+			expect(result.error).toBeDefined()
+		})
+
+		it("should handle URL response instead of b64_json", async () => {
+			const mockResponse = {
+				ok: true,
+				json: vi.fn().mockResolvedValue({
+					data: [{ url: "" }],
+				}),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+			})
+
+			expect(result.success).toBe(true)
+			expect(result.imageData).toBe("")
+			expect(result.imageFormat).toBe("png")
+		})
+
+		it("should handle external URL response", async () => {
+			const mockResponse = {
+				ok: true,
+				json: vi.fn().mockResolvedValue({
+					data: [{ url: "https://example.com/generated-image.png" }],
+				}),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+				outputFormat: "png",
+			})
+
+			expect(result.success).toBe(true)
+			expect(result.imageData).toBe("https://example.com/generated-image.png")
+			expect(result.imageFormat).toBe("png")
+		})
+
+		it("should handle empty data array in response", async () => {
+			const mockResponse = {
+				ok: true,
+				json: vi.fn().mockResolvedValue({
+					data: [],
+				}),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+			})
+
+			expect(result.success).toBe(false)
+			expect(result.error).toBeDefined()
+		})
+
+		it("should handle API error response", async () => {
+			const mockResponse = {
+				ok: true,
+				json: vi.fn().mockResolvedValue({
+					error: {
+						message: "Rate limit exceeded",
+						type: "rate_limit_error",
+					},
+				}),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+			})
+
+			expect(result.success).toBe(false)
+			expect(result.error).toBeDefined()
+		})
+
+		it("should include optional parameters when provided", async () => {
+			const mockBase64 = Buffer.from("fake image data").toString("base64")
+			const mockResponse = {
+				ok: true,
+				json: vi.fn().mockResolvedValue({
+					data: [{ b64_json: mockBase64 }],
+				}),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+				size: "1024x1024",
+				quality: "hd",
+				outputFormat: "png",
+			})
+
+			expect(result.success).toBe(true)
+
+			// Verify fetch was called with optional parameters
+			const callArgs = vi.mocked(global.fetch).mock.calls[0]
+			const body = JSON.parse(callArgs[1]?.body as string)
+			expect(body.size).toBe("1024x1024")
+			expect(body.quality).toBe("hd")
+		})
+
+		it("should handle network errors", async () => {
+			vi.mocked(global.fetch).mockRejectedValue(new Error("Network error"))
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+			})
+
+			expect(result.success).toBe(false)
+			expect(result.error).toContain("Network error")
+		})
+	})
+
+	describe("image editing", () => {
+		it("should use /images/generations endpoint with inputImage in request body", async () => {
+			const mockBase64 = Buffer.from("fake image data").toString("base64")
+			const mockResponse = {
+				ok: true,
+				json: vi.fn().mockResolvedValue({
+					data: [{ b64_json: mockBase64 }],
+				}),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const inputImageDataUrl = `data:image/png;base64,${mockBase64}`
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "Make it blue",
+				inputImage: inputImageDataUrl,
+				outputFormat: "png",
+			})
+
+			expect(result.success).toBe(true)
+
+			// Verify /images/generations endpoint was used (not /images/edits)
+			const callUrl = vi.mocked(global.fetch).mock.calls[0][0]
+			expect(callUrl).toContain("/images/generations")
+		})
+
+		it("should handle edit operation errors", async () => {
+			const mockResponse = {
+				ok: false,
+				status: 400,
+				statusText: "Bad Request",
+				text: vi.fn().mockResolvedValue("{}"),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const inputImageDataUrl =
+				""
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "Make it blue",
+				inputImage: inputImageDataUrl,
+			})
+
+			expect(result.success).toBe(false)
+			expect(result.error).toBeDefined()
+		})
+	})
+
+	describe("output format handling", () => {
+		it("should use png format by default", async () => {
+			const mockBase64 = Buffer.from("fake image data").toString("base64")
+			const mockResponse = {
+				ok: true,
+				json: vi.fn().mockResolvedValue({
+					data: [{ b64_json: mockBase64 }],
+				}),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+			})
+
+			expect(result.imageFormat).toBe("png")
+			expect(result.imageData).toContain("data:image/png;base64,")
+		})
+
+		it("should use specified output format", async () => {
+			const mockBase64 = Buffer.from("fake image data").toString("base64")
+			const mockResponse = {
+				ok: true,
+				json: vi.fn().mockResolvedValue({
+					data: [{ b64_json: mockBase64 }],
+				}),
+			}
+
+			vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+			const result = await generateImageWithImagesApi({
+				baseURL: "https://api.example.com/v1",
+				authToken: "test-token",
+				model: "gpt-image-1",
+				prompt: "A cute cat",
+				outputFormat: "jpeg",
+			})
+
+			expect(result.imageFormat).toBe("jpeg")
+			expect(result.imageData).toContain("data:image/jpeg;base64,")
+		})
+	})
+})
+
+describe("generateImageWithProvider (chat completions)", () => {
+	beforeEach(() => {
+		vi.clearAllMocks()
+	})
+
+	afterEach(() => {
+		vi.clearAllMocks()
+	})
+
+	it("should use /chat/completions endpoint", async () => {
+		const mockResponse = {
+			ok: true,
+			json: vi.fn().mockResolvedValue({
+				choices: [
+					{
+						message: {
+							images: [
+								{
+									image_url: {
+										url: "",
+									},
+								},
+							],
+						},
+					},
+				],
+			}),
+		}
+
+		vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+		const result = await generateImageWithProvider({
+			baseURL: "https://api.example.com/v1",
+			authToken: "test-token",
+			model: "gpt-4-vision",
+			prompt: "A cute cat",
+		})
+
+		expect(result.success).toBe(true)
+
+		// Verify /chat/completions endpoint was used
+		const callUrl = vi.mocked(global.fetch).mock.calls[0][0]
+		expect(callUrl).toContain("/chat/completions")
+	})
+
+	it("should handle missing images in response", async () => {
+		const mockResponse = {
+			ok: true,
+			json: vi.fn().mockResolvedValue({
+				choices: [{ message: { content: "No images" } }],
+			}),
+		}
+
+		vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
+
+		const result = await generateImageWithProvider({
+			baseURL: "https://api.example.com/v1",
+			authToken: "test-token",
+			model: "gpt-4-vision",
+			prompt: "A cute cat",
+		})
+
+		expect(result.success).toBe(false)
+		expect(result.error).toBeDefined()
+	})
+})

+ 164 - 0
src/api/providers/utils/image-generation.ts

@@ -20,6 +20,18 @@ interface ImageGenerationResponse {
 	}
 }
 
+interface ImagesApiResponse {
+	data?: Array<{
+		b64_json?: string
+		url?: string
+	}>
+	error?: {
+		message?: string
+		type?: string
+		code?: string
+	}
+}
+
 export interface ImageGenerationResult {
 	success: boolean
 	imageData?: string
@@ -35,6 +47,17 @@ interface ImageGenerationOptions {
 	inputImage?: string
 }
 
+interface ImagesApiOptions {
+	baseURL: string
+	authToken: string
+	model: string
+	prompt: string
+	inputImage?: string
+	size?: string
+	quality?: string
+	outputFormat?: string
+}
+
 /**
  * Shared image generation implementation for OpenRouter and Roo Code Cloud providers
  */
@@ -147,3 +170,144 @@ export async function generateImageWithProvider(options: ImageGenerationOptions)
 		}
 	}
 }
+
+/**
+ * Generate an image using OpenAI's Images API (/v1/images/generations)
+ * Supports BFL models (Flux) with provider-specific options for image editing
+ */
+export async function generateImageWithImagesApi(options: ImagesApiOptions): Promise<ImageGenerationResult> {
+	const { baseURL, authToken, model, prompt, inputImage, outputFormat = "png" } = options
+
+	try {
+		const url = `${baseURL}/images/generations`
+
+		// Build the request body
+		// For BFL models, inputImage is passed via providerOptions.blackForestLabs.inputImage
+		const requestBody: Record<string, unknown> = {
+			model,
+			prompt,
+			n: 1,
+		}
+
+		// Add optional parameters
+		if (options.size) {
+			requestBody.size = options.size
+		}
+		if (options.quality) {
+			requestBody.quality = options.quality
+		}
+
+		// For BFL (Black Forest Labs) models like flux-pro-1.1, use providerOptions
+		if (model.startsWith("bfl/")) {
+			requestBody.providerOptions = {
+				blackForestLabs: {
+					outputFormat: outputFormat,
+					// inputImage: Base64 encoded image or URL of image to use as reference
+					...(inputImage && { inputImage }),
+				},
+			}
+		} else {
+			// For other models, use standard output_format parameter
+			requestBody.output_format = outputFormat
+		}
+
+		const fetchOptions: RequestInit = {
+			method: "POST",
+			headers: {
+				Authorization: `Bearer ${authToken}`,
+				"Content-Type": "application/json",
+				"HTTP-Referer": "https://github.com/RooVetGit/Roo-Code",
+				"X-Title": "Roo Code",
+			},
+			body: JSON.stringify(requestBody),
+		}
+
+		const response = await fetch(url, fetchOptions)
+
+		if (!response.ok) {
+			const errorText = await response.text()
+			let errorMessage = t("tools:generateImage.failedWithStatus", {
+				status: response.status,
+				statusText: response.statusText,
+			})
+
+			try {
+				const errorJson = JSON.parse(errorText)
+				if (errorJson.error?.message) {
+					errorMessage = t("tools:generateImage.failedWithMessage", {
+						message: errorJson.error.message,
+					})
+				}
+			} catch {
+				// Use default error message
+			}
+			return {
+				success: false,
+				error: errorMessage,
+			}
+		}
+
+		const result: ImagesApiResponse = await response.json()
+
+		if (result.error) {
+			return {
+				success: false,
+				error: t("tools:generateImage.failedWithMessage", {
+					message: result.error.message,
+				}),
+			}
+		}
+
+		// Extract the generated image from the response
+		const images = result.data
+		if (!images || images.length === 0) {
+			return {
+				success: false,
+				error: t("tools:generateImage.noImageGenerated"),
+			}
+		}
+
+		const imageItem = images[0]
+
+		// Handle b64_json response (most common)
+		if (imageItem?.b64_json) {
+			// Convert base64 to data URL
+			const dataUrl = `data:image/${outputFormat};base64,${imageItem.b64_json}`
+			return {
+				success: true,
+				imageData: dataUrl,
+				imageFormat: outputFormat,
+			}
+		}
+
+		// Handle URL response (fallback)
+		if (imageItem?.url) {
+			// If it's already a data URL, use it directly
+			if (imageItem.url.startsWith("data:image/")) {
+				const formatMatch = imageItem.url.match(/^data:image\/(\w+);/)
+				const format = formatMatch?.[1] || outputFormat
+				return {
+					success: true,
+					imageData: imageItem.url,
+					imageFormat: format,
+				}
+			}
+			// For external URLs, return as-is (the caller will need to handle fetching)
+			return {
+				success: true,
+				imageData: imageItem.url,
+				imageFormat: outputFormat,
+			}
+		}
+
+		return {
+			success: false,
+			error: t("tools:generateImage.invalidImageData"),
+		}
+	} catch (error) {
+		return {
+			success: false,
+			error: error instanceof Error ? error.message : t("tools:generateImage.unknownError"),
+		}
+	}
+}

+ 14 - 10
src/core/tools/GenerateImageTool.ts

@@ -135,24 +135,28 @@ export class GenerateImageTool extends BaseTool<"generate_image"> {
 
 		// Get the selected model
 		let selectedModel = state?.openRouterImageGenerationSelectedModel
+		let modelInfo = undefined
 
-		// Verify the selected model matches the selected provider
-		// If not, default to first model of the selected provider
+		// Find the model info matching both value AND provider
+		// (since the same model value can exist for multiple providers)
 		if (selectedModel) {
-			const modelInfo = IMAGE_GENERATION_MODELS.find((m) => m.value === selectedModel)
-			if (!modelInfo || modelInfo.provider !== imageProvider) {
-				// Model doesn't match provider, use first model for selected provider
+			modelInfo = IMAGE_GENERATION_MODELS.find((m) => m.value === selectedModel && m.provider === imageProvider)
+			if (!modelInfo) {
+				// Model doesn't exist for this provider, use first model for selected provider
 				const providerModels = IMAGE_GENERATION_MODELS.filter((m) => m.provider === imageProvider)
-				selectedModel = providerModels[0]?.value || IMAGE_GENERATION_MODEL_IDS[0]
+				modelInfo = providerModels[0]
+				selectedModel = modelInfo?.value || IMAGE_GENERATION_MODEL_IDS[0]
 			}
 		} else {
 			// No model selected, use first model for selected provider
 			const providerModels = IMAGE_GENERATION_MODELS.filter((m) => m.provider === imageProvider)
-			selectedModel = providerModels[0]?.value || IMAGE_GENERATION_MODEL_IDS[0]
+			modelInfo = providerModels[0]
+			selectedModel = modelInfo?.value || IMAGE_GENERATION_MODEL_IDS[0]
 		}
 
 		// Use the provider selection
 		const modelProvider = imageProvider
+		const apiMethod = modelInfo?.apiMethod
 
 		// Validate API key for OpenRouter
 		const openRouterApiKey = state?.openRouterImageApiKey
@@ -192,11 +196,11 @@ export class GenerateImageTool extends BaseTool<"generate_image"> {
 
 			let result
 			if (modelProvider === "roo") {
-				// Use Roo Code Cloud provider
+				// Use Roo Code Cloud provider (supports both chat completions and images API)
 				const rooHandler = new RooHandler({} as any)
-				result = await rooHandler.generateImage(prompt, selectedModel, inputImageData)
+				result = await rooHandler.generateImage(prompt, selectedModel, inputImageData, apiMethod)
 			} else {
-				// Use OpenRouter provider
+				// Use OpenRouter provider (only supports chat completions API)
 				const openRouterHandler = new OpenRouterHandler({} as any)
 				result = await openRouterHandler.generateImage(prompt, selectedModel, openRouterApiKey!, inputImageData)
 			}