Browse Source

Enhance prompt button for openrouter

Matt Rubens 1 year ago
parent
commit
111abdbb2c

+ 4 - 0
src/api/index.ts

@@ -11,6 +11,10 @@ import { GeminiHandler } from "./providers/gemini"
 import { OpenAiNativeHandler } from "./providers/openai-native"
 import { ApiStream } from "./transform/stream"
 
+export interface SingleCompletionHandler {
+	completePrompt(prompt: string): Promise<string>
+}
+
 export interface ApiHandler {
 	createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream
 	getModel(): { id: string; info: ModelInfo }

+ 162 - 0
src/api/providers/__tests__/openrouter.test.ts

@@ -51,6 +51,14 @@ describe('OpenRouterHandler', () => {
         })
     })
 
+    test('getModel returns default model info when options are not provided', () => {
+        const handler = new OpenRouterHandler({})
+        const result = handler.getModel()
+        
+        expect(result.id).toBe('anthropic/claude-3.5-sonnet:beta')
+        expect(result.info.supportsPromptCache).toBe(true)
+    })
+
     test('createMessage generates correct stream chunks', async () => {
         const handler = new OpenRouterHandler(mockOptions)
         const mockStream = {
@@ -118,4 +126,158 @@ describe('OpenRouterHandler', () => {
             stream: true
         }))
     })
+
+    test('createMessage with middle-out transform enabled', async () => {
+        const handler = new OpenRouterHandler({
+            ...mockOptions,
+            openRouterUseMiddleOutTransform: true
+        })
+        const mockStream = {
+            async *[Symbol.asyncIterator]() {
+                yield {
+                    id: 'test-id',
+                    choices: [{
+                        delta: {
+                            content: 'test response'
+                        }
+                    }]
+                }
+            }
+        }
+
+        const mockCreate = jest.fn().mockResolvedValue(mockStream)
+        ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
+            completions: { create: mockCreate }
+        } as any
+        ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
+
+        await handler.createMessage('test', []).next()
+
+        expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
+            transforms: ['middle-out']
+        }))
+    })
+
+    test('createMessage with Claude model adds cache control', async () => {
+        const handler = new OpenRouterHandler({
+            ...mockOptions,
+            openRouterModelId: 'anthropic/claude-3.5-sonnet'
+        })
+        const mockStream = {
+            async *[Symbol.asyncIterator]() {
+                yield {
+                    id: 'test-id',
+                    choices: [{
+                        delta: {
+                            content: 'test response'
+                        }
+                    }]
+                }
+            }
+        }
+
+        const mockCreate = jest.fn().mockResolvedValue(mockStream)
+        ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
+            completions: { create: mockCreate }
+        } as any
+        ;(axios.get as jest.Mock).mockResolvedValue({ data: { data: {} } })
+
+        const messages: Anthropic.Messages.MessageParam[] = [
+            { role: 'user', content: 'message 1' },
+            { role: 'assistant', content: 'response 1' },
+            { role: 'user', content: 'message 2' }
+        ]
+
+        await handler.createMessage('test system', messages).next()
+
+        expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({
+            messages: expect.arrayContaining([
+                expect.objectContaining({
+                    role: 'system',
+                    content: expect.arrayContaining([
+                        expect.objectContaining({
+                            cache_control: { type: 'ephemeral' }
+                        })
+                    ])
+                })
+            ])
+        }))
+    })
+
+    test('createMessage handles API errors', async () => {
+        const handler = new OpenRouterHandler(mockOptions)
+        const mockStream = {
+            async *[Symbol.asyncIterator]() {
+                yield {
+                    error: {
+                        message: 'API Error',
+                        code: 500
+                    }
+                }
+            }
+        }
+
+        const mockCreate = jest.fn().mockResolvedValue(mockStream)
+        ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
+            completions: { create: mockCreate }
+        } as any
+
+        const generator = handler.createMessage('test', [])
+        await expect(generator.next()).rejects.toThrow('OpenRouter API Error 500: API Error')
+    })
+
+    test('completePrompt returns correct response', async () => {
+        const handler = new OpenRouterHandler(mockOptions)
+        const mockResponse = {
+            choices: [{
+                message: {
+                    content: 'test completion'
+                }
+            }]
+        }
+
+        const mockCreate = jest.fn().mockResolvedValue(mockResponse)
+        ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
+            completions: { create: mockCreate }
+        } as any
+
+        const result = await handler.completePrompt('test prompt')
+
+        expect(result).toBe('test completion')
+        expect(mockCreate).toHaveBeenCalledWith({
+            model: mockOptions.openRouterModelId,
+            messages: [{ role: 'user', content: 'test prompt' }],
+            temperature: 0,
+            stream: false
+        })
+    })
+
+    test('completePrompt handles API errors', async () => {
+        const handler = new OpenRouterHandler(mockOptions)
+        const mockError = {
+            error: {
+                message: 'API Error',
+                code: 500
+            }
+        }
+
+        const mockCreate = jest.fn().mockResolvedValue(mockError)
+        ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
+            completions: { create: mockCreate }
+        } as any
+
+        await expect(handler.completePrompt('test prompt'))
+            .rejects.toThrow('OpenRouter API Error 500: API Error')
+    })
+
+    test('completePrompt handles unexpected errors', async () => {
+        const handler = new OpenRouterHandler(mockOptions)
+        const mockCreate = jest.fn().mockRejectedValue(new Error('Unexpected error'))
+        ;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
+            completions: { create: mockCreate }
+        } as any
+
+        await expect(handler.completePrompt('test prompt'))
+            .rejects.toThrow('OpenRouter completion error: Unexpected error')
+    })
 })

