Просмотр исходного кода

Merge pull request #1429 from RooVetGit/cte/resume-task-checkpoints-config

Choose the correct checkpoint storage strategy when resuming tasks
Chris Estreich 10 месяцев назад
Родитель
Сommit
0a19fb547e

+ 41 - 19
src/core/webview/ClineProvider.ts

@@ -7,14 +7,15 @@ import pWaitFor from "p-wait-for"
 import * as path from "path"
 import * as path from "path"
 import * as vscode from "vscode"
 import * as vscode from "vscode"
 import simpleGit from "simple-git"
 import simpleGit from "simple-git"
-import { setPanel } from "../../activate/registerCommands"
 
 
+import { setPanel } from "../../activate/registerCommands"
 import { ApiConfiguration, ApiProvider, ModelInfo, API_CONFIG_KEYS } from "../../shared/api"
 import { ApiConfiguration, ApiProvider, ModelInfo, API_CONFIG_KEYS } from "../../shared/api"
 import { findLast } from "../../shared/array"
 import { findLast } from "../../shared/array"
 import { supportPrompt } from "../../shared/support-prompt"
 import { supportPrompt } from "../../shared/support-prompt"
 import { GlobalFileNames } from "../../shared/globalFileNames"
 import { GlobalFileNames } from "../../shared/globalFileNames"
 import { SecretKey, GlobalStateKey, SECRET_KEYS, GLOBAL_STATE_KEYS } from "../../shared/globalState"
 import { SecretKey, GlobalStateKey, SECRET_KEYS, GLOBAL_STATE_KEYS } from "../../shared/globalState"
 import { HistoryItem } from "../../shared/HistoryItem"
 import { HistoryItem } from "../../shared/HistoryItem"
+import { CheckpointStorage } from "../../shared/checkpoints"
 import { ApiConfigMeta, ExtensionMessage } from "../../shared/ExtensionMessage"
 import { ApiConfigMeta, ExtensionMessage } from "../../shared/ExtensionMessage"
 import { checkoutDiffPayloadSchema, checkoutRestorePayloadSchema, WebviewMessage } from "../../shared/WebviewMessage"
 import { checkoutDiffPayloadSchema, checkoutRestorePayloadSchema, WebviewMessage } from "../../shared/WebviewMessage"
 import { Mode, PromptComponent, defaultModeSlug, ModeConfig } from "../../shared/modes"
 import { Mode, PromptComponent, defaultModeSlug, ModeConfig } from "../../shared/modes"
@@ -28,6 +29,7 @@ import { getTheme } from "../../integrations/theme/getTheme"
 import WorkspaceTracker from "../../integrations/workspace/WorkspaceTracker"
 import WorkspaceTracker from "../../integrations/workspace/WorkspaceTracker"
 import { McpHub } from "../../services/mcp/McpHub"
 import { McpHub } from "../../services/mcp/McpHub"
 import { McpServerManager } from "../../services/mcp/McpServerManager"
 import { McpServerManager } from "../../services/mcp/McpServerManager"
+import { ShadowCheckpointService } from "../../services/checkpoints/ShadowCheckpointService"
 import { fileExistsAtPath } from "../../utils/fs"
 import { fileExistsAtPath } from "../../utils/fs"
 import { playSound, setSoundEnabled, setSoundVolume } from "../../utils/sound"
 import { playSound, setSoundEnabled, setSoundVolume } from "../../utils/sound"
 import { singleCompletionHandler } from "../../utils/single-completion-handler"
 import { singleCompletionHandler } from "../../utils/single-completion-handler"
@@ -47,7 +49,7 @@ import { getOllamaModels } from "../../api/providers/ollama"
 import { getVsCodeLmModels } from "../../api/providers/vscode-lm"
 import { getVsCodeLmModels } from "../../api/providers/vscode-lm"
 import { getLmStudioModels } from "../../api/providers/lmstudio"
 import { getLmStudioModels } from "../../api/providers/lmstudio"
 import { ACTION_NAMES } from "../CodeActionProvider"
 import { ACTION_NAMES } from "../CodeActionProvider"
