Ver código fonte

fix: prevent UI flicker and enable resumption after task cancellation (#8986)

Daniel 2 meses atrás
pai
commit
58edc71672

+ 100 - 7
src/core/task/Task.ts

@@ -737,7 +737,8 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 		// deallocated. (Although we set Cline = undefined in provider, that
 		// simply removes the reference to this instance, but the instance is
 		// still alive until this promise resolves or rejects.)
-		if (this.abort) {
+		// Exception: Allow resume asks even when aborted for soft-interrupt UX
+		if (this.abort && type !== "resume_task" && type !== "resume_completed_task") {
 			throw new Error(`[RooCode#ask] task ${this.taskId}.${this.instanceId} aborted`)
 		}
 
@@ -1255,7 +1256,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 		])
 	}
 
-	private async resumeTaskFromHistory() {
+	public async resumeTaskFromHistory() {
 		if (this.enableBridge) {
 			try {
 				await BridgeOrchestrator.subscribeToTask(this)
@@ -1347,6 +1348,13 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 
 		const { response, text, images } = await this.ask(askType) // Calls `postStateToWebview`.
 
+		// Reset abort flags AFTER user responds to resume ask.
+		// This is critical for the cancel → resume flow: when a task is soft-aborted
+		// (abandoned = false), we keep the instance alive but set abort = true.
+		// We only clear these flags after the user confirms they want to resume,
+		// preventing the old stream from continuing if abort was set.
+		this.resetAbortAndStreamingState()
+
 		let responseText: string | undefined
 		let responseImages: string[] | undefined
 
@@ -1525,6 +1533,86 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 		await this.initiateTaskLoop(newUserContent)
 	}
 
+	/**
+	 * Resets abort flags and streaming state to allow task resumption.
+	 * Centralizes the state reset logic used after user confirms task resumption.
+	 *
+	 * @private
+	 */
+	private resetAbortAndStreamingState(): void {
+		this.abort = false
+		this.abandoned = false
+		this.abortReason = undefined
+		this.didFinishAbortingStream = false
+		this.isStreaming = false
+
+		// Reset streaming-local fields to avoid stale state from previous stream
+		this.currentStreamingContentIndex = 0
+		this.currentStreamingDidCheckpoint = false
+		this.assistantMessageContent = []
+		this.didCompleteReadingStream = false
+		this.userMessageContent = []
+		this.userMessageContentReady = false
+		this.didRejectTool = false
+		this.didAlreadyUseTool = false
+		this.presentAssistantMessageLocked = false
+		this.presentAssistantMessageHasPendingUpdates = false
+		this.assistantMessageParser.reset()
+	}
+
+	/**
+	 * Present a resumable ask on an aborted task without rehydrating.
+	 * Used by soft-interrupt (cancelTask) to show Resume/Terminate UI.
+	 * Selects the appropriate ask type based on the last relevant message.
+	 * If the user clicks Resume, resets abort flags and continues the task loop.
+	 */
+	public async presentResumableAsk(): Promise<void> {
+		const lastClineMessage = this.clineMessages
+			.slice()
+			.reverse()
+			.find((m) => !(m.ask === "resume_task" || m.ask === "resume_completed_task"))
+
+		let askType: ClineAsk
+		if (lastClineMessage?.ask === "completion_result") {
+			askType = "resume_completed_task"
+		} else {
+			askType = "resume_task"
+		}
+
+		const { response, text, images } = await this.ask(askType)
+
+		// If user clicked Resume (not Terminate), reset abort flags and continue
+		if (response === "yesButtonClicked" || response === "messageResponse") {
+			// Reset abort flags to allow the loop to continue
+			this.resetAbortAndStreamingState()
+
+			// Prepare content for resuming the task loop
+			let userContent: Anthropic.Messages.ContentBlockParam[] = []
+
+			if (response === "messageResponse" && text) {
+				// User provided additional instructions
+				await this.say("user_feedback", text, images)
+				userContent.push({
+					type: "text",
+					text: `\n\nNew instructions for task continuation:\n<user_message>\n${text}\n</user_message>`,
+				})
+				if (images && images.length > 0) {
+					userContent.push(...formatResponse.imageBlocks(images))
+				}
+			} else {
+				// Simple resume with no new instructions
+				userContent.push({
+					type: "text",
+					text: "[TASK RESUMPTION] Resuming task...",
+				})
+			}
+
+			// Continue the task loop
+			await this.initiateTaskLoop(userContent)
+		}
+		// If user clicked Terminate (noButtonClicked), do nothing - task stays aborted
+	}
+
 	public async abortTask(isAbandoned = false) {
 		// Aborting task
 
@@ -1536,12 +1624,17 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 		this.abort = true
 		this.emit(RooCodeEventName.TaskAborted)
 
-		try {
-			this.dispose() // Call the centralized dispose method
-		} catch (error) {
-			console.error(`Error during task ${this.taskId}.${this.instanceId} disposal:`, error)
-			// Don't rethrow - we want abort to always succeed
+		// Only dispose if this is a hard abort (abandoned)
+		// For soft abort (user cancel), keep the instance alive so we can present a resumable ask
+		if (isAbandoned) {
+			try {
+				this.dispose() // Call the centralized dispose method
+			} catch (error) {
+				console.error(`Error during task ${this.taskId}.${this.instanceId} disposal:`, error)
+				// Don't rethrow - we want abort to always succeed
+			}
 		}
+
 		// Save the countdown message in the automatic retry or other content.
 		try {
 			// Save the countdown message in the automatic retry or other content.

+ 146 - 0
src/core/task/__tests__/Task.presentResumableAsk.abort-reset.spec.ts

@@ -0,0 +1,146 @@
+import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"
+import { ProviderSettings } from "@roo-code/types"
+
+import { Task } from "../Task"
+import { ClineProvider } from "../../webview/ClineProvider"
+
+// Mocks similar to Task.dispose.test.ts
+vi.mock("../../webview/ClineProvider")
+vi.mock("../../../integrations/terminal/TerminalRegistry", () => ({
+	TerminalRegistry: {
+		releaseTerminalsForTask: vi.fn(),
+	},
+}))
+vi.mock("../../ignore/RooIgnoreController")
+vi.mock("../../protect/RooProtectedController")
+vi.mock("../../context-tracking/FileContextTracker")
+vi.mock("../../../services/browser/UrlContentFetcher")
+vi.mock("../../../services/browser/BrowserSession")
+vi.mock("../../../integrations/editor/DiffViewProvider")
+vi.mock("../../tools/ToolRepetitionDetector")
+vi.mock("../../../api", () => ({
+	buildApiHandler: vi.fn(() => ({
+		getModel: () => ({ info: {}, id: "test-model" }),
+	})),
+}))
+vi.mock("../AutoApprovalHandler")
+
+// Mock TelemetryService
+vi.mock("@roo-code/telemetry", () => ({
+	TelemetryService: {
+		instance: {
+			captureTaskCreated: vi.fn(),
+			captureTaskRestarted: vi.fn(),
+			captureConversationMessage: vi.fn(),
+		},
+	},
+}))
+
+describe("Task.presentResumableAsk abort reset", () => {
+	let mockProvider: any
+	let mockApiConfiguration: ProviderSettings
+	let task: Task
+
+	beforeEach(() => {
+		vi.clearAllMocks()
+
+		mockProvider = {
+			context: {
+				globalStorageUri: { fsPath: "/test/path" },
+			},
+			getState: vi.fn().mockResolvedValue({ mode: "code" }),
+			postStateToWebview: vi.fn().mockResolvedValue(undefined),
+			postMessageToWebview: vi.fn().mockResolvedValue(undefined),
+			updateTaskHistory: vi.fn().mockResolvedValue(undefined),
+			log: vi.fn(),
+		}
+
+		mockApiConfiguration = {
+			apiProvider: "anthropic",
+			apiKey: "test-key",
+		} as ProviderSettings
+
+		task = new Task({
+			provider: mockProvider as ClineProvider,
+			apiConfiguration: mockApiConfiguration,
+			startTask: false,
+		})
+	})
+
+	afterEach(() => {
+		// Ensure we don't leave event listeners dangling
+		task.dispose()
+	})
+
+	it("resets abort flags and continues the loop on yesButtonClicked", async () => {
+		// Arrange aborted state
+		task.abort = true
+		task.abortReason = "user_cancelled"
+		task.didFinishAbortingStream = true
+		task.isStreaming = true
+
+		// minimal message history
+		task.clineMessages = [{ ts: Date.now() - 1000, type: "say", say: "text", text: "prev" } as any]
+
+		// Spy and stub ask + loop
+		const askSpy = vi.spyOn(task as any, "ask").mockResolvedValue({ response: "yesButtonClicked" })
+		const loopSpy = vi.spyOn(task as any, "initiateTaskLoop").mockResolvedValue(undefined)
+
+		// Act
+		await task.presentResumableAsk()
+
+		// Assert ask was presented
+		expect(askSpy).toHaveBeenCalled()
+
+		// Abort flags cleared
+		expect(task.abort).toBe(false)
+		expect(task.abandoned).toBe(false)
+		expect(task.abortReason).toBeUndefined()
+		expect(task.didFinishAbortingStream).toBe(false)
+		expect(task.isStreaming).toBe(false)
+
+		// Streaming-local state cleared
+		expect(task.currentStreamingContentIndex).toBe(0)
+		expect(task.assistantMessageContent).toEqual([])
+		expect(task.userMessageContentReady).toBe(false)
+		expect(task.didRejectTool).toBe(false)
+		expect(task.presentAssistantMessageLocked).toBe(false)
+
+		// Loop resumed
+		expect(loopSpy).toHaveBeenCalledTimes(1)
+	})
+
+	it("includes user feedback when resuming with messageResponse", async () => {
+		task.abort = true
+		task.clineMessages = [{ ts: Date.now() - 1000, type: "say", say: "text", text: "prev" } as any]
+
+		const askSpy = vi
+			.spyOn(task as any, "ask")
+			.mockResolvedValue({ response: "messageResponse", text: "Continue with this", images: undefined })
+		const saySpy = vi.spyOn(task, "say").mockResolvedValue(undefined as any)
+		const loopSpy = vi.spyOn(task as any, "initiateTaskLoop").mockResolvedValue(undefined)
+
+		await task.presentResumableAsk()
+
+		expect(askSpy).toHaveBeenCalled()
+		expect(saySpy).toHaveBeenCalledWith("user_feedback", "Continue with this", undefined)
+		expect(loopSpy).toHaveBeenCalledTimes(1)
+	})
+
+	it("does nothing when user clicks Terminate (noButtonClicked)", async () => {
+		task.abort = true
+		task.abortReason = "user_cancelled"
+		task.clineMessages = [{ ts: Date.now() - 1000, type: "say", say: "text", text: "prev" } as any]
+
+		vi.spyOn(task as any, "ask").mockResolvedValue({ response: "noButtonClicked" })
+		const loopSpy = vi.spyOn(task as any, "initiateTaskLoop").mockResolvedValue(undefined)
+
+		await task.presentResumableAsk()
+
+		// Still aborted
+		expect(task.abort).toBe(true)
+		expect(task.abortReason).toBe("user_cancelled")
+		// No loop resume
+		expect(loopSpy).not.toHaveBeenCalled()
+	})
+})

+ 5 - 5
src/core/task/__tests__/Task.spec.ts

@@ -1722,12 +1722,12 @@ describe("Cline", () => {
 			// Mock the dispose method to track cleanup
 			const disposeSpy = vi.spyOn(task, "dispose").mockImplementation(() => {})
 
-			// Call abortTask
+			// Call abortTask (soft cancel - same path as UI Cancel button)
 			await task.abortTask()
 
-			// Verify the same behavior as Cancel button
+			// Verify the same behavior as Cancel button: soft abort sets abort flag but does not dispose
 			expect(task.abort).toBe(true)
-			expect(disposeSpy).toHaveBeenCalled()
+			expect(disposeSpy).not.toHaveBeenCalled()
 		})
 
 		it("should work with TaskLike interface", async () => {
@@ -1771,8 +1771,8 @@ describe("Cline", () => {
 			// Spy on console.error to verify error is logged
 			const consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {})
 
-			// abortTask should not throw even if dispose fails
-			await expect(task.abortTask()).resolves.not.toThrow()
+			// abortTask should not throw even if dispose fails (hard abort triggers dispose)
+			await expect(task.abortTask(true)).resolves.not.toThrow()
 
 			// Verify error was logged
 			expect(consoleErrorSpy).toHaveBeenCalledWith(expect.stringContaining("Error during task"), mockError)

+ 68 - 32
src/core/webview/ClineProvider.ts

@@ -50,6 +50,7 @@ import { Package } from "../../shared/package"
 import { findLast } from "../../shared/array"
 import { supportPrompt } from "../../shared/support-prompt"
 import { GlobalFileNames } from "../../shared/globalFileNames"
+import { safeJsonParse } from "../../shared/safeJsonParse"
 import type { ExtensionMessage, ExtensionState, MarketplaceInstalledMetadata } from "../../shared/ExtensionMessage"
 import { Mode, defaultModeSlug, getModeBySlug } from "../../shared/modes"
 import { experimentDefault } from "../../shared/experiments"
@@ -144,6 +145,10 @@ export class ClineProvider
 	private pendingOperations: Map<string, PendingEditOperation> = new Map()
 	private static readonly PENDING_OPERATION_TIMEOUT_MS = 30000 // 30 seconds
 
+	// Transactional state posting
+	private uiUpdatePaused: boolean = false
+	private pendingState: ExtensionState | null = null
+
 	public isViewLaunched = false
 	public settingsImportedAt?: number
 	public readonly latestAnnouncementId = "oct-2025-v3.29.0-cloud-agents" // v3.29.0 Cloud Agents announcement
@@ -1624,8 +1629,26 @@ export class ClineProvider
 		await this.postStateToWebview()
 	}
 
+	public beginStateTransaction(): void {
+		this.uiUpdatePaused = true
+	}
+
+	public async endStateTransaction(): Promise<void> {
+		this.uiUpdatePaused = false
+		if (this.pendingState) {
+			await this.view?.webview.postMessage({ type: "state", state: this.pendingState })
+			this.pendingState = null
+		}
+	}
+
 	async postStateToWebview() {
 		const state = await this.getStateToPostToWebview()
+
+		if (this.uiUpdatePaused) {
+			this.pendingState = state
+			return
+		}
+
 		this.postMessageToWebview({ type: "state", state })
 
 		// Check MDM compliance and send user to account tab if not compliant
@@ -2617,24 +2640,13 @@ export class ClineProvider
 
 		console.log(`[cancelTask] cancelling task ${task.taskId}.${task.instanceId}`)
 
-		const { historyItem, uiMessagesFilePath } = await this.getTaskWithId(task.taskId)
-
-		// Preserve parent and root task information for history item.
-		const rootTask = task.rootTask
-		const parentTask = task.parentTask
-
-		// Mark this as a user-initiated cancellation so provider-only rehydration can occur
+		// Mark this as a user-initiated cancellation
 		task.abortReason = "user_cancelled"
 
-		// Capture the current instance to detect if rehydrate already occurred elsewhere
-		const originalInstanceId = task.instanceId
-
-		// Begin abort (non-blocking)
-		task.abortTask()
-
-		// Immediately mark the original instance as abandoned to prevent any residual activity
-		task.abandoned = true
+		// Soft abort (isAbandoned = false) to keep the instance alive
+		await task.abortTask(false)
 
+		// Wait for abort to complete
 		await pWaitFor(
 			() =>
 				this.getCurrentTask()! === undefined ||
@@ -2651,28 +2663,52 @@ export class ClineProvider
 			console.error("Failed to abort task")
 		})
 
-		// Defensive safeguard: if current instance already changed, skip rehydrate
-		const current = this.getCurrentTask()
-		if (current && current.instanceId !== originalInstanceId) {
-			this.log(
-				`[cancelTask] Skipping rehydrate: current instance ${current.instanceId} != original ${originalInstanceId}`,
-			)
-			return
-		}
+		// Deterministic spinner stop: If the last api_req_started has no cost and no cancelReason,
+		// inject cancelReason to stop the spinner
+		try {
+			let lastApiReqStartedIndex = -1
+			for (let i = task.clineMessages.length - 1; i >= 0; i--) {
+				if (task.clineMessages[i].type === "say" && task.clineMessages[i].say === "api_req_started") {
+					lastApiReqStartedIndex = i
+					break
+				}
+			}
 
-		// Final race check before rehydrate to avoid duplicate rehydration
-		{
-			const currentAfterCheck = this.getCurrentTask()
-			if (currentAfterCheck && currentAfterCheck.instanceId !== originalInstanceId) {
-				this.log(
-					`[cancelTask] Skipping rehydrate after final check: current instance ${currentAfterCheck.instanceId} != original ${originalInstanceId}`,
+			if (lastApiReqStartedIndex !== -1) {
+				const lastApiReqStarted = task.clineMessages[lastApiReqStartedIndex]
+				const apiReqInfo = safeJsonParse<{ cost?: number; cancelReason?: string }>(
+					lastApiReqStarted.text || "{}",
+					{},
 				)
-				return
+
+				if (apiReqInfo && apiReqInfo.cost === undefined && apiReqInfo.cancelReason === undefined) {
+					apiReqInfo.cancelReason = "user_cancelled"
+					lastApiReqStarted.text = JSON.stringify(apiReqInfo)
+					await task.overwriteClineMessages([...task.clineMessages])
+					console.log(`[cancelTask] Injected cancelReason for deterministic spinner stop`)
+				}
 			}
+		} catch (error) {
+			console.error(`[cancelTask] Failed to inject cancelReason:`, error)
 		}
 
-		// Clears task again, so we need to abortTask manually above.
-		await this.createTaskWithHistoryItem({ ...historyItem, rootTask, parentTask })
+		// Update UI immediately to reflect current state
+		await this.postStateToWebview()
+
+		// Schedule non-blocking resumption to present "Resume Task" ask
+		// Use setImmediate to avoid blocking the webview handler
+		setImmediate(() => {
+			if (task && !task.abandoned) {
+				// Present a resume ask without rehydrating - just show the Resume/Terminate UI
+				task.presentResumableAsk().catch((error) => {
+					console.error(
+						`[cancelTask] Failed to present resume ask: ${
+							error instanceof Error ? error.message : String(error)
+						}`,
+					)
+				})
+			}
+		})
 	}
 
 	// Clear the current task without treating it as a subtask.

+ 60 - 0
src/core/webview/__tests__/ClineProvider.cancelTask.present-ask.spec.ts

@@ -0,0 +1,60 @@
+import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"
+import { ClineProvider } from "../ClineProvider"
+
+describe("ClineProvider.cancelTask - schedules presentResumableAsk", () => {
+	let provider: ClineProvider
+	let mockTask: any
+
+	beforeEach(() => {
+		vi.useFakeTimers()
+		// Create a bare instance without running constructor
+		provider = Object.create(ClineProvider.prototype) as ClineProvider
+
+		mockTask = {
+			taskId: "task-1",
+			instanceId: "inst-1",
+			abortReason: undefined,
+			abandoned: false,
+			abortTask: vi.fn().mockResolvedValue(undefined),
+			isStreaming: false,
+			didFinishAbortingStream: true,
+			isWaitingForFirstChunk: false,
+			// Last api_req_started without cost/cancelReason so provider injects cancelReason
+			clineMessages: [
+				{ ts: Date.now() - 2000, type: "say", say: "text", text: "hello" },
+				{ ts: Date.now() - 1000, type: "say", say: "api_req_started", text: JSON.stringify({}) },
+			],
+			overwriteClineMessages: vi.fn().mockResolvedValue(undefined),
+			presentResumableAsk: vi.fn().mockResolvedValue(undefined),
+		}
+
+		// Patch required instance methods used by cancelTask
+		;(provider as any).getCurrentTask = vi.fn().mockReturnValue(mockTask)
+		;(provider as any).postStateToWebview = vi.fn().mockResolvedValue(undefined)
+	})
+
+	afterEach(() => {
+		vi.useRealTimers()
+	})
+
+	it("injects cancelReason and schedules presentResumableAsk on soft cancel", async () => {
+		// Act
+		await (provider as any).cancelTask()
+
+		// Assert that abort was initiated
+		expect(mockTask.abortTask).toHaveBeenCalledWith(false)
+
+		// cancelReason injected for spinner stop
+		const last = mockTask.clineMessages.at(-1)
+		expect(last.say).toBe("api_req_started")
+		const parsed = JSON.parse(last.text || "{}")
+		expect(parsed.cancelReason).toBe("user_cancelled")
+
+		// presentResumableAsk is scheduled via setImmediate
+		expect(mockTask.presentResumableAsk).not.toHaveBeenCalled()
+		vi.runAllTimers()
+		// run microtasks as well to flush any pending promises in the scheduled callback
+		await Promise.resolve()
+		expect(mockTask.presentResumableAsk).toHaveBeenCalledTimes(1)
+	})
+})

+ 17 - 9
src/core/webview/webviewMessageHandler.ts

@@ -1033,18 +1033,26 @@ export const webviewMessageHandler = async (
 			const result = checkoutRestorePayloadSchema.safeParse(message.payload)
 
 			if (result.success) {
-				await provider.cancelTask()
+				// Begin transaction to buffer state updates
+				provider.beginStateTransaction()
 
 				try {
-					await pWaitFor(() => provider.getCurrentTask()?.isInitialized === true, { timeout: 3_000 })
-				} catch (error) {
-					vscode.window.showErrorMessage(t("common:errors.checkpoint_timeout"))
-				}
+					await provider.cancelTask()
 
-				try {
-					await provider.getCurrentTask()?.checkpointRestore(result.data)
-				} catch (error) {
-					vscode.window.showErrorMessage(t("common:errors.checkpoint_failed"))
+					try {
+						await pWaitFor(() => provider.getCurrentTask()?.isInitialized === true, { timeout: 3_000 })
+					} catch (error) {
+						vscode.window.showErrorMessage(t("common:errors.checkpoint_timeout"))
+					}
+
+					try {
+						await provider.getCurrentTask()?.checkpointRestore(result.data)
+					} catch (error) {
+						vscode.window.showErrorMessage(t("common:errors.checkpoint_failed"))
+					}
+				} finally {
+					// End transaction and post consolidated state
+					await provider.endStateTransaction()
 				}
 			}
 

+ 15 - 3
webview-ui/src/components/chat/ChatView.tsx

@@ -58,6 +58,7 @@ import { QueuedMessages } from "./QueuedMessages"
 import DismissibleUpsell from "../common/DismissibleUpsell"
 import { useCloudUpsell } from "@src/hooks/useCloudUpsell"
 import { Cloud } from "lucide-react"
+import { safeJsonParse } from "../../../../src/shared/safeJsonParse"
 
 export interface ChatViewProps {
 	isHidden: boolean
@@ -547,10 +548,21 @@ const ChatViewComponent: React.ForwardRefRenderFunction<ChatViewRef, ChatViewPro
 				lastApiReqStarted.text !== undefined &&
 				lastApiReqStarted.say === "api_req_started"
 			) {
-				const cost = JSON.parse(lastApiReqStarted.text).cost
+				const info = safeJsonParse(lastApiReqStarted.text)
 
-				if (cost === undefined) {
-					return true // API request has not finished yet.
+				// If cancelReason is defined, the stream has been cancelled (terminal state)
+				if (typeof info === "object" && info !== null) {
+					if ("cancelReason" in info && info.cancelReason !== undefined) {
+						return false
+					}
+
+					// Otherwise, check if cost is defined to determine if streaming is complete
+					if ("cost" in info && info.cost !== undefined) {
+						return false // API request has finished.
+					}
+
+					// If we have api_req_started without cost or cancelReason, streaming is in progress
+					return true
 				}
 			}
 		}