Przeglądaj źródła

fix: Make cancel button immediately responsive during streaming (#9448)

Daniel 1 miesiąc temu
rodzic
commit
1dd223d240

+ 73 - 3
src/core/task/Task.ts

@@ -215,6 +215,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 	providerRef: WeakRef<ClineProvider>
 	private readonly globalStoragePath: string
 	abort: boolean = false
+	currentRequestAbortController?: AbortController
 
 	// TaskStatus
 	idleAsk?: ClineMessage
@@ -1734,6 +1735,18 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 		await this.initiateTaskLoop(newUserContent)
 	}
 
+	/**
+	 * Cancels the current HTTP request if one is in progress.
+	 * This immediately aborts the underlying stream rather than waiting for the next chunk.
+	 */
+	public cancelCurrentRequest(): void {
+		if (this.currentRequestAbortController) {
+			console.log(`[Task#${this.taskId}.${this.instanceId}] Aborting current HTTP request`)
+			this.currentRequestAbortController.abort()
+			this.currentRequestAbortController = undefined
+		}
+	}
+
 	public async abortTask(isAbandoned = false) {
 		// Aborting task
 
@@ -1763,6 +1776,13 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 	public dispose(): void {
 		console.log(`[Task#dispose] disposing task ${this.taskId}.${this.instanceId}`)
 
+		// Cancel any in-progress HTTP request
+		try {
+			this.cancelCurrentRequest()
+		} catch (error) {
+			console.error("Error cancelling current request:", error)
+		}
+
 		// Remove provider profile change listener
 		try {
 			if (this.providerProfileChangeListener) {
@@ -2234,10 +2254,34 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 
 				try {
 					const iterator = stream[Symbol.asyncIterator]()
-					let item = await iterator.next()
+
+					// Helper to race iterator.next() with abort signal
+					const nextChunkWithAbort = async () => {
+						const nextPromise = iterator.next()
+
+						// If we have an abort controller, race it with the next chunk
+						if (this.currentRequestAbortController) {
+							const abortPromise = new Promise<never>((_, reject) => {
+								const signal = this.currentRequestAbortController!.signal
+								if (signal.aborted) {
+									reject(new Error("Request cancelled by user"))
+								} else {
+									signal.addEventListener("abort", () => {
+										reject(new Error("Request cancelled by user"))
+									})
+								}
+							})
+							return await Promise.race([nextPromise, abortPromise])
+						}
+
+						// No abort controller, just return the next chunk normally
+						return await nextPromise
+					}
+
+					let item = await nextChunkWithAbort()
 					while (!item.done) {
 						const chunk = item.value
-						item = await iterator.next()
+						item = await nextChunkWithAbort()
 						if (!chunk) {
 							// Sometimes chunk is undefined, no idea that can cause
 							// it, but this workaround seems to fix it.
@@ -2608,6 +2652,8 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 					}
 				} finally {
 					this.isStreaming = false
+					// Clean up the abort controller when streaming completes
+					this.currentRequestAbortController = undefined
 				}
 
 				// Need to call here in case the stream was aborted.
@@ -3221,6 +3267,10 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 			...(shouldIncludeTools ? { tools: allTools, tool_choice: "auto", toolProtocol } : {}),
 		}
 
+		// Create an AbortController to allow cancelling the request mid-stream
+		this.currentRequestAbortController = new AbortController()
+		const abortSignal = this.currentRequestAbortController.signal
+
 		// The provider accepts reasoning items alongside standard messages; cast to the expected parameter type.
 		const stream = this.api.createMessage(
 			systemPrompt,
@@ -3229,14 +3279,34 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 		)
 		const iterator = stream[Symbol.asyncIterator]()
 
+		// Set up abort handling - when the signal is aborted, clean up the controller reference
+		abortSignal.addEventListener("abort", () => {
+			console.log(`[Task#${this.taskId}.${this.instanceId}] AbortSignal triggered for current request`)
+			this.currentRequestAbortController = undefined
+		})
+
 		try {
 			// Awaiting first chunk to see if it will throw an error.
 			this.isWaitingForFirstChunk = true
-			const firstChunk = await iterator.next()
+
+			// Race between the first chunk and the abort signal
+			const firstChunkPromise = iterator.next()
+			const abortPromise = new Promise<never>((_, reject) => {
+				if (abortSignal.aborted) {
+					reject(new Error("Request cancelled by user"))
+				} else {
+					abortSignal.addEventListener("abort", () => {
+						reject(new Error("Request cancelled by user"))
+					})
+				}
+			})
+
+			const firstChunk = await Promise.race([firstChunkPromise, abortPromise])
 			yield firstChunk.value
 			this.isWaitingForFirstChunk = false
 		} catch (error) {
 			this.isWaitingForFirstChunk = false
+			this.currentRequestAbortController = undefined
 			const isContextWindowExceededError = checkContextWindowExceededError(error)
 
 			// If it's a context window error and we haven't exceeded max retries for this error type

+ 74 - 0
src/core/task/__tests__/Task.spec.ts

@@ -1770,6 +1770,80 @@ describe("Cline", () => {
 				consoleErrorSpy.mockRestore()
 			})
 		})
+
+		describe("cancelCurrentRequest", () => {
+			it("should cancel the current HTTP request via AbortController", () => {
+				const task = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "test task",
+					startTask: false,
+				})
+
+				// Create a real AbortController and spy on its abort method
+				const mockAbortController = new AbortController()
+				const abortSpy = vi.spyOn(mockAbortController, "abort")
+				task.currentRequestAbortController = mockAbortController
+
+				// Spy on console.log
+				const consoleLogSpy = vi.spyOn(console, "log").mockImplementation(() => {})
+
+				// Call cancelCurrentRequest
+				task.cancelCurrentRequest()
+
+				// Verify abort was called on the controller
+				expect(abortSpy).toHaveBeenCalled()
+
+				// Verify the controller was cleared
+				expect(task.currentRequestAbortController).toBeUndefined()
+
+				// Verify logging
+				expect(consoleLogSpy).toHaveBeenCalledWith(expect.stringContaining("Aborting current HTTP request"))
+
+				// Restore console.log
+				consoleLogSpy.mockRestore()
+			})
+
+			it("should handle missing AbortController gracefully", () => {
+				const task = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "test task",
+					startTask: false,
+				})
+
+				// Ensure no controller exists
+				task.currentRequestAbortController = undefined
+
+				// Should not throw when called with no controller
+				expect(() => task.cancelCurrentRequest()).not.toThrow()
+			})
+
+			it("should be called during dispose", () => {
+				const task = new Task({
+					provider: mockProvider,
+					apiConfiguration: mockApiConfig,
+					task: "test task",
+					startTask: false,
+				})
+
+				// Spy on cancelCurrentRequest
+				const cancelSpy = vi.spyOn(task, "cancelCurrentRequest")
+
+				// Mock other dispose operations
+				vi.spyOn(task.messageQueueService, "removeListener").mockImplementation(
+					() => task.messageQueueService as any,
+				)
+				vi.spyOn(task.messageQueueService, "dispose").mockImplementation(() => {})
+				vi.spyOn(task, "removeAllListeners").mockImplementation(() => task as any)
+
+				// Call dispose
+				task.dispose()
+
+				// Verify cancelCurrentRequest was called
+				expect(cancelSpy).toHaveBeenCalled()
+			})
+		})
 	})
 })
 

+ 4 - 0
src/core/webview/ClineProvider.ts

@@ -2734,6 +2734,10 @@ export class ClineProvider
 		// Capture the current instance to detect if rehydrate already occurred elsewhere
 		const originalInstanceId = task.instanceId
 
+		// Immediately cancel the underlying HTTP request if one is in progress
+		// This ensures the stream fails quickly rather than waiting for network timeout
+		task.cancelCurrentRequest()
+
 		// Begin abort (non-blocking)
 		task.abortTask()