-import { Cline } from "../Cline"
+import { Cline, ClineOptions } from "../Cline"
 import { openMention } from "../mentions"
 import { openMention } from "../mentions"
 import { getNonce } from "./getNonce"
 import { getNonce } from "./getNonce"
 import { getUri } from "./getUri"
 import { getUri } from "./getUri"
@@ -525,22 +527,43 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 		const modePrompt = customModePrompts?.[mode] as PromptComponent
 		const modePrompt = customModePrompts?.[mode] as PromptComponent
 		const effectiveInstructions = [globalInstructions, modePrompt?.customInstructions].filter(Boolean).join("\n\n")
 		const effectiveInstructions = [globalInstructions, modePrompt?.customInstructions].filter(Boolean).join("\n\n")
 
 
-		// TODO: The `checkpointStorage` value should be derived from the
-		// task data on disk; the current setting could be different than
-		// the setting at the time the task was created.
+		const taskId = historyItem.id
+		const globalStorageDir = this.contextProxy.globalStorageUri.fsPath
+		const workspaceDir = vscode.workspace.workspaceFolders?.[0]?.uri.fsPath ?? ""
+
+		const checkpoints: Pick<ClineOptions, "enableCheckpoints" | "checkpointStorage"> = {
+			enableCheckpoints,
+			checkpointStorage,
+		}
+
+		if (enableCheckpoints) {
+			try {
+				checkpoints.checkpointStorage = await ShadowCheckpointService.getTaskStorage({
+					taskId,
+					globalStorageDir,
+					workspaceDir,
+				})
+
+				this.log(
+					`[ClineProvider#initClineWithHistoryItem] Using ${checkpoints.checkpointStorage} storage for ${taskId}`,
+				)
+			} catch (error) {
+				checkpoints.enableCheckpoints = false
+				this.log(`[ClineProvider#initClineWithHistoryItem] Error getting task storage: ${error.message}`)
+			}
+		}
 
 
 		const newCline = new Cline({
 		const newCline = new Cline({
 			provider: this,
 			provider: this,
 			apiConfiguration,
 			apiConfiguration,
 			customInstructions: effectiveInstructions,
 			customInstructions: effectiveInstructions,
 			enableDiff,
 			enableDiff,
-			enableCheckpoints,
-			checkpointStorage,
+			...checkpoints,
 			fuzzyMatchThreshold,
 			fuzzyMatchThreshold,
 			historyItem,
 			historyItem,
 			experiments,
 			experiments,
 		})
 		})
-		// get this cline task number id from the history item and set it to newCline
+
 		newCline.setTaskNumber(historyItem.number)
 		newCline.setTaskNumber(historyItem.number)
 		await this.addClineToStack(newCline)
 		await this.addClineToStack(newCline)
 	}
 	}
@@ -2069,21 +2092,20 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 		// delete task from the task history state
 		// delete task from the task history state
 		await this.deleteTaskFromState(id)
 		await this.deleteTaskFromState(id)
 
 
-		// check if checkpoints are enabled
-		const { enableCheckpoints } = await this.getState()
 		// get the base directory of the project
 		// get the base directory of the project
 		const baseDir = vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0)
 		const baseDir = vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0)
 
 
-		// Delete checkpoints branch.
-		// TODO: Also delete the workspace branch if it exists.
-		if (enableCheckpoints && baseDir) {
-			const branchSummary = await simpleGit(baseDir)
-				.branch(["-D", `roo-code-checkpoints-${id}`])
-				.catch(() => undefined)
+		// Delete associated shadow repository or branch.
+		// TODO: Store `workspaceDir` in the `HistoryItem` object.
+		const globalStorageDir = this.contextProxy.globalStorageUri.fsPath
+		const workspaceDir = baseDir ?? ""
 
 
-			if (branchSummary) {
-				console.log(`[deleteTaskWithId${id}] deleted checkpoints branch`)
-			}
+		try {
+			await ShadowCheckpointService.deleteTask({ taskId: id, globalStorageDir, workspaceDir })
+		} catch (error) {
+			console.error(
+				`[deleteTaskWithId${id}] failed to delete associated shadow repository or branch: ${error instanceof Error ? error.message : String(error)}`,
+			)
 		}
 		}
 
 
 		// delete the entire task directory including checkpoints and all content
 		// delete the entire task directory including checkpoints and all content