+ 32 - 3
src/api/providers/openrouter.ts

@@ -4,11 +4,11 @@ import OpenAI from "openai"
 import { ApiHandler } from "../"
 import { ApiHandlerOptions, ModelInfo, openRouterDefaultModelId, openRouterDefaultModelInfo } from "../../shared/api"
 import { convertToOpenAiMessages } from "../transform/openai-format"
-import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
+import { ApiStream, ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
 import delay from "delay"
 
 // Add custom interface for OpenRouter params
-interface OpenRouterChatCompletionParams extends OpenAI.Chat.ChatCompletionCreateParamsStreaming {
+type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
     transforms?: string[];
 }
 
@@ -17,7 +17,12 @@ interface OpenRouterApiStreamUsageChunk extends ApiStreamUsageChunk {
     fullResponseText: string;
 }
 
-export class OpenRouterHandler implements ApiHandler {
+// Interface for providers that support single completions
+export interface SingleCompletionHandler {
+    completePrompt(prompt: string): Promise<string>
+}
+
+export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
 	private client: OpenAI
 
@@ -184,4 +189,28 @@ export class OpenRouterHandler implements ApiHandler {
 		}
 		return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		try {
+			const response = await this.client.chat.completions.create({
+				model: this.getModel().id,
+				messages: [{ role: "user", content: prompt }],
+				temperature: 0,
+				stream: false
+			})
+
+			if ("error" in response) {
+				const error = response.error as { message?: string; code?: number }
+				throw new Error(`OpenRouter API Error ${error?.code}: ${error?.message}`)
+			}
+
+			const completion = response as OpenAI.Chat.ChatCompletion
+			return completion.choices[0]?.message?.content || ""
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`OpenRouter completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
 }

+ 18 - 1
src/core/Cline.ts

@@ -8,7 +8,7 @@ import pWaitFor from "p-wait-for"
 import * as path from "path"
 import { serializeError } from "serialize-error"
 import * as vscode from "vscode"
-import { ApiHandler, buildApiHandler } from "../api"
+import { ApiHandler, SingleCompletionHandler, buildApiHandler } from "../api"
 import { ApiStream } from "../api/transform/stream"
 import { DiffViewProvider } from "../integrations/editor/DiffViewProvider"
 import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
@@ -49,6 +49,7 @@ import { truncateHalfConversation } from "./sliding-window"
 import { ClineProvider, GlobalFileNames } from "./webview/ClineProvider"
 import { detectCodeOmission } from "../integrations/editor/detect-omission"
 import { BrowserSession } from "../services/browser/BrowserSession"
+import { OpenRouterHandler } from "../api/providers/openrouter"
 
 const cwd =
 	vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0) ?? path.join(os.homedir(), "Desktop") // may or may not exist but fs checking existence would immediately ask for permission which would be bad UX, need to come up with a better solution
@@ -126,6 +127,22 @@ export class Cline {
 		}
 	}
 
+	async enhancePrompt(promptText: string): Promise<string> {
+		if (!promptText) {
+			throw new Error("No prompt text provided")
+		}
+
+		const prompt = `Generate an enhanced version of this prompt (reply with only the enhanced prompt, no bullet points): ${promptText}`
+		
+		// Check if the API handler supports completePrompt
+		if (this.api instanceof OpenRouterHandler) {
+			return this.api.completePrompt(prompt)
+		}
+
+		// Otherwise just return the prompt as is
+		return prompt;
+	}
+
 	// Storing task to disk for history
 
 	private async ensureTaskDirectoryExists(): Promise<string> {

+ 21 - 0
src/core/webview/ClineProvider.ts

@@ -23,6 +23,7 @@ import { openMention } from "../mentions"
 import { getNonce } from "./getNonce"
 import { getUri } from "./getUri"
 import { playSound, setSoundEnabled, setSoundVolume } from "../../utils/sound"
+import { enhancePrompt } from "../../utils/enhance-prompt"
 
 /*
 https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
@@ -637,6 +638,26 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 						await this.updateGlobalState("writeDelayMs", message.value)
 						await this.postStateToWebview()
 						break
+					case "enhancePrompt":
+						if (message.text) {
+							try {
+								const { apiConfiguration } = await this.getState()
+								const enhanceConfig = {
+									...apiConfiguration,
+									apiProvider: "openrouter" as const,
+									openRouterModelId: "gpt-4o",
+								}
+								const enhancedPrompt = await enhancePrompt(enhanceConfig, message.text)
+								await this.postMessageToWebview({
+									type: "enhancedPrompt",
+									text: enhancedPrompt
+								})
+							} catch (error) {
+								console.error("Error enhancing prompt:", error)
+								vscode.window.showErrorMessage("Failed to enhance prompt")
+							}
+						}
+						break
 				}
 			},
 			null,

+ 1 - 0
src/shared/ExtensionMessage.ts

@@ -19,6 +19,7 @@ export interface ExtensionMessage {
 		| "openRouterModels"
 		| "openAiModels"
 		| "mcpServers"
+		| "enhancedPrompt"
 	text?: string
 	action?:
 		| "chatButtonClicked"

+ 4 - 1
src/shared/WebviewMessage.ts

@@ -43,6 +43,9 @@ export interface WebviewMessage {
 		| "fuzzyMatchThreshold"
 		| "preferredLanguage"
 		| "writeDelayMs"
+		| "enhancePrompt"
+		| "enhancedPrompt"
+		| "draggedImages"
 	text?: string
 	disabled?: boolean
 	askResponse?: ClineAskResponse
@@ -52,10 +55,10 @@ export interface WebviewMessage {
 	value?: number
 	commands?: string[]
 	audioType?: AudioType
-	// For toggleToolAutoApprove
 	serverName?: string
 	toolName?: string
 	alwaysAllow?: boolean
+	dataUrls?: string[]
 }
 
 export type ClineAskResponse = "yesButtonClicked" | "noButtonClicked" | "messageResponse"

+ 80 - 0
src/utils/__tests__/enhance-prompt.test.ts

@@ -0,0 +1,80 @@
+import { enhancePrompt } from '../enhance-prompt'
+import { buildApiHandler } from '../../api'
+import { ApiConfiguration } from '../../shared/api'
+import { OpenRouterHandler } from '../../api/providers/openrouter'
+
+// Mock the buildApiHandler function
+jest.mock('../../api', () => ({
+    buildApiHandler: jest.fn()
+}))
+
+describe('enhancePrompt', () => {
+    const mockApiConfig: ApiConfiguration = {
+        apiProvider: 'openrouter',
+        apiKey: 'test-key',
+        openRouterApiKey: 'test-key',
+        openRouterModelId: 'test-model'
+    }
+
+    // Create a mock handler that looks like OpenRouterHandler
+    const mockHandler = {
+        completePrompt: jest.fn(),
+        createMessage: jest.fn(),
+        getModel: jest.fn()
+    }
+
+    // Make instanceof check work
+    Object.setPrototypeOf(mockHandler, OpenRouterHandler.prototype)
+
+    beforeEach(() => {
+        jest.clearAllMocks()
+        ;(buildApiHandler as jest.Mock).mockReturnValue(mockHandler)
+    })
+
+    it('should throw error for non-OpenRouter providers', async () => {
+        const nonOpenRouterConfig: ApiConfiguration = {
+            apiProvider: 'anthropic',
+            apiKey: 'test-key',
+            apiModelId: 'claude-3'
+        }
+        await expect(enhancePrompt(nonOpenRouterConfig, 'test')).rejects.toThrow('Prompt enhancement is only available with OpenRouter')
+    })
+
+    it('should enhance a valid prompt', async () => {
+        const inputPrompt = 'Write a function to sort an array'
+        const enhancedPrompt = 'Write a TypeScript function that implements an efficient sorting algorithm for a generic array, including error handling and type safety'
+        
+        mockHandler.completePrompt.mockResolvedValue(enhancedPrompt)
+
+        const result = await enhancePrompt(mockApiConfig, inputPrompt)
+
+        expect(result).toBe(enhancedPrompt)
+        expect(buildApiHandler).toHaveBeenCalledWith(mockApiConfig)
+        expect(mockHandler.completePrompt).toHaveBeenCalledWith(
+            expect.stringContaining(inputPrompt)
+        )
+    })
+
+    it('should throw error when no prompt text is provided', async () => {
+        await expect(enhancePrompt(mockApiConfig, '')).rejects.toThrow('No prompt text provided')
+        expect(mockHandler.completePrompt).not.toHaveBeenCalled()
+    })
+
+    it('should pass through API errors', async () => {
+        const inputPrompt = 'Test prompt'
+        mockHandler.completePrompt.mockRejectedValue('API error')
+
+        await expect(enhancePrompt(mockApiConfig, inputPrompt)).rejects.toBe('API error')
+    })
+
+    it('should pass the correct prompt format to the API', async () => {
+        const inputPrompt = 'Test prompt'
+        mockHandler.completePrompt.mockResolvedValue('Enhanced test prompt')
+        
+        await enhancePrompt(mockApiConfig, inputPrompt)
+
+        expect(mockHandler.completePrompt).toHaveBeenCalledWith(
+            'Generate an enhanced version of this prompt (reply with only the enhanced prompt, no other text or bullet points): Test prompt'
+        )
+    })
+})

+ 26 - 0
src/utils/enhance-prompt.ts

@@ -0,0 +1,26 @@
+import { ApiConfiguration } from "../shared/api"
+import { buildApiHandler } from "../api"
+import { OpenRouterHandler, SingleCompletionHandler } from "../api/providers/openrouter"
+
+/**
+ * Enhances a prompt using the OpenRouter API without creating a full Cline instance or task history.
+ * This is a lightweight alternative that only uses the API's completion functionality.
+ */
+export async function enhancePrompt(apiConfiguration: ApiConfiguration, promptText: string): Promise<string> {
+    if (!promptText) {
+        throw new Error("No prompt text provided")
+    }
+    if (apiConfiguration.apiProvider !== "openrouter") {
+        throw new Error("Prompt enhancement is only available with OpenRouter")
+    }
+    
+    const handler = buildApiHandler(apiConfiguration)
+    
+    // Type guard to check if handler is OpenRouterHandler
+    if (!(handler instanceof OpenRouterHandler)) {
+        throw new Error("Expected OpenRouter handler")
+    }
+    
+    const prompt = `Generate an enhanced version of this prompt (reply with only the enhanced prompt, no other text or bullet points): ${promptText}`
+    return handler.completePrompt(prompt)
+}

+ 146 - 97
webview-ui/src/components/chat/ChatTextArea.tsx

@@ -13,7 +13,7 @@ import { MAX_IMAGES_PER_MESSAGE } from "./ChatView"
 import ContextMenu from "./ContextMenu"
 import Thumbnails from "../common/Thumbnails"
 
-declare const vscode: any;
+import { vscode } from "../../utils/vscode"
 
 interface ChatTextAreaProps {
 	inputValue: string
@@ -44,8 +44,20 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
 		},
 		ref,
 	) => {
-		const { filePaths } = useExtensionState()
+		const { filePaths, apiConfiguration } = useExtensionState()
 		const [isTextAreaFocused, setIsTextAreaFocused] = useState(false)
+
+		// Handle enhanced prompt response
+		useEffect(() => {
+			const messageHandler = (event: MessageEvent) => {
+				const message = event.data
+				if (message.type === 'enhancedPrompt' && message.text) {
+					setInputValue(message.text)
+				}
+			}
+			window.addEventListener('message', messageHandler)
+			return () => window.removeEventListener('message', messageHandler)
+		}, [setInputValue])
 		const [thumbnailsHeight, setThumbnailsHeight] = useState(0)
 		const [textAreaBaseHeight, setTextAreaBaseHeight] = useState<number | undefined>(undefined)
 		const [showContextMenu, setShowContextMenu] = useState(false)
@@ -60,6 +72,63 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
 		const [intendedCursorPosition, setIntendedCursorPosition] = useState<number | null>(null)
 		const contextMenuContainerRef = useRef<HTMLDivElement>(null)
 
+		const [isEnhancingPrompt, setIsEnhancingPrompt] = useState(false)
+		
+		const handleEnhancePrompt = useCallback(() => {
+		    if (!textAreaDisabled) {
+		        const trimmedInput = inputValue.trim()
+		        if (trimmedInput) {
+		            setIsEnhancingPrompt(true)
+		            const message = {
+		                type: "enhancePrompt" as const,
+		                text: trimmedInput,
+		            }
+		            vscode.postMessage(message)
+		        } else {
+		            const promptDescription = "The 'Enhance Prompt' button helps improve your prompt by providing additional context, clarification, or rephrasing. Try typing a prompt in here and clicking the button again to see how it works."
+		            setInputValue(promptDescription)
+		        }
+		    }
+		}, [inputValue, textAreaDisabled, setInputValue])
+		
+		useEffect(() => {
+		    const messageHandler = (event: MessageEvent) => {
+		        const message = event.data
+		        if (message.type === 'enhancedPrompt') {
+		            setInputValue(message.text)
+		            setIsEnhancingPrompt(false)
+		        }
+		    }
+		    window.addEventListener('message', messageHandler)
+		    return () => window.removeEventListener('message', messageHandler)
+		}, [setInputValue])
+
+		// Handle enhanced prompt response
+		useEffect(() => {
+		    const messageHandler = (event: MessageEvent) => {
+		        const message = event.data
+		        if (message.type === 'enhancedPrompt') {
+		            setInputValue(message.text)
+		        }
+		    }
+		    window.addEventListener('message', messageHandler)
+		    return () => {
+		        window.removeEventListener('message', messageHandler)
+		    }
+		}, [setInputValue])
+
+		// Handle enhanced prompt response
+		useEffect(() => {
+		    const messageHandler = (event: MessageEvent) => {
+		        const message = event.data
+		        if (message.type === 'enhancedPrompt' && message.text) {
+		            setInputValue(message.text)
+		        }
+		    }
+		    window.addEventListener('message', messageHandler)
+		    return () => window.removeEventListener('message', messageHandler)
+		}, [setInputValue])
+
 		const queryItems = useMemo(() => {
 			return [
 				{ type: ContextMenuOptionType.Problems, value: "problems" },
@@ -423,68 +492,64 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
 		)
 
 		return (
-			<div
-				style={{
-					padding: "10px 15px",
-					opacity: textAreaDisabled ? 0.5 : 1,
-					position: "relative",
-					display: "flex",
-				}}
-				onDrop={async (e) => {
-					console.log("onDrop called")
-					e.preventDefault()
-					const files = Array.from(e.dataTransfer.files)
-					const text = e.dataTransfer.getData("text")
-					if (text) {
-						const newValue =
-							inputValue.slice(0, cursorPosition) + text + inputValue.slice(cursorPosition)
-						setInputValue(newValue)
-						const newCursorPosition = cursorPosition + text.length
-						setCursorPosition(newCursorPosition)
-						setIntendedCursorPosition(newCursorPosition)
-						return
-					}
-					const acceptedTypes = ["png", "jpeg", "webp"]
-					const imageFiles = files.filter((file) => {
-						const [type, subtype] = file.type.split("/")
-						return type === "image" && acceptedTypes.includes(subtype)
-					})
-					if (!shouldDisableImages && imageFiles.length > 0) {
-						const imagePromises = imageFiles.map((file) => {
-							return new Promise<string | null>((resolve) => {
-								const reader = new FileReader()
-								reader.onloadend = () => {
-									if (reader.error) {
-										console.error("Error reading file:", reader.error)
-										resolve(null)
-									} else {
-										const result = reader.result
-										console.log("File read successfully", result)
-										resolve(typeof result === "string" ? result : null)
-									}
+			<div style={{
+				padding: "10px 15px",
+				opacity: textAreaDisabled ? 0.5 : 1,
+				position: "relative",
+				display: "flex",
+			}}
+			onDrop={async (e) => {
+				e.preventDefault()
+				const files = Array.from(e.dataTransfer.files)
+				const text = e.dataTransfer.getData("text")
+				if (text) {
+					const newValue =
+						inputValue.slice(0, cursorPosition) + text + inputValue.slice(cursorPosition)
+					setInputValue(newValue)
+					const newCursorPosition = cursorPosition + text.length
+					setCursorPosition(newCursorPosition)
+					setIntendedCursorPosition(newCursorPosition)
+					return
+				}
+				const acceptedTypes = ["png", "jpeg", "webp"]
+				const imageFiles = files.filter((file) => {
+					const [type, subtype] = file.type.split("/")
+					return type === "image" && acceptedTypes.includes(subtype)
+				})
+				if (!shouldDisableImages && imageFiles.length > 0) {
+					const imagePromises = imageFiles.map((file) => {
+						return new Promise<string | null>((resolve) => {
+							const reader = new FileReader()
+							reader.onloadend = () => {
+								if (reader.error) {
+									console.error("Error reading file:", reader.error)
+									resolve(null)
+								} else {
+									const result = reader.result
+									resolve(typeof result === "string" ? result : null)
 								}
-								reader.readAsDataURL(file)
-							})
-						})
-						const imageDataArray = await Promise.all(imagePromises)
-						const dataUrls = imageDataArray.filter((dataUrl): dataUrl is string => dataUrl !== null)
-						if (dataUrls.length > 0) {
-							setSelectedImages((prevImages) => [...prevImages, ...dataUrls].slice(0, MAX_IMAGES_PER_MESSAGE))
-							if (typeof vscode !== 'undefined') {
-								vscode.postMessage({
-									type: 'draggedImages',
-									dataUrls: dataUrls
-								})
 							}
-						} else {
-							console.warn("No valid images were processed")
+							reader.readAsDataURL(file)
+						})
+					})
+					const imageDataArray = await Promise.all(imagePromises)
+					const dataUrls = imageDataArray.filter((dataUrl): dataUrl is string => dataUrl !== null)
+					if (dataUrls.length > 0) {
+						setSelectedImages((prevImages) => [...prevImages, ...dataUrls].slice(0, MAX_IMAGES_PER_MESSAGE))
+						if (typeof vscode !== 'undefined') {
+							vscode.postMessage({
+								type: 'draggedImages',
+								dataUrls: dataUrls
+							})
 						}
+					} else {
+						console.warn("No valid images were processed")
 					}
-				}}
-				onDragOver={(e) => {
-					e.preventDefault()
-				}}
-			>
+				}
+			}}
+			onDragOver={(e) => {
+				e.preventDefault()
+			}}>
 				{showContextMenu && (
 					<div ref={contextMenuContainerRef}>
 						<ContextMenu
@@ -533,7 +598,7 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
 						borderTop: 0,
 						borderColor: "transparent",
 						borderBottom: `${thumbnailsHeight + 6}px solid transparent`,
-						padding: "9px 49px 3px 9px",
+						padding: "9px 9px 25px 9px",
 					}}
 				/>
 				<DynamicTextArea
@@ -588,11 +653,11 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
 						borderTop: 0,
 						borderBottom: `${thumbnailsHeight + 6}px solid transparent`,
 						borderColor: "transparent",
+						padding: "9px 9px 25px 9px",
 						// borderRight: "54px solid transparent",
 						// borderLeft: "9px solid transparent", // NOTE: react-textarea-autosize doesn't calculate correct height when using borderLeft/borderRight so we need to use horizontal padding instead
 						// Instead of using boxShadow, we use a div with a border to better replicate the behavior when the textarea is focused
 						// boxShadow: "0px 0px 0px 1px var(--vscode-input-border)",
-						padding: "9px 49px 3px 9px",
 						cursor: textAreaDisabled ? "not-allowed" : undefined,
 						flex: 1,
 						zIndex: 1,
@@ -609,45 +674,29 @@ const ChatTextArea = forwardRef<HTMLTextAreaElement, ChatTextAreaProps>(
 							paddingTop: 4,
 							bottom: 14,
 							left: 22,
-							right: 67, // (54 + 9) + 4 extra padding
+							right: 67,
 							zIndex: 2,
 						}}
 					/>
 				)}
-				<div
-					style={{
-						position: "absolute",
-						right: 28,
-						display: "flex",
-						alignItems: "flex-end",
-						height: textAreaBaseHeight || 31,
-						bottom: 18,
-						zIndex: 2,
-					}}>
-					<div style={{ display: "flex", flexDirection: "row", alignItems: "center" }}>
-						<div
-							className={`input-icon-button ${
-								shouldDisableImages ? "disabled" : ""
-							} codicon codicon-device-camera`}
-							onClick={() => {
-								if (!shouldDisableImages) {
-									onSelectImages()
-								}
-							}}
-							style={{
-								marginRight: 5.5,
-								fontSize: 16.5,
-							}}
-						/>
-						<div
-							className={`input-icon-button ${textAreaDisabled ? "disabled" : ""} codicon codicon-send`}
-							onClick={() => {
-								if (!textAreaDisabled) {
-									onSend()
-								}
-							}}
-							style={{ fontSize: 15 }}></div>
-					</div>
+				<div className="button-row" style={{ position: "absolute", right: 20, display: "flex", alignItems: "center", height: 31, bottom: 8, zIndex: 2, justifyContent: "flex-end" }}>
+				  <span style={{ display: "flex", alignItems: "center", gap: 12 }}>
+				    {apiConfiguration?.apiProvider === "openrouter" && (
+				      <div style={{ display: "flex", alignItems: "center" }}>
+				        {isEnhancingPrompt && <span style={{ marginRight: 10, color: "var(--vscode-input-foreground)", opacity: 0.5 }}>Enhancing prompt...</span>}
+				        <span
+				          role="button"
+				          aria-label="enhance prompt"
+				          data-testid="enhance-prompt-button"
+				          className={`input-icon-button ${textAreaDisabled ? "disabled" : ""} codicon codicon-sparkle`}
+				          onClick={() => !textAreaDisabled && handleEnhancePrompt()}
+				          style={{ fontSize: 16.5 }}
+				        />
+				      </div>
+				    )}
+				    <span className={`input-icon-button ${shouldDisableImages ? "disabled" : ""} codicon codicon-device-camera`} onClick={() => !shouldDisableImages && onSelectImages()} style={{ fontSize: 16.5 }} />
+				    <span className={`input-icon-button ${textAreaDisabled ? "disabled" : ""} codicon codicon-send`} onClick={() => !textAreaDisabled && onSend()} style={{ fontSize: 15 }} />
+				  </span>
 				</div>
 			</div>
 		)

+ 185 - 0
webview-ui/src/components/chat/__tests__/ChatTextArea.test.tsx

@@ -0,0 +1,185 @@
+/* eslint-disable import/first */
+import React from 'react';
+import { render, fireEvent, screen } from '@testing-library/react';
+import '@testing-library/jest-dom';
+import ChatTextArea from '../ChatTextArea';
+import { useExtensionState } from '../../../context/ExtensionStateContext';
+import { vscode } from '../../../utils/vscode';
+
+// Mock modules
+jest.mock('../../../utils/vscode', () => ({
+  vscode: {
+    postMessage: jest.fn()
+  }
+}));
+jest.mock('../../../components/common/CodeBlock');
+jest.mock('../../../components/common/MarkdownBlock');
+
+// Get the mocked postMessage function
+const mockPostMessage = vscode.postMessage as jest.Mock;
+/* eslint-enable import/first */
+
+// Mock ExtensionStateContext
+jest.mock('../../../context/ExtensionStateContext');
+
+describe('ChatTextArea', () => {
+  const defaultProps = {
+    inputValue: '',
+    setInputValue: jest.fn(),
+    onSend: jest.fn(),
+    textAreaDisabled: false,
+    onSelectImages: jest.fn(),
+    shouldDisableImages: false,
+    placeholderText: 'Type a message...',
+    selectedImages: [],
+    setSelectedImages: jest.fn(),
+    onHeightChange: jest.fn(),
+  };
+
+  beforeEach(() => {
+    jest.clearAllMocks();
+    // Default mock implementation for useExtensionState
+    (useExtensionState as jest.Mock).mockReturnValue({
+      filePaths: [],
+      apiConfiguration: {
+        apiProvider: 'anthropic',
+      },
+    });
+  });
+
+  describe('enhance prompt button', () => {
+    it('should show enhance prompt button only when apiProvider is openrouter', () => {
+      // Test with non-openrouter provider
+      (useExtensionState as jest.Mock).mockReturnValue({
+        filePaths: [],
+        apiConfiguration: {
+          apiProvider: 'anthropic',
+        },
+      });
+
+      const { rerender } = render(<ChatTextArea {...defaultProps} />);
+      expect(screen.queryByTestId('enhance-prompt-button')).not.toBeInTheDocument();
+
+      // Test with openrouter provider
+      (useExtensionState as jest.Mock).mockReturnValue({
+        filePaths: [],
+        apiConfiguration: {
+          apiProvider: 'openrouter',
+        },
+      });
+
+      rerender(<ChatTextArea {...defaultProps} />);
+      const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i });
+      expect(enhanceButton).toBeInTheDocument();
+    });
+
+    it('should be disabled when textAreaDisabled is true', () => {
+      (useExtensionState as jest.Mock).mockReturnValue({
+        filePaths: [],
+        apiConfiguration: {
+          apiProvider: 'openrouter',
+        },
+      });
+
+      render(<ChatTextArea {...defaultProps} textAreaDisabled={true} />);
+      const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i });
+      expect(enhanceButton).toHaveClass('disabled');
+    });
+  });
+
+  describe('handleEnhancePrompt', () => {
+    it('should send message with correct configuration when clicked', () => {
+      const apiConfiguration = {
+        apiProvider: 'openrouter',
+        apiKey: 'test-key',
+      };
+
+      (useExtensionState as jest.Mock).mockReturnValue({
+        filePaths: [],
+        apiConfiguration,
+      });
+
+      render(<ChatTextArea {...defaultProps} inputValue="Test prompt" />);
+      
+      const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i });
+      fireEvent.click(enhanceButton);
+
+      expect(mockPostMessage).toHaveBeenCalledWith({
+        type: 'enhancePrompt',
+        text: 'Test prompt',
+      });
+    });
+
+    it('should not send message when input is empty', () => {
+      (useExtensionState as jest.Mock).mockReturnValue({
+        filePaths: [],
+        apiConfiguration: {
+          apiProvider: 'openrouter',
+        },
+      });
+
+      render(<ChatTextArea {...defaultProps} inputValue="" />);
+      
+      const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i });
+      fireEvent.click(enhanceButton);
+
+      expect(mockPostMessage).not.toHaveBeenCalled();
+    });
+
+    it('should show loading state while enhancing', () => {
+      (useExtensionState as jest.Mock).mockReturnValue({
+        filePaths: [],
+        apiConfiguration: {
+          apiProvider: 'openrouter',
+        },
+      });
+
+      render(<ChatTextArea {...defaultProps} inputValue="Test prompt" />);
+      
+      const enhanceButton = screen.getByRole('button', { name: /enhance prompt/i });
+      fireEvent.click(enhanceButton);
+
+      expect(screen.getByText('Enhancing prompt...')).toBeInTheDocument();
+    });
+  });
+
+  describe('effect dependencies', () => {
+    it('should update when apiConfiguration changes', () => {
+      const { rerender } = render(<ChatTextArea {...defaultProps} />);
+
+      // Update apiConfiguration
+      (useExtensionState as jest.Mock).mockReturnValue({
+        filePaths: [],
+        apiConfiguration: {
+          apiProvider: 'openrouter',
+          newSetting: 'test',
+        },
+      });
+
+      rerender(<ChatTextArea {...defaultProps} />);
+      
+      // Verify the enhance button appears after apiConfiguration changes
+      expect(screen.getByRole('button', { name: /enhance prompt/i })).toBeInTheDocument();
+    });
+  });
+
+  describe('enhanced prompt response', () => {
+    it('should update input value when receiving enhanced prompt', () => {
+      const setInputValue = jest.fn();
+      
+      render(<ChatTextArea {...defaultProps} setInputValue={setInputValue} />);
+
+      // Simulate receiving enhanced prompt message
+      window.dispatchEvent(
+        new MessageEvent('message', {
+          data: {
+            type: 'enhancedPrompt',
+            text: 'Enhanced test prompt',
+          },
+        })
+      );
+
+      expect(setInputValue).toHaveBeenCalledWith('Enhanced test prompt');
+    });
+  });
+});

+ 12 - 0
webview-ui/src/components/common/__mocks__/CodeBlock.tsx

@@ -0,0 +1,12 @@
+import * as React from 'react';
+
+interface CodeBlockProps {
+  children?: React.ReactNode;
+  language?: string;
+}
+
+const CodeBlock: React.FC<CodeBlockProps> = () => (
+  <div data-testid="mock-code-block">Mocked Code Block</div>
+);
+
+export default CodeBlock;

+ 12 - 0
webview-ui/src/components/common/__mocks__/MarkdownBlock.tsx

@@ -0,0 +1,12 @@
+import * as React from 'react';
+
+interface MarkdownBlockProps {
+  children?: React.ReactNode;
+  content?: string;
+}
+
+const MarkdownBlock: React.FC<MarkdownBlockProps> = ({ content }) => (
+  <div data-testid="mock-markdown-block">{content}</div>
+);
+
+export default MarkdownBlock;