Просмотр исходного кода

feat: add optional input image parameter to image generation tool (#7525)

Co-authored-by: Roo Code <[email protected]>
Co-authored-by: Daniel Riccio <[email protected]>
roomote[bot] 4 месяцев назад
Родитель
Сommit
b22a618ee2

+ 21 - 2
src/api/providers/openrouter.ts

@@ -275,9 +275,15 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
 	 * @param prompt The text prompt for image generation
 	 * @param model The model to use for generation
 	 * @param apiKey The OpenRouter API key (must be explicitly provided)
+	 * @param inputImage Optional base64 encoded input image data URL
 	 * @returns The generated image data and format, or an error
 	 */
-	async generateImage(prompt: string, model: string, apiKey: string): Promise<ImageGenerationResult> {
+	async generateImage(
+		prompt: string,
+		model: string,
+		apiKey: string,
+		inputImage?: string,
+	): Promise<ImageGenerationResult> {
 		if (!apiKey) {
 			return {
 				success: false,
@@ -299,7 +305,20 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
 					messages: [
 						{
 							role: "user",
-							content: prompt,
+							content: inputImage
+								? [
+										{
+											type: "text",
+											text: prompt,
+										},
+										{
+											type: "image_url",
+											image_url: {
+												url: inputImage,
+											},
+										},
+									]
+								: prompt,
 						},
 					],
 					modalities: ["image", "text"],

+ 19 - 3
src/core/prompts/tools/generate-image.ts

@@ -2,19 +2,35 @@ import { ToolArgs } from "./types"
 
 export function getGenerateImageDescription(args: ToolArgs): string {
 	return `## generate_image
-Description: Request to generate an image using AI models through OpenRouter API. This tool creates images from text prompts and saves them to the specified path.
+Description: Request to generate or edit an image using AI models through OpenRouter API. This tool can create new images from text prompts or modify existing images based on your instructions. When an input image is provided, the AI will apply the requested edits, transformations, or enhancements to that image.
 Parameters:
-- prompt: (required) The text prompt describing the image to generate
-- path: (required) The file path where the generated image should be saved (relative to the current workspace directory ${args.cwd}). The tool will automatically add the appropriate image extension if not provided.
+- prompt: (required) The text prompt describing what to generate or how to edit the image
+- path: (required) The file path where the generated/edited image should be saved (relative to the current workspace directory ${args.cwd}). The tool will automatically add the appropriate image extension if not provided.
+- image: (optional) The file path to an input image to edit or transform (relative to the current workspace directory ${args.cwd}). Supported formats: PNG, JPG, JPEG, GIF, WEBP.
 Usage:
 <generate_image>
 <prompt>Your image description here</prompt>
 <path>path/to/save/image.png</path>
+<image>path/to/input/image.jpg</image>
 </generate_image>
 
 Example: Requesting to generate a sunset image
 <generate_image>
 <prompt>A beautiful sunset over mountains with vibrant orange and purple colors</prompt>
 <path>images/sunset.png</path>
+</generate_image>
+
+Example: Editing an existing image
+<generate_image>
+<prompt>Transform this image into a watercolor painting style</prompt>
+<path>images/watercolor-output.png</path>
+<image>images/original-photo.jpg</image>
+</generate_image>
+
+Example: Upscaling and enhancing an image
+<generate_image>
+<prompt>Upscale this image to higher resolution, enhance details, improve clarity and sharpness while maintaining the original content and composition</prompt>
+<path>images/enhanced-photo.png</path>
+<image>images/low-res-photo.jpg</image>
 </generate_image>`
 }

+ 313 - 0
src/core/tools/__tests__/generateImageTool.test.ts

@@ -0,0 +1,313 @@
+import { describe, it, expect, vi, beforeEach } from "vitest"
+import { generateImageTool } from "../generateImageTool"
+import { ToolUse } from "../../../shared/tools"
+import { Task } from "../../task/Task"
+import * as fs from "fs/promises"
+import * as pathUtils from "../../../utils/pathUtils"
+import * as fileUtils from "../../../utils/fs"
+import { formatResponse } from "../../prompts/responses"
+import { EXPERIMENT_IDS } from "../../../shared/experiments"
+import { OpenRouterHandler } from "../../../api/providers/openrouter"
+
+// Mock dependencies
+vi.mock("fs/promises")
+vi.mock("../../../utils/pathUtils")
+vi.mock("../../../utils/fs")
+vi.mock("../../../utils/safeWriteJson")
+vi.mock("../../../api/providers/openrouter")
+
+describe("generateImageTool", () => {
+	let mockCline: any
+	let mockAskApproval: any
+	let mockHandleError: any
+	let mockPushToolResult: any
+	let mockRemoveClosingTag: any
+
+	beforeEach(() => {
+		vi.clearAllMocks()
+
+		// Setup mock Cline instance
+		mockCline = {
+			cwd: "/test/workspace",
+			consecutiveMistakeCount: 0,
+			recordToolError: vi.fn(),
+			recordToolUsage: vi.fn(),
+			sayAndCreateMissingParamError: vi.fn().mockResolvedValue("Missing parameter error"),
+			say: vi.fn(),
+			rooIgnoreController: {
+				validateAccess: vi.fn().mockReturnValue(true),
+			},
+			rooProtectedController: {
+				isWriteProtected: vi.fn().mockReturnValue(false),
+			},
+			providerRef: {
+				deref: vi.fn().mockReturnValue({
+					getState: vi.fn().mockResolvedValue({
+						experiments: {
+							[EXPERIMENT_IDS.IMAGE_GENERATION]: true,
+						},
+						apiConfiguration: {
+							openRouterImageGenerationSettings: {
+								openRouterApiKey: "test-api-key",
+								selectedModel: "google/gemini-2.5-flash-image-preview",
+							},
+						},
+					}),
+				}),
+			},
+			fileContextTracker: {
+				trackFileContext: vi.fn(),
+			},
+			didEditFile: false,
+		}
+
+		mockAskApproval = vi.fn().mockResolvedValue(true)
+		mockHandleError = vi.fn()
+		mockPushToolResult = vi.fn()
+		mockRemoveClosingTag = vi.fn((tag, content) => content || "")
+
+		// Mock file system operations
+		vi.mocked(fileUtils.fileExistsAtPath).mockResolvedValue(true)
+		vi.mocked(fs.readFile).mockResolvedValue(Buffer.from("fake-image-data"))
+		vi.mocked(fs.mkdir).mockResolvedValue(undefined)
+		vi.mocked(fs.writeFile).mockResolvedValue(undefined)
+		vi.mocked(pathUtils.isPathOutsideWorkspace).mockReturnValue(false)
+	})
+
+	describe("partial block handling", () => {
+		it("should return early when block is partial", async () => {
+			const partialBlock: ToolUse = {
+				type: "tool_use",
+				name: "generate_image",
+				params: {
+					prompt: "Generate a test image",
+					path: "test-image.png",
+				},
+				partial: true,
+			}
+
+			await generateImageTool(
+				mockCline as Task,
+				partialBlock,
+				mockAskApproval,
+				mockHandleError,
+				mockPushToolResult,
+				mockRemoveClosingTag,
+			)
+
+			// Should not process anything when partial
+			expect(mockAskApproval).not.toHaveBeenCalled()
+			expect(mockPushToolResult).not.toHaveBeenCalled()
+			expect(mockCline.say).not.toHaveBeenCalled()
+		})
+
+		it("should return early when block is partial even with image parameter", async () => {
+			const partialBlock: ToolUse = {
+				type: "tool_use",
+				name: "generate_image",
+				params: {
+					prompt: "Upscale this image",
+					path: "upscaled-image.png",
+					image: "source-image.png",
+				},
+				partial: true,
+			}
+
+			await generateImageTool(
+				mockCline as Task,
+				partialBlock,
+				mockAskApproval,
+				mockHandleError,
+				mockPushToolResult,
+				mockRemoveClosingTag,
+			)
+
+			// Should not process anything when partial
+			expect(mockAskApproval).not.toHaveBeenCalled()
+			expect(mockPushToolResult).not.toHaveBeenCalled()
+			expect(mockCline.say).not.toHaveBeenCalled()
+			expect(fs.readFile).not.toHaveBeenCalled()
+		})
+
+		it("should process when block is not partial", async () => {
+			const completeBlock: ToolUse = {
+				type: "tool_use",
+				name: "generate_image",
+				params: {
+					prompt: "Generate a test image",
+					path: "test-image.png",
+				},
+				partial: false,
+			}
+
+			// Mock the OpenRouterHandler generateImage method
+			const mockGenerateImage = vi.fn().mockResolvedValue({
+				success: true,
+				imageData: "",
+			})
+
+			vi.mocked(OpenRouterHandler).mockImplementation(
+				() =>
+					({
+						generateImage: mockGenerateImage,
+					}) as any,
+			)
+
+			await generateImageTool(
+				mockCline as Task,
+				completeBlock,
+				mockAskApproval,
+				mockHandleError,
+				mockPushToolResult,
+				mockRemoveClosingTag,
+			)
+
+			// Should process the complete block
+			expect(mockAskApproval).toHaveBeenCalled()
+			expect(mockGenerateImage).toHaveBeenCalled()
+			expect(mockPushToolResult).toHaveBeenCalled()
+		})
+	})
+
+	describe("missing parameters", () => {
+		it("should handle missing prompt parameter", async () => {
+			const block: ToolUse = {
+				type: "tool_use",
+				name: "generate_image",
+				params: {
+					path: "test-image.png",
+				},
+				partial: false,
+			}
+
+			await generateImageTool(
+				mockCline as Task,
+				block,
+				mockAskApproval,
+				mockHandleError,
+				mockPushToolResult,
+				mockRemoveClosingTag,
+			)
+
+			expect(mockCline.consecutiveMistakeCount).toBe(1)
+			expect(mockCline.recordToolError).toHaveBeenCalledWith("generate_image")
+			expect(mockCline.sayAndCreateMissingParamError).toHaveBeenCalledWith("generate_image", "prompt")
+			expect(mockPushToolResult).toHaveBeenCalledWith("Missing parameter error")
+		})
+
+		it("should handle missing path parameter", async () => {
+			const block: ToolUse = {
+				type: "tool_use",
+				name: "generate_image",
+				params: {
+					prompt: "Generate a test image",
+				},
+				partial: false,
+			}
+
+			await generateImageTool(
+				mockCline as Task,
+				block,
+				mockAskApproval,
+				mockHandleError,
+				mockPushToolResult,
+				mockRemoveClosingTag,
+			)
+
+			expect(mockCline.consecutiveMistakeCount).toBe(1)
+			expect(mockCline.recordToolError).toHaveBeenCalledWith("generate_image")
+			expect(mockCline.sayAndCreateMissingParamError).toHaveBeenCalledWith("generate_image", "path")
+			expect(mockPushToolResult).toHaveBeenCalledWith("Missing parameter error")
+		})
+	})
+
+	describe("experiment validation", () => {
+		it("should error when image generation experiment is disabled", async () => {
+			// Disable the experiment
+			mockCline.providerRef.deref().getState.mockResolvedValue({
+				experiments: {
+					[EXPERIMENT_IDS.IMAGE_GENERATION]: false,
+				},
+			})
+
+			const block: ToolUse = {
+				type: "tool_use",
+				name: "generate_image",
+				params: {
+					prompt: "Generate a test image",
+					path: "test-image.png",
+				},
+				partial: false,
+			}
+
+			await generateImageTool(
+				mockCline as Task,
+				block,
+				mockAskApproval,
+				mockHandleError,
+				mockPushToolResult,
+				mockRemoveClosingTag,
+			)
+
+			expect(mockPushToolResult).toHaveBeenCalledWith(
+				formatResponse.toolError(
+					"Image generation is an experimental feature that must be enabled in settings. Please enable 'Image Generation' in the Experimental Settings section.",
+				),
+			)
+		})
+	})
+
+	describe("input image validation", () => {
+		it("should handle non-existent input image", async () => {
+			vi.mocked(fileUtils.fileExistsAtPath).mockResolvedValue(false)
+
+			const block: ToolUse = {
+				type: "tool_use",
+				name: "generate_image",
+				params: {
+					prompt: "Upscale this image",
+					path: "upscaled.png",
+					image: "non-existent.png",
+				},
+				partial: false,
+			}
+
+			await generateImageTool(
+				mockCline as Task,
+				block,
+				mockAskApproval,
+				mockHandleError,
+				mockPushToolResult,
+				mockRemoveClosingTag,
+			)
+
+			expect(mockCline.say).toHaveBeenCalledWith("error", expect.stringContaining("Input image not found"))
+			expect(mockPushToolResult).toHaveBeenCalledWith(expect.stringContaining("Input image not found"))
+		})
+
+		it("should handle unsupported image format", async () => {
+			const block: ToolUse = {
+				type: "tool_use",
+				name: "generate_image",
+				params: {
+					prompt: "Upscale this image",
+					path: "upscaled.png",
+					image: "test.bmp", // Unsupported format
+				},
+				partial: false,
+			}
+
+			await generateImageTool(
+				mockCline as Task,
+				block,
+				mockAskApproval,
+				mockHandleError,
+				mockPushToolResult,
+				mockRemoveClosingTag,
+			)
+
+			expect(mockCline.say).toHaveBeenCalledWith("error", expect.stringContaining("Unsupported image format"))
+			expect(mockPushToolResult).toHaveBeenCalledWith(expect.stringContaining("Unsupported image format"))
+		})
+	})
+})

+ 70 - 4
src/core/tools/generateImageTool.ts

@@ -24,6 +24,7 @@ export async function generateImageTool(
 ) {
 	const prompt: string | undefined = block.params.prompt
 	const relPath: string | undefined = block.params.path
+	const inputImagePath: string | undefined = block.params.image
 
 	// Check if the experiment is enabled
 	const provider = cline.providerRef.deref()
@@ -39,8 +40,7 @@ export async function generateImageTool(
 		return
 	}
 
-	if (block.partial && (!prompt || !relPath)) {
-		// Wait for complete parameters
+	if (block.partial) {
 		return
 	}
 
@@ -66,6 +66,66 @@ export async function generateImageTool(
 		return
 	}
 
+	// If input image is provided, validate it exists and can be read
+	let inputImageData: string | undefined
+	if (inputImagePath) {
+		const inputImageFullPath = path.resolve(cline.cwd, inputImagePath)
+
+		// Check if input image exists
+		const inputImageExists = await fileExistsAtPath(inputImageFullPath)
+		if (!inputImageExists) {
+			await cline.say("error", `Input image not found: ${getReadablePath(cline.cwd, inputImagePath)}`)
+			pushToolResult(
+				formatResponse.toolError(`Input image not found: ${getReadablePath(cline.cwd, inputImagePath)}`),
+			)
+			return
+		}
+
+		// Validate input image access permissions
+		const inputImageAccessAllowed = cline.rooIgnoreController?.validateAccess(inputImagePath)
+		if (!inputImageAccessAllowed) {
+			await cline.say("rooignore_error", inputImagePath)
+			pushToolResult(formatResponse.toolError(formatResponse.rooIgnoreError(inputImagePath)))
+			return
+		}
+
+		// Read the input image file
+		try {
+			const imageBuffer = await fs.readFile(inputImageFullPath)
+			const imageExtension = path.extname(inputImageFullPath).toLowerCase().replace(".", "")
+
+			// Validate image format
+			const supportedFormats = ["png", "jpg", "jpeg", "gif", "webp"]
+			if (!supportedFormats.includes(imageExtension)) {
+				await cline.say(
+					"error",
+					`Unsupported image format: ${imageExtension}. Supported formats: ${supportedFormats.join(", ")}`,
+				)
+				pushToolResult(
+					formatResponse.toolError(
+						`Unsupported image format: ${imageExtension}. Supported formats: ${supportedFormats.join(", ")}`,
+					),
+				)
+				return
+			}
+
+			// Convert to base64 data URL
+			const mimeType = imageExtension === "jpg" ? "jpeg" : imageExtension
+			inputImageData = `data:image/${mimeType};base64,${imageBuffer.toString("base64")}`
+		} catch (error) {
+			await cline.say(
+				"error",
+				`Failed to read input image: ${error instanceof Error ? error.message : "Unknown error"}`,
+			)
+			pushToolResult(
+				formatResponse.toolError(
+					`Failed to read input image: ${error instanceof Error ? error.message : "Unknown error"}`,
+				),
+			)
+			return
+		}
+	}
+
 	// Check if file is write-protected
 	const isWriteProtected = cline.rooProtectedController?.isWriteProtected(relPath) || false
 
@@ -110,6 +170,7 @@ export async function generateImageTool(
 			const approvalMessage = JSON.stringify({
 				...sharedMessageProps,
 				content: prompt,
+				...(inputImagePath && { inputImage: getReadablePath(cline.cwd, inputImagePath) }),
 			})
 
 			const didApprove = await askApproval("tool", approvalMessage, undefined, isWriteProtected)
@@ -121,8 +182,13 @@ export async function generateImageTool(
 			// Create a temporary OpenRouter handler with minimal options
 			const openRouterHandler = new OpenRouterHandler({} as any)
 
-			// Call the generateImage method with the explicit API key
-			const result = await openRouterHandler.generateImage(prompt, selectedModel, openRouterApiKey)
+			// Call the generateImage method with the explicit API key and optional input image
+			const result = await openRouterHandler.generateImage(
+				prompt,
+				selectedModel,
+				openRouterApiKey,
+				inputImageData,
+			)
 
 			if (!result.success) {
 				await cline.say("error", result.error || "Failed to generate image")

+ 2 - 1
src/shared/tools.ts

@@ -66,6 +66,7 @@ export const toolParamNames = [
 	"args",
 	"todos",
 	"prompt",
+	"image",
 ] as const
 
 export type ToolParamName = (typeof toolParamNames)[number]
@@ -167,7 +168,7 @@ export interface SearchAndReplaceToolUse extends ToolUse {
 
 export interface GenerateImageToolUse extends ToolUse {
 	name: "generate_image"
-	params: Partial<Pick<Record<ToolParamName, string>, "prompt" | "path">>
+	params: Partial<Pick<Record<ToolParamName, string>, "prompt" | "path" | "image">>
 }
 
 // Define tool group configuration