+ 1 - 2
src/services/checkpoints/RepoPerWorkspaceCheckpointService.ts

@@ -1,5 +1,4 @@
 import * as path from "path"
 import * as path from "path"
-import crypto from "crypto"
 
 
 import { CheckpointServiceOptions } from "./types"
 import { CheckpointServiceOptions } from "./types"
 import { ShadowCheckpointService } from "./ShadowCheckpointService"
 import { ShadowCheckpointService } from "./ShadowCheckpointService"
@@ -64,7 +63,7 @@ export class RepoPerWorkspaceCheckpointService extends ShadowCheckpointService {
 	}
 	}
 
 
 	public static create({ taskId, workspaceDir, shadowDir, log = console.log }: CheckpointServiceOptions) {
 	public static create({ taskId, workspaceDir, shadowDir, log = console.log }: CheckpointServiceOptions) {
-		const workspaceHash = crypto.createHash("sha256").update(workspaceDir).digest("hex").toString().slice(0, 8)
+		const workspaceHash = this.hashWorkspaceDir(workspaceDir)
 
 
 		return new RepoPerWorkspaceCheckpointService(
 		return new RepoPerWorkspaceCheckpointService(
 			taskId,
 			taskId,

+ 134 - 0
src/services/checkpoints/ShadowCheckpointService.ts

@@ -1,12 +1,15 @@
 import fs from "fs/promises"
 import fs from "fs/promises"
 import os from "os"
 import os from "os"
 import * as path from "path"
 import * as path from "path"
+import crypto from "crypto"
 import EventEmitter from "events"
 import EventEmitter from "events"
 
 
 import simpleGit, { SimpleGit } from "simple-git"
 import simpleGit, { SimpleGit } from "simple-git"
 import { globby } from "globby"
 import { globby } from "globby"
+import pWaitFor from "p-wait-for"
 
 
 import { fileExistsAtPath } from "../../utils/fs"
 import { fileExistsAtPath } from "../../utils/fs"
+import { CheckpointStorage } from "../../shared/checkpoints"
 
 
 import { GIT_DISABLED_SUFFIX } from "./constants"
 import { GIT_DISABLED_SUFFIX } from "./constants"
 import { CheckpointDiff, CheckpointResult, CheckpointEventMap } from "./types"
 import { CheckpointDiff, CheckpointResult, CheckpointEventMap } from "./types"
@@ -318,4 +321,135 @@ export abstract class ShadowCheckpointService extends EventEmitter {
 	override once<K extends keyof CheckpointEventMap>(event: K, listener: (data: CheckpointEventMap[K]) => void) {
 	override once<K extends keyof CheckpointEventMap>(event: K, listener: (data: CheckpointEventMap[K]) => void) {
 		return super.once(event, listener)
 		return super.once(event, listener)
 	}
 	}
+
+	/**
+	 * Storage
+	 */
+
+	public static hashWorkspaceDir(workspaceDir: string) {
+		return crypto.createHash("sha256").update(workspaceDir).digest("hex").toString().slice(0, 8)
+	}
+
+	protected static taskRepoDir({ taskId, globalStorageDir }: { taskId: string; globalStorageDir: string }) {
+		return path.join(globalStorageDir, "tasks", taskId, "checkpoints")
+	}
+
+	protected static workspaceRepoDir({
+		globalStorageDir,
+		workspaceDir,
+	}: {
+		globalStorageDir: string
+		workspaceDir: string
+	}) {
+		return path.join(globalStorageDir, "checkpoints", this.hashWorkspaceDir(workspaceDir))
+	}
+
+	public static async getTaskStorage({
+		taskId,
+		globalStorageDir,
+		workspaceDir,
+	}: {
+		taskId: string
+		globalStorageDir: string
+		workspaceDir: string
+	}): Promise<CheckpointStorage | undefined> {
+		// Is there a checkpoints repo in the task directory?
+		const taskRepoDir = this.taskRepoDir({ taskId, globalStorageDir })
+
+		if (await fileExistsAtPath(taskRepoDir)) {
+			return "task"
+		}
+
+		// Does the workspace checkpoints repo have a branch for this task?
+		const workspaceRepoDir = this.workspaceRepoDir({ globalStorageDir, workspaceDir })
+
+		if (!(await fileExistsAtPath(workspaceRepoDir))) {
+			return undefined
+		}
+
+		const git = simpleGit(workspaceRepoDir)
+		const branches = await git.branchLocal()
+
+		if (branches.all.includes(`roo-${taskId}`)) {
+			return "workspace"
+		}
+
+		return undefined
+	}
+
+	public static async deleteTask({
+		taskId,
+		globalStorageDir,
+		workspaceDir,
+	}: {
+		taskId: string
+		globalStorageDir: string
+		workspaceDir: string
+	}) {
+		const storage = await this.getTaskStorage({ taskId, globalStorageDir, workspaceDir })
+
+		if (storage === "task") {
+			const taskRepoDir = this.taskRepoDir({ taskId, globalStorageDir })
+			await fs.rm(taskRepoDir, { recursive: true, force: true })
+			console.log(`[${this.name}#deleteTask.${taskId}] removed ${taskRepoDir}`)
+		} else if (storage === "workspace") {
+			const workspaceRepoDir = this.workspaceRepoDir({ globalStorageDir, workspaceDir })
+			const branchName = `roo-${taskId}`
+			const git = simpleGit(workspaceRepoDir)
+			const success = await this.deleteBranch(git, branchName)
+
+			if (success) {
+				console.log(`[${this.name}#deleteTask.${taskId}] deleted branch ${branchName}`)
+			} else {
+				console.error(`[${this.name}#deleteTask.${taskId}] failed to delete branch ${branchName}`)
+			}
+		}
+	}
+
+	public static async deleteBranch(git: SimpleGit, branchName: string) {
+		const branches = await git.branchLocal()
+
+		if (!branches.all.includes(branchName)) {
+			console.error(`[${this.constructor.name}#deleteBranch] branch ${branchName} does not exist`)
+			return false
+		}
+
+		const currentBranch = await git.revparse(["--abbrev-ref", "HEAD"])
+
+		if (currentBranch === branchName) {
+			const worktree = await git.getConfig("core.worktree")
+
+			try {
+				await git.raw(["config", "--unset", "core.worktree"])
+				await git.reset(["--hard"])
+				await git.clean("f", ["-d"])
+				const defaultBranch = branches.all.includes("main") ? "main" : "master"
+				await git.checkout([defaultBranch, "--force"])
+
+				await pWaitFor(
+					async () => {
+						const newBranch = await git.revparse(["--abbrev-ref", "HEAD"])
+						return newBranch === defaultBranch
+					},
+					{ interval: 500, timeout: 2_000 },
+				)
+
+				await git.branch(["-D", branchName])
+				return true
+			} catch (error) {
+				console.error(
+					`[${this.constructor.name}#deleteBranch] failed to delete branch ${branchName}: ${error instanceof Error ? error.message : String(error)}`,
+				)
+
+				return false
+			} finally {
+				if (worktree.value) {
+					await git.addConfig("core.worktree", worktree.value)
+				}
+			}
+		} else {
+			await git.branch(["-D", branchName])
+			return true
+		}
+	}
 }
 }

+ 81 - 0
src/services/checkpoints/__tests__/ShadowCheckpointService.test.ts

@@ -8,6 +8,8 @@ import { EventEmitter } from "events"
 import { simpleGit, SimpleGit } from "simple-git"
 import { simpleGit, SimpleGit } from "simple-git"
 
 
 import { fileExistsAtPath } from "../../../utils/fs"
 import { fileExistsAtPath } from "../../../utils/fs"
+
+import { ShadowCheckpointService } from "../ShadowCheckpointService"
 import { RepoPerTaskCheckpointService } from "../RepoPerTaskCheckpointService"
 import { RepoPerTaskCheckpointService } from "../RepoPerTaskCheckpointService"
 import { RepoPerWorkspaceCheckpointService } from "../RepoPerWorkspaceCheckpointService"
 import { RepoPerWorkspaceCheckpointService } from "../RepoPerWorkspaceCheckpointService"
 
 
@@ -648,3 +650,82 @@ describe.each([
 		})
 		})
 	})
 	})
 })
 })
