Просмотр исходного кода

feat: auto-switch to imported mode with architect fallback (#9003)

Co-authored-by: roomote[bot] <219738659+roomote[bot]@users.noreply.github.com>
Co-authored-by: Seth Miller <[email protected]>
Co-authored-by: heyseth <[email protected]>
Co-authored-by: Roo Code <[email protected]>
Daniel 1 месяц назад
Родитель
Сommit
b17760bf57

+ 5 - 2
src/core/config/CustomModesManager.ts

@@ -40,6 +40,7 @@ interface ExportResult {
 
 interface ImportResult {
 	success: boolean
+	slug?: string
 	error?: string
 }
 
@@ -411,7 +412,7 @@ export class CustomModesManager {
 				const errorMessage = `Invalid mode configuration: ${errorMessages}`
 				logger.error("Mode validation failed", { slug, errors: validationResult.error.errors })
 				vscode.window.showErrorMessage(t("common:customModes.errors.updateFailed", { error: errorMessage }))
-				return
+				throw new Error(errorMessage)
 			}
 
 			const isProjectMode = config.source === "project"
@@ -457,6 +458,7 @@ export class CustomModesManager {
 			const errorMessage = error instanceof Error ? error.message : String(error)
 			logger.error("Failed to update custom mode", { slug, error: errorMessage })
 			vscode.window.showErrorMessage(t("common:customModes.errors.updateFailed", { error: errorMessage }))
+			throw error
 		}
 	}
 
@@ -989,7 +991,8 @@ export class CustomModesManager {
 			// Refresh the modes after import
 			await this.refreshMergedState()
 
-			return { success: true }
+			// Return the imported mode's slug so the UI can activate it
+			return { success: true, slug: importData.customModes[0]?.slug }
 		} catch (error) {
 			const errorMessage = error instanceof Error ? error.message : String(error)
 			logger.error("Failed to import mode with rules", { error: errorMessage })

+ 38 - 32
src/core/webview/webviewMessageHandler.ts

@@ -2044,40 +2044,45 @@ export const webviewMessageHandler = async (
 			break
 		case "updateCustomMode":
 			if (message.modeConfig) {
-				// Check if this is a new mode or an update to an existing mode
-				const existingModes = await provider.customModesManager.getCustomModes()
-				const isNewMode = !existingModes.some((mode) => mode.slug === message.modeConfig?.slug)
-
-				await provider.customModesManager.updateCustomMode(message.modeConfig.slug, message.modeConfig)
-				// Update state after saving the mode
-				const customModes = await provider.customModesManager.getCustomModes()
-				await updateGlobalState("customModes", customModes)
-				await updateGlobalState("mode", message.modeConfig.slug)
-				await provider.postStateToWebview()
-
-				// Track telemetry for custom mode creation or update
-				if (TelemetryService.hasInstance()) {
-					if (isNewMode) {
-						// This is a new custom mode
-						TelemetryService.instance.captureCustomModeCreated(
-							message.modeConfig.slug,
-							message.modeConfig.name,
-						)
-					} else {
-						// Determine which setting was changed by comparing objects
-						const existingMode = existingModes.find((mode) => mode.slug === message.modeConfig?.slug)
-						const changedSettings = existingMode
-							? Object.keys(message.modeConfig).filter(
-									(key) =>
-										JSON.stringify((existingMode as Record<string, unknown>)[key]) !==
-										JSON.stringify((message.modeConfig as Record<string, unknown>)[key]),
-								)
-							: []
+				try {
+					// Check if this is a new mode or an update to an existing mode
+					const existingModes = await provider.customModesManager.getCustomModes()
+					const isNewMode = !existingModes.some((mode) => mode.slug === message.modeConfig?.slug)
+
+					await provider.customModesManager.updateCustomMode(message.modeConfig.slug, message.modeConfig)
+					// Update state after saving the mode
+					const customModes = await provider.customModesManager.getCustomModes()
+					await updateGlobalState("customModes", customModes)
+					await updateGlobalState("mode", message.modeConfig.slug)
+					await provider.postStateToWebview()
 
-						if (changedSettings.length > 0) {
-							TelemetryService.instance.captureModeSettingChanged(changedSettings[0])
+					// Track telemetry for custom mode creation or update
+					if (TelemetryService.hasInstance()) {
+						if (isNewMode) {
+							// This is a new custom mode
+							TelemetryService.instance.captureCustomModeCreated(
+								message.modeConfig.slug,
+								message.modeConfig.name,
+							)
+						} else {
+							// Determine which setting was changed by comparing objects
+							const existingMode = existingModes.find((mode) => mode.slug === message.modeConfig?.slug)
+							const changedSettings = existingMode
+								? Object.keys(message.modeConfig).filter(
+										(key) =>
+											JSON.stringify((existingMode as Record<string, unknown>)[key]) !==
+											JSON.stringify((message.modeConfig as Record<string, unknown>)[key]),
+									)
+								: []
+
+							if (changedSettings.length > 0) {
+								TelemetryService.instance.captureModeSettingChanged(changedSettings[0])
+							}
 						}
 					}
+				} catch (error) {
+					// Error already shown to user by updateCustomMode
+					// Just prevent unhandled rejection and skip state updates
 				}
 			}
 			break
@@ -2285,10 +2290,11 @@ export const webviewMessageHandler = async (
 						await updateGlobalState("customModes", customModes)
 						await provider.postStateToWebview()
 
-						// Send success message to webview
+						// Send success message to webview, include the imported slug so UI can switch
 						provider.postMessageToWebview({
 							type: "importModeResult",
 							success: true,
+							slug: result.slug,
 						})
 
 						// Show success message

+ 41 - 1
webview-ui/src/components/modes/ModesView.tsx

@@ -20,6 +20,7 @@ import {
 	getCustomInstructions,
 	getAllModes,
 	findModeBySlug as findCustomModeBySlug,
+	defaultModeSlug,
 } from "@roo/modes"
 import { TOOL_GROUPS } from "@roo/tools"
 
@@ -55,6 +56,8 @@ const availableGroups = (Object.keys(TOOL_GROUPS) as ToolGroup[]).filter((group)
 
 type ModeSource = "global" | "project"
 
+type ImportModeResult = { type: "importModeResult"; success: boolean; slug?: string; error?: string }
+
 type ModesViewProps = {
 	onDone: () => void
 }
@@ -186,6 +189,29 @@ const ModesView = ({ onDone }: ModesViewProps) => {
 		[visualMode, switchMode],
 	)
 
+	// Refs to track latest state/functions for message handler (which has no dependencies)
+	const handleModeSwitchRef = useRef(handleModeSwitch)
+	const customModesRef = useRef(customModes)
+	const switchModeRef = useRef(switchMode)
+
+	// Update refs when dependencies change
+	useEffect(() => {
+		handleModeSwitchRef.current = handleModeSwitch
+	}, [handleModeSwitch])
+
+	useEffect(() => {
+		customModesRef.current = customModes
+	}, [customModes])
+
+	useEffect(() => {
+		switchModeRef.current = switchMode
+	}, [switchMode])
+
+	// Sync visualMode with backend mode changes to prevent desync
+	useEffect(() => {
+		setVisualMode(mode)
+	}, [mode])
+
 	// Handler for popover open state change
 	const onOpenChange = useCallback((open: boolean) => {
 		setOpen(open)
@@ -460,7 +486,21 @@ const ModesView = ({ onDone }: ModesViewProps) => {
 				setIsImporting(false)
 				setShowImportDialog(false)
 
-				if (!message.success) {
+				if (message.success) {
+					const { slug } = message as ImportModeResult
+					if (slug) {
+						// Try switching using the freshest mode list available
+						const all = getAllModes(customModesRef.current)
+						const importedMode = all.find((m) => m.slug === slug)
+						if (importedMode) {
+							handleModeSwitchRef.current(importedMode)
+						} else {
+							// Fallback: slug not yet in state (race condition) - select default mode
+							setVisualMode(defaultModeSlug)
+							switchModeRef.current?.(defaultModeSlug)
+						}
+					}
+				} else {
 					// Only log error if it's not a cancellation
 					if (message.error !== "cancelled") {
 						console.error("Failed to import mode:", message.error)

+ 156 - 0
webview-ui/src/components/modes/__tests__/ModesView.import-switch.spec.tsx

@@ -0,0 +1,156 @@
+// npx vitest src/components/modes/__tests__/ModesView.import-switch.spec.tsx
+
+import { render, waitFor } from "@/utils/test-utils"
+import ModesView from "../ModesView"
+import { ExtensionStateContext } from "@src/context/ExtensionStateContext"
+import { vscode } from "@src/utils/vscode"
+import { defaultModeSlug } from "@roo/modes"
+
+// Mock vscode API
+vitest.mock("@src/utils/vscode", () => ({
+	vscode: {
+		postMessage: vitest.fn(),
+	},
+}))
+
+const mockExtensionState = {
+	customModePrompts: {},
+	listApiConfigMeta: [
+		{ id: "config1", name: "Config 1" },
+		{ id: "config2", name: "Config 2" },
+	],
+	enhancementApiConfigId: "",
+	setEnhancementApiConfigId: vitest.fn(),
+	mode: "code",
+	customModes: [],
+	customSupportPrompts: [],
+	currentApiConfigName: "",
+	customInstructions: "",
+	setCustomInstructions: vitest.fn(),
+}
+
+const renderModesView = (props = {}) => {
+	const mockOnDone = vitest.fn()
+	return render(
+		<ExtensionStateContext.Provider value={{ ...mockExtensionState, ...props } as any}>
+			<ModesView onDone={mockOnDone} />
+		</ExtensionStateContext.Provider>,
+	)
+}
+
+Element.prototype.scrollIntoView = vitest.fn()
+
+describe("ModesView Import Auto-Switch", () => {
+	beforeEach(() => {
+		vitest.clearAllMocks()
+	})
+
+	it("should auto-switch to imported mode when found in current state", async () => {
+		const importedModeSlug = "custom-test-mode"
+		const customModes = [
+			{
+				slug: importedModeSlug,
+				name: "Custom Test Mode",
+				roleDefinition: "Test role",
+				groups: [],
+			},
+		]
+
+		renderModesView({ customModes })
+
+		// Simulate successful import message with the mode already in state
+		const importMessage = {
+			data: {
+				type: "importModeResult",
+				success: true,
+				slug: importedModeSlug,
+			},
+		}
+
+		window.dispatchEvent(new MessageEvent("message", importMessage))
+
+		// Wait for the mode switch message to be sent
+		await waitFor(() => {
+			expect(vscode.postMessage).toHaveBeenCalledWith({
+				type: "mode",
+				text: importedModeSlug,
+			})
+		})
+	})
+
+	it("should fallback to architect mode when imported slug not yet in state (race condition)", async () => {
+		const importedModeSlug = "custom-new-mode"
+
+		// Render without the imported mode in customModes (simulating race condition)
+		renderModesView({ customModes: [] })
+
+		// Simulate successful import message but mode not yet in state
+		const importMessage = {
+			data: {
+				type: "importModeResult",
+				success: true,
+				slug: importedModeSlug,
+			},
+		}
+
+		window.dispatchEvent(new MessageEvent("message", importMessage))
+
+		// Wait for the fallback to default mode (architect)
+		await waitFor(() => {
+			expect(vscode.postMessage).toHaveBeenCalledWith({
+				type: "mode",
+				text: defaultModeSlug,
+			})
+		})
+	})
+
+	it("should not switch modes on import failure", async () => {
+		renderModesView()
+
+		// Simulate failed import message
+		const importMessage = {
+			data: {
+				type: "importModeResult",
+				success: false,
+				error: "Import failed",
+			},
+		}
+
+		window.dispatchEvent(new MessageEvent("message", importMessage))
+
+		// Wait a bit to ensure no mode switch happens
+		await new Promise((resolve) => setTimeout(resolve, 100))
+
+		// Verify no mode switch message was sent
+		expect(vscode.postMessage).not.toHaveBeenCalledWith(
+			expect.objectContaining({
+				type: "mode",
+			}),
+		)
+	})
+
+	it("should not switch modes on cancelled import", async () => {
+		renderModesView()
+
+		// Simulate cancelled import message
+		const importMessage = {
+			data: {
+				type: "importModeResult",
+				success: false,
+				error: "cancelled",
+			},
+		}
+
+		window.dispatchEvent(new MessageEvent("message", importMessage))
+
+		// Wait a bit to ensure no mode switch happens
+		await new Promise((resolve) => setTimeout(resolve, 100))
+
+		// Verify no mode switch message was sent
+		expect(vscode.postMessage).not.toHaveBeenCalledWith(
+			expect.objectContaining({
+				type: "mode",
+			}),
+		)
+	})
+})