Просмотр исходного кода

IPC fixes for task cancellation and queued messages (#11162)

Chris Estreich 1 неделя назад
Родитель
Сommit
e5fa5e8e46

+ 2 - 2
packages/evals/src/cli/runTaskInCli.ts

@@ -264,7 +264,7 @@ export const runTaskWithCli = async ({ run, task, publish, logger, jobToken }: R
 
 		if (rooTaskId && !isClientDisconnected) {
 			logger.info("cancelling task")
-			client.sendCommand({ commandName: TaskCommandName.CancelTask, data: rooTaskId })
+			client.sendCommand({ commandName: TaskCommandName.CancelTask })
 			await new Promise((resolve) => setTimeout(resolve, 5_000))
 		}
 
@@ -289,7 +289,7 @@ export const runTaskWithCli = async ({ run, task, publish, logger, jobToken }: R
 
 	if (rooTaskId && !isClientDisconnected) {
 		logger.info("closing task")
-		client.sendCommand({ commandName: TaskCommandName.CloseTask, data: rooTaskId })
+		client.sendCommand({ commandName: TaskCommandName.CloseTask })
 		await new Promise((resolve) => setTimeout(resolve, 2_000))
 	}
 

+ 2 - 2
packages/evals/src/cli/runTaskInVscode.ts

@@ -270,7 +270,7 @@ export const runTaskInVscode = async ({ run, task, publish, logger, jobToken }:
 
 		if (rooTaskId && !isClientDisconnected) {
 			logger.info("cancelling task")
-			client.sendCommand({ commandName: TaskCommandName.CancelTask, data: rooTaskId })
+			client.sendCommand({ commandName: TaskCommandName.CancelTask })
 			await new Promise((resolve) => setTimeout(resolve, 5_000)) // Allow some time for the task to cancel.
 		}
 
@@ -289,7 +289,7 @@ export const runTaskInVscode = async ({ run, task, publish, logger, jobToken }:
 
 	if (rooTaskId && !isClientDisconnected) {
 		logger.info("closing task")
-		client.sendCommand({ commandName: TaskCommandName.CloseTask, data: rooTaskId })
+		client.sendCommand({ commandName: TaskCommandName.CloseTask })
 		await new Promise((resolve) => setTimeout(resolve, 2_000)) // Allow some time for the window to close.
 	}
 

+ 2 - 2
packages/types/src/__tests__/ipc.test.ts

@@ -27,7 +27,7 @@ describe("IPC Types", () => {
 				const result = taskCommandSchema.safeParse(resumeTaskCommand)
 				expect(result.success).toBe(true)
 
-				if (result.success) {
+				if (result.success && result.data.commandName === TaskCommandName.ResumeTask) {
 					expect(result.data.commandName).toBe("ResumeTask")
 					expect(result.data.data).toBe("non-existent-task-id")
 				}
@@ -45,7 +45,7 @@ describe("IPC Types", () => {
 			const result = taskCommandSchema.safeParse(resumeTaskCommand)
 			expect(result.success).toBe(true)
 
-			if (result.success) {
+			if (result.success && result.data.commandName === TaskCommandName.ResumeTask) {
 				expect(result.data.commandName).toBe("ResumeTask")
 				expect(result.data.data).toBe("task-123")
 			}

+ 8 - 1
packages/types/src/events.ts

@@ -1,6 +1,6 @@
 import { z } from "zod"
 
-import { clineMessageSchema, tokenUsageSchema } from "./message.js"
+import { clineMessageSchema, queuedMessageSchema, tokenUsageSchema } from "./message.js"
 import { toolNamesSchema, toolUsageSchema } from "./tool.js"
 
 /**
@@ -35,6 +35,7 @@ export enum RooCodeEventName {
 	TaskModeSwitched = "taskModeSwitched",
 	TaskAskResponded = "taskAskResponded",
 	TaskUserMessage = "taskUserMessage",
+	QueuedMessagesUpdated = "queuedMessagesUpdated",
 
 	// Task Analytics
 	TaskTokenUsageUpdated = "taskTokenUsageUpdated",
@@ -100,6 +101,7 @@ export const rooCodeEventsSchema = z.object({
 	[RooCodeEventName.TaskModeSwitched]: z.tuple([z.string(), z.string()]),
 	[RooCodeEventName.TaskAskResponded]: z.tuple([z.string()]),
 	[RooCodeEventName.TaskUserMessage]: z.tuple([z.string()]),
+	[RooCodeEventName.QueuedMessagesUpdated]: z.tuple([z.string(), z.array(queuedMessageSchema)]),
 
 	[RooCodeEventName.TaskToolFailed]: z.tuple([z.string(), toolNamesSchema, z.string()]),
 	[RooCodeEventName.TaskTokenUsageUpdated]: z.tuple([z.string(), tokenUsageSchema, toolUsageSchema]),
@@ -217,6 +219,11 @@ export const taskEventSchema = z.discriminatedUnion("eventName", [
 		payload: rooCodeEventsSchema.shape[RooCodeEventName.TaskAskResponded],
 		taskId: z.number().optional(),
 	}),
+	z.object({
+		eventName: z.literal(RooCodeEventName.QueuedMessagesUpdated),
+		payload: rooCodeEventsSchema.shape[RooCodeEventName.QueuedMessagesUpdated],
+		taskId: z.number().optional(),
+	}),
 
 	// Task Analytics
 	z.object({

+ 0 - 2
packages/types/src/ipc.ts

@@ -64,11 +64,9 @@ export const taskCommandSchema = z.discriminatedUnion("commandName", [
 	}),
 	z.object({
 		commandName: z.literal(TaskCommandName.CancelTask),
-		data: z.string(),
 	}),
 	z.object({
 		commandName: z.literal(TaskCommandName.CloseTask),
-		data: z.string(),
 	}),
 	z.object({
 		commandName: z.literal(TaskCommandName.ResumeTask),

+ 1 - 0
packages/types/src/task.ts

@@ -154,6 +154,7 @@ export type TaskEvents = {
 	[RooCodeEventName.TaskModeSwitched]: [taskId: string, mode: string]
 	[RooCodeEventName.TaskAskResponded]: []
 	[RooCodeEventName.TaskUserMessage]: [taskId: string]
+	[RooCodeEventName.QueuedMessagesUpdated]: [taskId: string, messages: QueuedMessage[]]
 
 	// Task Analytics
 	[RooCodeEventName.TaskToolFailed]: [taskId: string, tool: ToolName, error: string]

+ 1 - 0
src/core/task/Task.ts

@@ -675,6 +675,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 
 		this.messageQueueStateChangedHandler = () => {
 			this.emit(RooCodeEventName.TaskUserMessage, this.taskId)
+			this.emit(RooCodeEventName.QueuedMessagesUpdated, this.taskId, this.messageQueueService.messages)
 			this.providerRef.deref()?.postStateToWebviewWithoutTaskHistory()
 		}
 

+ 18 - 26
src/extension/api.ts

@@ -30,7 +30,6 @@ export class API extends EventEmitter<RooCodeEvents> implements RooCodeAPI {
 	private readonly sidebarProvider: ClineProvider
 	private readonly context: vscode.ExtensionContext
 	private readonly ipc?: IpcServer
-	private readonly taskMap = new Map<string, ClineProvider>()
 	private readonly log: (...args: unknown[]) => void
 	private logfile?: string
 
@@ -65,35 +64,37 @@ export class API extends EventEmitter<RooCodeEvents> implements RooCodeAPI {
 			ipc.listen()
 			this.log(`[API] ipc server started: socketPath=${socketPath}, pid=${process.pid}, ppid=${process.ppid}`)
 
-			ipc.on(IpcMessageType.TaskCommand, async (_clientId, { commandName, data }) => {
-				switch (commandName) {
+			ipc.on(IpcMessageType.TaskCommand, async (_clientId, command) => {
+				switch (command.commandName) {
 					case TaskCommandName.StartNewTask:
-						this.log(`[API] StartNewTask -> ${data.text}, ${JSON.stringify(data.configuration)}`)
-						await this.startNewTask(data)
+						this.log(
+							`[API] StartNewTask -> ${command.data.text}, ${JSON.stringify(command.data.configuration)}`,
+						)
+						await this.startNewTask(command.data)
 						break
 					case TaskCommandName.CancelTask:
-						this.log(`[API] CancelTask -> ${data}`)
-						await this.cancelTask(data)
+						this.log(`[API] CancelTask`)
+						await this.cancelCurrentTask()
 						break
 					case TaskCommandName.CloseTask:
-						this.log(`[API] CloseTask -> ${data}`)
+						this.log(`[API] CloseTask`)
 						await vscode.commands.executeCommand("workbench.action.files.saveFiles")
 						await vscode.commands.executeCommand("workbench.action.closeWindow")
 						break
 					case TaskCommandName.ResumeTask:
-						this.log(`[API] ResumeTask -> ${data}`)
+						this.log(`[API] ResumeTask -> ${command.data}`)
 						try {
-							await this.resumeTask(data)
+							await this.resumeTask(command.data)
 						} catch (error) {
 							const errorMessage = error instanceof Error ? error.message : String(error)
-							this.log(`[API] ResumeTask failed for taskId ${data}: ${errorMessage}`)
+							this.log(`[API] ResumeTask failed for taskId ${command.data}: ${errorMessage}`)
 							// Don't rethrow - we want to prevent IPC server crashes
 							// The error is logged for debugging purposes
 						}
 						break
 					case TaskCommandName.SendMessage:
-						this.log(`[API] SendMessage -> ${data.text}`)
-						await this.sendMessage(data.text, data.images)
+						this.log(`[API] SendMessage -> ${command.data.text}`)
+						await this.sendMessage(command.data.text, command.data.images)
 						break
 				}
 			})
@@ -181,15 +182,6 @@ export class API extends EventEmitter<RooCodeEvents> implements RooCodeAPI {
 		await this.sidebarProvider.cancelTask()
 	}
 
-	public async cancelTask(taskId: string) {
-		const provider = this.taskMap.get(taskId)
-
-		if (provider) {
-			await provider.cancelTask()
-			this.taskMap.delete(taskId)
-		}
-	}
-
 	public async sendMessage(text?: string, images?: string[]) {
 		await this.sidebarProvider.postMessageToWebview({ type: "invoke", invoke: "sendMessage", text, images })
 	}
@@ -212,7 +204,6 @@ export class API extends EventEmitter<RooCodeEvents> implements RooCodeAPI {
 
 			task.on(RooCodeEventName.TaskStarted, async () => {
 				this.emit(RooCodeEventName.TaskStarted, task.taskId)
-				this.taskMap.set(task.taskId, provider)
 				await this.fileLog(`[${new Date().toISOString()}] taskStarted -> ${task.taskId}\n`)
 			})
 
@@ -221,8 +212,6 @@ export class API extends EventEmitter<RooCodeEvents> implements RooCodeAPI {
 					isSubtask: !!task.parentTaskId,
 				})
 
-				this.taskMap.delete(task.taskId)
-
 				await this.fileLog(
 					`[${new Date().toISOString()}] taskCompleted -> ${task.taskId} | ${JSON.stringify(tokenUsage, null, 2)} | ${JSON.stringify(toolUsage, null, 2)}\n`,
 				)
@@ -230,7 +219,6 @@ export class API extends EventEmitter<RooCodeEvents> implements RooCodeAPI {
 
 			task.on(RooCodeEventName.TaskAborted, () => {
 				this.emit(RooCodeEventName.TaskAborted, task.taskId)
-				this.taskMap.delete(task.taskId)
 			})
 
 			task.on(RooCodeEventName.TaskFocused, () => {
@@ -301,6 +289,10 @@ export class API extends EventEmitter<RooCodeEvents> implements RooCodeAPI {
 				this.emit(RooCodeEventName.TaskAskResponded, task.taskId)
 			})
 
+			task.on(RooCodeEventName.QueuedMessagesUpdated, (taskId, messages) => {
+				this.emit(RooCodeEventName.QueuedMessagesUpdated, taskId, messages)
+			})
+
 			// Task Analytics
 
 			task.on(RooCodeEventName.TaskToolFailed, (taskId, tool, error) => {