+
+describe("ShadowCheckpointService", () => {
+	const taskId = "test-task-storage"
+	const tmpDir = path.join(os.tmpdir(), "CheckpointService")
+	const globalStorageDir = path.join(tmpDir, "global-storage-dir")
+	const workspaceDir = path.join(tmpDir, "workspace-dir")
+	const workspaceHash = ShadowCheckpointService.hashWorkspaceDir(workspaceDir)
+
+	beforeEach(async () => {
+		await fs.mkdir(globalStorageDir, { recursive: true })
+		await fs.mkdir(workspaceDir, { recursive: true })
+	})
+
+	afterEach(async () => {
+		await fs.rm(globalStorageDir, { recursive: true, force: true })
+		await fs.rm(workspaceDir, { recursive: true, force: true })
+	})
+
+	describe("getTaskStorage", () => {
+		it("returns 'task' when task repo exists", async () => {
+			const service = RepoPerTaskCheckpointService.create({
+				taskId,
+				shadowDir: globalStorageDir,
+				workspaceDir,
+				log: () => {},
+			})
+
+			await service.initShadowGit()
+
+			const storage = await ShadowCheckpointService.getTaskStorage({ taskId, globalStorageDir, workspaceDir })
+			expect(storage).toBe("task")
+		})
+
+		it("returns 'workspace' when workspace repo exists with task branch", async () => {
+			const service = RepoPerWorkspaceCheckpointService.create({
+				taskId,
+				shadowDir: globalStorageDir,
+				workspaceDir,
+				log: () => {},
+			})
+
+			await service.initShadowGit()
+
+			const storage = await ShadowCheckpointService.getTaskStorage({ taskId, globalStorageDir, workspaceDir })
+			expect(storage).toBe("workspace")
+		})
+
+		it("returns undefined when no repos exist", async () => {
+			const storage = await ShadowCheckpointService.getTaskStorage({ taskId, globalStorageDir, workspaceDir })
+			expect(storage).toBeUndefined()
+		})
+
+		it("returns undefined when workspace repo exists but has no task branch", async () => {
+			// Setup: Create workspace repo without the task branch
+			const workspaceRepoDir = path.join(globalStorageDir, "checkpoints", workspaceHash)
+			await fs.mkdir(workspaceRepoDir, { recursive: true })
+
+			// Create git repo without adding the specific branch
+			const git = simpleGit(workspaceRepoDir)
+			await git.init()
+			await git.addConfig("user.name", "Roo Code")
+			await git.addConfig("user.email", "[email protected]")
+
+			// We need to create a commit, but we won't create the specific branch
+			const testFile = path.join(workspaceRepoDir, "test.txt")
+			await fs.writeFile(testFile, "Test content")
+			await git.add(".")
+			await git.commit("Initial commit")
+
+			const storage = await ShadowCheckpointService.getTaskStorage({
+				taskId,
+				globalStorageDir,
+				workspaceDir,
+			})
+
+			expect(storage).toBeUndefined()
+		})
+	})
+})