Просмотр исходного кода

Merge pull request #1357 from RooVetGit/cte/checkpoint-per-task

Add a repo-per-workspace checkpoint service in addition to repo-per-task
Chris Estreich 10 месяцев назад
Родитель
Сommit
fd60c9466b

+ 112 - 50
src/core/Cline.ts

@@ -1,6 +1,6 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import cloneDeep from "clone-deep"
-import { DiffStrategy, getDiffStrategy, UnifiedDiffStrategy } from "./diff/DiffStrategy"
+import { DiffStrategy, getDiffStrategy } from "./diff/DiffStrategy"
 import { validateToolUse, isToolAllowedForMode, ToolName } from "./mode-validator"
 import delay from "delay"
 import fs from "fs/promises"
@@ -13,7 +13,11 @@ import * as vscode from "vscode"
 import { ApiHandler, buildApiHandler } from "../api"
 import { ApiStream } from "../api/transform/stream"
 import { DIFF_VIEW_URI_SCHEME, DiffViewProvider } from "../integrations/editor/DiffViewProvider"
-import { ShadowCheckpointService } from "../services/checkpoints/ShadowCheckpointService"
+import {
+	CheckpointServiceOptions,
+	RepoPerTaskCheckpointService,
+	RepoPerWorkspaceCheckpointService,
+} from "../services/checkpoints"
 import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
 import {
 	extractTextFromFile,
@@ -77,6 +81,7 @@ export type ClineOptions = {
 	customInstructions?: string
 	enableDiff?: boolean
 	enableCheckpoints?: boolean
+	checkpointStorage?: "task" | "workspace"
 	fuzzyMatchThreshold?: number
 	task?: string
 	images?: string[]
@@ -115,8 +120,9 @@ export class Cline {
 	isInitialized = false
 
 	// checkpoints
-	enableCheckpoints: boolean = false
-	private checkpointService?: ShadowCheckpointService
+	private enableCheckpoints: boolean
+	private checkpointStorage: "task" | "workspace"
+	private checkpointService?: RepoPerTaskCheckpointService | RepoPerWorkspaceCheckpointService
 
 	// streaming
 	isWaitingForFirstChunk = false
@@ -136,7 +142,8 @@ export class Cline {
 		apiConfiguration,
 		customInstructions,
 		enableDiff,
-		enableCheckpoints,
+		enableCheckpoints = false,
+		checkpointStorage = "task",
 		fuzzyMatchThreshold,
 		task,
 		images,
@@ -160,7 +167,8 @@ export class Cline {
 		this.fuzzyMatchThreshold = fuzzyMatchThreshold ?? 1.0
 		this.providerRef = new WeakRef(provider)
 		this.diffViewProvider = new DiffViewProvider(cwd)
-		this.enableCheckpoints = enableCheckpoints ?? false
+		this.enableCheckpoints = enableCheckpoints
+		this.checkpointStorage = checkpointStorage
 
 		// Initialize diffStrategy based on current state
 		this.updateDiffStrategy(Experiments.isEnabled(experiments ?? {}, EXPERIMENT_IDS.DIFF_STRATEGY))
@@ -747,7 +755,8 @@ export class Cline {
 	}
 
 	private async initiateTaskLoop(userContent: UserContent): Promise<void> {
-		this.initializeCheckpoints()
+		// Kicks off the checkpoints initialization process in the background.
+		this.getCheckpointService()
 
 		let nextUserContent = userContent
 		let includeFileDetails = true
@@ -3352,9 +3361,13 @@ export class Cline {
 
 	// Checkpoints
 
-	private async initializeCheckpoints() {
+	private getCheckpointService() {
 		if (!this.enableCheckpoints) {
-			return
+			return undefined
+		}
+
+		if (this.checkpointService) {
+			return this.checkpointService
 		}
 
 		const log = (message: string) => {
@@ -3368,47 +3381,45 @@ export class Cline {
 		}
 
 		try {
-			if (this.checkpointService) {
-				log("[Cline#initializeCheckpoints] checkpointService already initialized")
-				return
-			}
-
 			const workspaceDir = vscode.workspace.workspaceFolders?.map((folder) => folder.uri.fsPath).at(0)
 
 			if (!workspaceDir) {
 				log("[Cline#initializeCheckpoints] workspace folder not found, disabling checkpoints")
 				this.enableCheckpoints = false
-				return
+				return undefined
 			}
 
-			const shadowDir = this.providerRef.deref()?.context.globalStorageUri.fsPath
+			const globalStorageDir = this.providerRef.deref()?.context.globalStorageUri.fsPath
 
-			if (!shadowDir) {
-				log("[Cline#initializeCheckpoints] shadowDir not found, disabling checkpoints")
+			if (!globalStorageDir) {
+				log("[Cline#initializeCheckpoints] globalStorageDir not found, disabling checkpoints")
 				this.enableCheckpoints = false
-				return
+				return undefined
 			}
 
-			const service = await ShadowCheckpointService.create({ taskId: this.taskId, workspaceDir, shadowDir, log })
-
-			if (!service) {
-				log("[Cline#initializeCheckpoints] failed to create checkpoint service, disabling checkpoints")
-				this.enableCheckpoints = false
-				return
+			const options: CheckpointServiceOptions = {
+				taskId: this.taskId,
+				workspaceDir,
+				shadowDir: globalStorageDir,
+				log,
 			}
 
-			service.on("initialize", ({ workspaceDir, created, duration }) => {
+			const service =
+				this.checkpointStorage === "task"
+					? RepoPerTaskCheckpointService.create(options)
+					: RepoPerWorkspaceCheckpointService.create(options)
+
+			service.on("initialize", () => {
 				try {
-					if (created) {
-						log(`[Cline#initializeCheckpoints] created new shadow repo (${workspaceDir}) in ${duration}ms`)
-					} else {
-						log(
-							`[Cline#initializeCheckpoints] found existing shadow repo (${workspaceDir}) in ${duration}ms`,
-						)
-					}
+					const isCheckpointNeeded =
+						typeof this.clineMessages.find(({ say }) => say === "checkpoint_saved") === "undefined"
 
 					this.checkpointService = service
-					this.checkpointSave()
+
+					if (isCheckpointNeeded) {
+						log("[Cline#initializeCheckpoints] no checkpoints found, saving initial checkpoint")
+						this.checkpointSave()
+					}
 				} catch (err) {
 					log("[Cline#initializeCheckpoints] caught error in on('initialize'), disabling checkpoints")
 					this.enableCheckpoints = false
@@ -3417,41 +3428,77 @@ export class Cline {
 
 			service.on("checkpoint", ({ isFirst, fromHash: from, toHash: to }) => {
 				try {
-					log(`[Cline#initializeCheckpoints] ${isFirst ? "initial" : "incremental"} checkpoint saved: ${to}`)
 					this.providerRef.deref()?.postMessageToWebview({ type: "currentCheckpointUpdated", text: to })
 
-					this.say("checkpoint_saved", to, undefined, undefined, { isFirst, from, to }).catch((e) =>
-						console.error("Error saving checkpoint message:", e),
-					)
+					this.say("checkpoint_saved", to, undefined, undefined, { isFirst, from, to }).catch((err) => {
+						log("[Cline#initializeCheckpoints] caught unexpected error in say('checkpoint_saved')")
+						console.error(err)
+					})
 				} catch (err) {
-					log("[Cline#initializeCheckpoints] caught error in on('checkpoint'), disabling checkpoints")
+					log(
+						"[Cline#initializeCheckpoints] caught unexpected error in on('checkpoint'), disabling checkpoints",
+					)
+					console.error(err)
 					this.enableCheckpoints = false
 				}
 			})
 
-			service.initShadowGit()
+			service.initShadowGit().catch((err) => {
+				log("[Cline#initializeCheckpoints] caught unexpected error in initShadowGit, disabling checkpoints")
+				console.error(err)
+				this.enableCheckpoints = false
+			})
+
+			return service
 		} catch (err) {
-			log("[Cline#initializeCheckpoints] caught error in initializeCheckpoints(), disabling checkpoints")
+			log("[Cline#initializeCheckpoints] caught unexpected error, disabling checkpoints")
 			this.enableCheckpoints = false
+			return undefined
+		}
+	}
+
+	private async getInitializedCheckpointService({
+		interval = 250,
+		timeout = 15_000,
+	}: { interval?: number; timeout?: number } = {}) {
+		const service = this.getCheckpointService()
+
+		if (!service || service.isInitialized) {
+			return service
+		}
+
+		try {
+			await pWaitFor(
+				() => {
+					console.log("[Cline#getCheckpointService] waiting for service to initialize")
+					return service.isInitialized
+				},
+				{ interval, timeout },
+			)
+			return service
+		} catch (err) {
+			return undefined
 		}
 	}
 
 	public async checkpointDiff({
 		ts,
+		previousCommitHash,
 		commitHash,
 		mode,
 	}: {
 		ts: number
+		previousCommitHash?: string
 		commitHash: string
 		mode: "full" | "checkpoint"
 	}) {
-		if (!this.checkpointService || !this.enableCheckpoints) {
+		const service = await this.getInitializedCheckpointService()
+
+		if (!service) {
 			return
 		}
 
-		let previousCommitHash = undefined
-
-		if (mode === "checkpoint") {
+		if (!previousCommitHash && mode === "checkpoint") {
 			const previousCheckpoint = this.clineMessages
 				.filter(({ say }) => say === "checkpoint_saved")
 				.sort((a, b) => b.ts - a.ts)
@@ -3461,7 +3508,7 @@ export class Cline {
 		}
 
 		try {
-			const changes = await this.checkpointService.getDiff({ from: previousCommitHash, to: commitHash })
+			const changes = await service.getDiff({ from: previousCommitHash, to: commitHash })
 
 			if (!changes?.length) {
 				vscode.window.showInformationMessage("No changes found.")
@@ -3488,12 +3535,25 @@ export class Cline {
 	}
 
 	public checkpointSave() {
-		if (!this.checkpointService || !this.enableCheckpoints) {
+		const service = this.getCheckpointService()
+
+		if (!service) {
+			return
+		}
+
+		if (!service.isInitialized) {
+			this.providerRef
+				.deref()
+				?.log("[checkpointSave] checkpoints didn't initialize in time, disabling checkpoints for this task")
+			this.enableCheckpoints = false
 			return
 		}
 
 		// Start the checkpoint process in the background.
-		this.checkpointService.saveCheckpoint(`Task: ${this.taskId}, Time: ${Date.now()}`)
+		service.saveCheckpoint(`Task: ${this.taskId}, Time: ${Date.now()}`).catch((err) => {
+			console.error("[Cline#checkpointSave] caught unexpected error, disabling checkpoints", err)
+			this.enableCheckpoints = false
+		})
 	}
 
 	public async checkpointRestore({
@@ -3505,7 +3565,9 @@ export class Cline {
 		commitHash: string
 		mode: "preview" | "restore"
 	}) {
-		if (!this.checkpointService || !this.enableCheckpoints) {
+		const service = await this.getInitializedCheckpointService()
+
+		if (!service) {
 			return
 		}
 
@@ -3516,7 +3578,7 @@ export class Cline {
 		}
 
 		try {
-			await this.checkpointService.restoreCheckpoint(commitHash)
+			await service.restoreCheckpoint(commitHash)
 
 			await this.providerRef.deref()?.postMessageToWebview({ type: "currentCheckpointUpdated", text: commitHash })
 

+ 15 - 0
src/services/checkpoints/RepoPerTaskCheckpointService.ts

@@ -0,0 +1,15 @@
+import * as path from "path"
+
+import { CheckpointServiceOptions } from "./types"
+import { ShadowCheckpointService } from "./ShadowCheckpointService"
+
+export class RepoPerTaskCheckpointService extends ShadowCheckpointService {
+	public static create({ taskId, workspaceDir, shadowDir, log = console.log }: CheckpointServiceOptions) {
+		return new RepoPerTaskCheckpointService(
+			taskId,
+			path.join(shadowDir, "tasks", taskId, "checkpoints"),
+			workspaceDir,
+			log,
+		)
+	}
+}

+ 76 - 0
src/services/checkpoints/RepoPerWorkspaceCheckpointService.ts

@@ -0,0 +1,76 @@
+import * as path from "path"
+import crypto from "crypto"
+
+import { CheckpointServiceOptions } from "./types"
+import { ShadowCheckpointService } from "./ShadowCheckpointService"
+
+export class RepoPerWorkspaceCheckpointService extends ShadowCheckpointService {
+	private async checkoutTaskBranch(source: string) {
+		if (!this.git) {
+			throw new Error("Shadow git repo not initialized")
+		}
+
+		const startTime = Date.now()
+		const branch = `roo-${this.taskId}`
+		const currentBranch = await this.git.revparse(["--abbrev-ref", "HEAD"])
+
+		if (currentBranch === branch) {
+			return
+		}
+
+		this.log(`[${this.constructor.name}#checkoutTaskBranch{${source}}] checking out ${branch}`)
+		const branches = await this.git.branchLocal()
+		let exists = branches.all.includes(branch)
+
+		if (!exists) {
+			await this.git.checkoutLocalBranch(branch)
+		} else {
+			await this.git.checkout(branch)
+		}
+
+		const duration = Date.now() - startTime
+
+		this.log(
+			`[${this.constructor.name}#checkoutTaskBranch{${source}}] ${exists ? "checked out" : "created"} branch "${branch}" in ${duration}ms`,
+		)
+	}
+
+	override async initShadowGit() {
+		return await super.initShadowGit(() => this.checkoutTaskBranch("initShadowGit"))
+	}
+
+	override async saveCheckpoint(message: string) {
+		await this.checkoutTaskBranch("saveCheckpoint")
+		return super.saveCheckpoint(message)
+	}
+
+	override async restoreCheckpoint(commitHash: string) {
+		await this.checkoutTaskBranch("restoreCheckpoint")
+		await super.restoreCheckpoint(commitHash)
+	}
+
+	override async getDiff({ from, to }: { from?: string; to?: string }) {
+		if (!this.git) {
+			throw new Error("Shadow git repo not initialized")
+		}
+
+		await this.checkoutTaskBranch("getDiff")
+
+		if (!from && to) {
+			from = `${to}~`
+		}
+
+		return super.getDiff({ from, to })
+	}
+
+	public static create({ taskId, workspaceDir, shadowDir, log = console.log }: CheckpointServiceOptions) {
+		const workspaceHash = crypto.createHash("sha256").update(workspaceDir).digest("hex").toString().slice(0, 8)
+
+		return new RepoPerWorkspaceCheckpointService(
+			taskId,
+			path.join(shadowDir, "checkpoints", workspaceHash),
+			workspaceDir,
+			log,
+		)
+	}
+}

+ 117 - 92
src/services/checkpoints/ShadowCheckpointService.ts

@@ -1,58 +1,70 @@
 import fs from "fs/promises"
 import os from "os"
 import * as path from "path"
-import { globby } from "globby"
+import EventEmitter from "events"
+
 import simpleGit, { SimpleGit } from "simple-git"
+import { globby } from "globby"
 
 import { GIT_DISABLED_SUFFIX, GIT_EXCLUDES } from "./constants"
-import { CheckpointService, CheckpointServiceOptions, CheckpointEventEmitter } from "./types"
+import { CheckpointDiff, CheckpointResult, CheckpointEventMap } from "./types"
 
-export interface ShadowCheckpointServiceOptions extends CheckpointServiceOptions {
-	shadowDir: string
-}
+export abstract class ShadowCheckpointService extends EventEmitter {
+	public readonly taskId: string
+	public readonly checkpointsDir: string
+	public readonly workspaceDir: string
 
-export class ShadowCheckpointService extends CheckpointEventEmitter implements CheckpointService {
-	public readonly version = 1
+	protected _checkpoints: string[] = []
+	protected _baseHash?: string
 
-	private _checkpoints: string[] = []
-	private _baseHash?: string
-	private _isInitialized = false
+	protected readonly dotGitDir: string
+	protected git?: SimpleGit
+	protected readonly log: (message: string) => void
+	protected shadowGitConfigWorktree?: string
 
 	public get baseHash() {
 		return this._baseHash
 	}
 
-	private set baseHash(value: string | undefined) {
+	protected set baseHash(value: string | undefined) {
 		this._baseHash = value
 	}
 
 	public get isInitialized() {
-		return this._isInitialized
+		return !!this.git
 	}
 
-	private set isInitialized(value: boolean) {
-		this._isInitialized = value
-	}
+	constructor(taskId: string, checkpointsDir: string, workspaceDir: string, log: (message: string) => void) {
+		super()
+
+		const homedir = os.homedir()
+		const desktopPath = path.join(homedir, "Desktop")
+		const documentsPath = path.join(homedir, "Documents")
+		const downloadsPath = path.join(homedir, "Downloads")
+		const protectedPaths = [homedir, desktopPath, documentsPath, downloadsPath]
+
+		if (protectedPaths.includes(workspaceDir)) {
+			throw new Error(`Cannot use checkpoints in ${workspaceDir}`)
+		}
 
-	private readonly shadowGitDir: string
-	private shadowGitConfigWorktree?: string
+		this.taskId = taskId
+		this.checkpointsDir = checkpointsDir
+		this.workspaceDir = workspaceDir
 
-	private constructor(
-		public readonly taskId: string,
-		public readonly git: SimpleGit,
-		public readonly shadowDir: string,
-		public readonly workspaceDir: string,
-		private readonly log: (message: string) => void,
-	) {
-		super()
-		this.shadowGitDir = path.join(this.shadowDir, "tasks", this.taskId, "checkpoints", ".git")
+		this.dotGitDir = path.join(this.checkpointsDir, ".git")
+		this.log = log
 	}
 
-	public async initShadowGit() {
-		if (this.isInitialized) {
-			return
+	public async initShadowGit(onInit?: () => Promise<void>) {
+		if (this.git) {
+			throw new Error("Shadow git repo already initialized")
 		}
 
+		await fs.mkdir(this.checkpointsDir, { recursive: true })
+		const git = simpleGit(this.checkpointsDir)
+		const gitVersion = await git.version()
+		this.log(`[${this.constructor.name}#create] git = ${gitVersion}`)
+
 		const fileExistsAtPath = (path: string) =>
 			fs
 				.access(path)
@@ -62,9 +74,9 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 		let created = false
 		const startTime = Date.now()
 
-		if (await fileExistsAtPath(this.shadowGitDir)) {
-			this.log(`[CheckpointService#initShadowGit] shadow git repo already exists at ${this.shadowGitDir}`)
-			const worktree = await this.getShadowGitConfigWorktree()
+		if (await fileExistsAtPath(this.dotGitDir)) {
+			this.log(`[${this.constructor.name}#initShadowGit] shadow git repo already exists at ${this.dotGitDir}`)
+			const worktree = await this.getShadowGitConfigWorktree(git)
 
 			if (worktree !== this.workspaceDir) {
 				throw new Error(
@@ -72,15 +84,15 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 				)
 			}
 
-			this.baseHash = await this.git.revparse(["--abbrev-ref", "HEAD"])
+			this.baseHash = await git.revparse(["HEAD"])
 		} else {
-			this.log(`[CheckpointService#initShadowGit] creating shadow git repo at ${this.workspaceDir}`)
+			this.log(`[${this.constructor.name}#initShadowGit] creating shadow git repo at ${this.checkpointsDir}`)
 
-			await this.git.init()
-			await this.git.addConfig("core.worktree", this.workspaceDir) // Sets the working tree to the current workspace.
-			await this.git.addConfig("commit.gpgSign", "false") // Disable commit signing for shadow repo.
-			await this.git.addConfig("user.name", "Roo Code")
-			await this.git.addConfig("user.email", "[email protected]")
+			await git.init()
+			await git.addConfig("core.worktree", this.workspaceDir) // Sets the working tree to the current workspace.
+			await git.addConfig("commit.gpgSign", "false") // Disable commit signing for shadow repo.
+			await git.addConfig("user.name", "Roo Code")
+			await git.addConfig("user.email", "[email protected]")
 
 			let lfsPatterns: string[] = [] // Get LFS patterns from workspace if they exist.
 
@@ -95,7 +107,7 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 				}
 			} catch (error) {
 				this.log(
-					`[CheckpointService#initShadowGit] failed to read .gitattributes: ${error instanceof Error ? error.message : String(error)}`,
+					`[${this.constructor.name}#initShadowGit] failed to read .gitattributes: ${error instanceof Error ? error.message : String(error)}`,
 				)
 			}
 
@@ -104,21 +116,23 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 			// .git/info/exclude is local to the shadow git repo, so it's not
 			// shared with the main repo - and won't conflict with user's
 			// .gitignore.
-			await fs.mkdir(path.join(this.shadowGitDir, "info"), { recursive: true })
-			const excludesPath = path.join(this.shadowGitDir, "info", "exclude")
+			await fs.mkdir(path.join(this.dotGitDir, "info"), { recursive: true })
+			const excludesPath = path.join(this.dotGitDir, "info", "exclude")
 			await fs.writeFile(excludesPath, [...GIT_EXCLUDES, ...lfsPatterns].join("\n"))
-			await this.stageAll()
-			const { commit } = await this.git.commit("initial commit", { "--allow-empty": null })
+			await this.stageAll(git)
+			const { commit } = await git.commit("initial commit", { "--allow-empty": null })
 			this.baseHash = commit
-			this.log(`[CheckpointService#initShadowGit] base commit is ${commit}`)
-
 			created = true
 		}
 
 		const duration = Date.now() - startTime
-		this.log(`[CheckpointService#initShadowGit] initialized shadow git in ${duration}ms`)
+		this.log(
+			`[${this.constructor.name}#initShadowGit] initialized shadow repo with base commit ${this.baseHash} in ${duration}ms`,
+		)
+
+		this.git = git
 
-		this.isInitialized = true
+		await onInit?.()
 
 		this.emit("initialize", {
 			type: "initialize",
@@ -127,16 +141,19 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 			created,
 			duration,
 		})
+
+		return { created, duration }
 	}
 
-	private async stageAll() {
+	private async stageAll(git: SimpleGit) {
+		// await writeExcludesFile(gitPath, await getLfsPatterns(this.cwd)).
 		await this.renameNestedGitRepos(true)
 
 		try {
-			await this.git.add(".")
+			await git.add(".")
 		} catch (error) {
 			this.log(
-				`[CheckpointService#stageAll] failed to add files to git: ${error instanceof Error ? error.message : String(error)}`,
+				`[${this.constructor.name}#stageAll] failed to add files to git: ${error instanceof Error ? error.message : String(error)}`,
 			)
 		} finally {
 			await this.renameNestedGitRepos(false)
@@ -172,23 +189,23 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 			try {
 				await fs.rename(fullPath, newPath)
 				this.log(
-					`[CheckpointService#renameNestedGitRepos] ${disable ? "disabled" : "enabled"} nested git repo ${gitPath}`,
+					`[${this.constructor.name}#renameNestedGitRepos] ${disable ? "disabled" : "enabled"} nested git repo ${gitPath}`,
 				)
 			} catch (error) {
 				this.log(
-					`[CheckpointService#renameNestedGitRepos] failed to ${disable ? "disable" : "enable"} nested git repo ${gitPath}: ${error instanceof Error ? error.message : String(error)}`,
+					`[${this.constructor.name}#renameNestedGitRepos] failed to ${disable ? "disable" : "enable"} nested git repo ${gitPath}: ${error instanceof Error ? error.message : String(error)}`,
 				)
 			}
 		}
 	}
 
-	public async getShadowGitConfigWorktree() {
+	private async getShadowGitConfigWorktree(git: SimpleGit) {
 		if (!this.shadowGitConfigWorktree) {
 			try {
-				this.shadowGitConfigWorktree = (await this.git.getConfig("core.worktree")).value || undefined
+				this.shadowGitConfigWorktree = (await git.getConfig("core.worktree")).value || undefined
 			} catch (error) {
 				this.log(
-					`[CheckpointService#getShadowGitConfigWorktree] failed to get core.worktree: ${error instanceof Error ? error.message : String(error)}`,
+					`[${this.constructor.name}#getShadowGitConfigWorktree] failed to get core.worktree: ${error instanceof Error ? error.message : String(error)}`,
 				)
 			}
 		}
@@ -196,27 +213,39 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 		return this.shadowGitConfigWorktree
 	}
 
-	public async saveCheckpoint(message: string) {
+	public async saveCheckpoint(message: string): Promise<CheckpointResult | undefined> {
 		try {
-			this.log("[CheckpointService#saveCheckpoint] starting checkpoint save")
+			this.log(`[${this.constructor.name}#saveCheckpoint] starting checkpoint save`)
 
-			if (!this.isInitialized) {
+			if (!this.git) {
 				throw new Error("Shadow git repo not initialized")
 			}
 
 			const startTime = Date.now()
-			await this.stageAll()
+			await this.stageAll(this.git)
 			const result = await this.git.commit(message)
 			const isFirst = this._checkpoints.length === 0
 			const fromHash = this._checkpoints[this._checkpoints.length - 1] ?? this.baseHash!
-			const toHash = result.commit ?? fromHash
+			const toHash = result.commit || fromHash
 			this._checkpoints.push(toHash)
 			const duration = Date.now() - startTime
-			this.emit("checkpoint", { type: "checkpoint", isFirst, fromHash, toHash, duration })
-			return result.commit ? result : undefined
+
+			if (isFirst || result.commit) {
+				this.emit("checkpoint", { type: "checkpoint", isFirst, fromHash, toHash, duration })
+			}
+
+			if (result.commit) {
+				this.log(
+					`[${this.constructor.name}#saveCheckpoint] checkpoint saved in ${duration}ms -> ${result.commit}`,
+				)
+				return result
+			} else {
+				this.log(`[${this.constructor.name}#saveCheckpoint] found no changes to commit in ${duration}ms`)
+				return undefined
+			}
 		} catch (e) {
 			const error = e instanceof Error ? e : new Error(String(e))
-			this.log(`[CheckpointService#saveCheckpoint] failed to create checkpoint: ${error.message}`)
+			this.log(`[${this.constructor.name}#saveCheckpoint] failed to create checkpoint: ${error.message}`)
 			this.emit("error", { type: "error", error })
 			throw error
 		}
@@ -224,7 +253,9 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 
 	public async restoreCheckpoint(commitHash: string) {
 		try {
-			if (!this.isInitialized) {
+			this.log(`[${this.constructor.name}#restoreCheckpoint] starting checkpoint restore`)
+
+			if (!this.git) {
 				throw new Error("Shadow git repo not initialized")
 			}
 
@@ -241,17 +272,17 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 
 			const duration = Date.now() - start
 			this.emit("restore", { type: "restore", commitHash, duration })
-			this.log(`[CheckpointService#restoreCheckpoint] restored checkpoint ${commitHash} in ${duration}ms`)
+			this.log(`[${this.constructor.name}#restoreCheckpoint] restored checkpoint ${commitHash} in ${duration}ms`)
 		} catch (e) {
 			const error = e instanceof Error ? e : new Error(String(e))
-			this.log(`[CheckpointService#restoreCheckpoint] failed to restore checkpoint: ${error.message}`)
+			this.log(`[${this.constructor.name}#restoreCheckpoint] failed to restore checkpoint: ${error.message}`)
 			this.emit("error", { type: "error", error })
 			throw error
 		}
 	}
 
-	public async getDiff({ from, to }: { from?: string; to?: string }) {
-		if (!this.isInitialized) {
+	public async getDiff({ from, to }: { from?: string; to?: string }): Promise<CheckpointDiff[]> {
+		if (!this.git) {
 			throw new Error("Shadow git repo not initialized")
 		}
 
@@ -262,11 +293,12 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 		}
 
 		// Stage all changes so that untracked files appear in diff summary.
-		await this.stageAll()
+		await this.stageAll(this.git)
 
+		this.log(`[${this.constructor.name}#getDiff] diffing ${to ? `${from}..${to}` : `${from}..HEAD`}`)
 		const { files } = to ? await this.git.diffSummary([`${from}..${to}`]) : await this.git.diffSummary([from])
 
-		const cwdPath = (await this.getShadowGitConfigWorktree()) || this.workspaceDir || ""
+		const cwdPath = (await this.getShadowGitConfigWorktree(this.git)) || this.workspaceDir || ""
 
 		for (const file of files) {
 			const relPath = file.file
@@ -283,30 +315,23 @@ export class ShadowCheckpointService extends CheckpointEventEmitter implements C
 		return result
 	}
 
-	public static async create({ taskId, shadowDir, workspaceDir, log = console.log }: ShadowCheckpointServiceOptions) {
-		try {
-			await simpleGit().version()
-		} catch (error) {
-			log("[CheckpointService#create] git is not installed")
-			throw new Error("Git must be installed to use checkpoints.")
-		}
+	/**
+	 * EventEmitter
+	 */
 
-		const homedir = os.homedir()
-		const desktopPath = path.join(homedir, "Desktop")
-		const documentsPath = path.join(homedir, "Documents")
-		const downloadsPath = path.join(homedir, "Downloads")
-		const protectedPaths = [homedir, desktopPath, documentsPath, downloadsPath]
+	override emit<K extends keyof CheckpointEventMap>(event: K, data: CheckpointEventMap[K]) {
+		return super.emit(event, data)
+	}
 
-		if (protectedPaths.includes(workspaceDir)) {
-			throw new Error(`Cannot use checkpoints in ${workspaceDir}`)
-		}
+	override on<K extends keyof CheckpointEventMap>(event: K, listener: (data: CheckpointEventMap[K]) => void) {
+		return super.on(event, listener)
+	}
 
-		const checkpointsDir = path.join(shadowDir, "tasks", taskId, "checkpoints")
-		await fs.mkdir(checkpointsDir, { recursive: true })
-		const gitDir = path.join(checkpointsDir, ".git")
-		const git = simpleGit(path.dirname(gitDir))
+	override off<K extends keyof CheckpointEventMap>(event: K, listener: (data: CheckpointEventMap[K]) => void) {
+		return super.off(event, listener)
+	}
 
-		log(`[CheckpointService#create] taskId = ${taskId}, workspaceDir = ${workspaceDir}, shadowDir = ${shadowDir}`)
-		return new ShadowCheckpointService(taskId, git, shadowDir, workspaceDir, log)
+	override once<K extends keyof CheckpointEventMap>(event: K, listener: (data: CheckpointEventMap[K]) => void) {
+		return super.once(event, listener)
 	}
 }

+ 62 - 57
src/services/checkpoints/__tests__/ShadowCheckpointService.test.ts

@@ -7,65 +7,69 @@ import { EventEmitter } from "events"
 
 import { simpleGit, SimpleGit } from "simple-git"
 
-import { ShadowCheckpointService } from "../ShadowCheckpointService"
+import { RepoPerTaskCheckpointService } from "../RepoPerTaskCheckpointService"
+import { RepoPerWorkspaceCheckpointService } from "../RepoPerWorkspaceCheckpointService"
 
 jest.mock("globby", () => ({
 	globby: jest.fn().mockResolvedValue([]),
 }))
 
-const tmpDir = path.join(os.tmpdir(), "test-ShadowCheckpointService")
-
-describe("ShadowCheckpointService", () => {
+const tmpDir = path.join(os.tmpdir(), "CheckpointService")
+
+const initRepo = async ({
+	workspaceDir,
+	userName = "Roo Code",
+	userEmail = "[email protected]",
+	testFileName = "test.txt",
+	textFileContent = "Hello, world!",
+}: {
+	workspaceDir: string
+	userName?: string
+	userEmail?: string
+	testFileName?: string
+	textFileContent?: string
+}) => {
+	// Create a temporary directory for testing.
+	await fs.mkdir(workspaceDir, { recursive: true })
+
+	// Initialize git repo.
+	const git = simpleGit(workspaceDir)
+	await git.init()
+	await git.addConfig("user.name", userName)
+	await git.addConfig("user.email", userEmail)
+
+	// Create test file.
+	const testFile = path.join(workspaceDir, testFileName)
+	await fs.writeFile(testFile, textFileContent)
+
+	// Create initial commit.
+	await git.add(".")
+	await git.commit("Initial commit")!
+
+	return { git, testFile }
+}
+
+describe.each([
+	[RepoPerTaskCheckpointService, "RepoPerTaskCheckpointService"],
+	[RepoPerWorkspaceCheckpointService, "RepoPerWorkspaceCheckpointService"],
+])("CheckpointService", (klass, prefix) => {
 	const taskId = "test-task"
 
 	let workspaceGit: SimpleGit
 	let testFile: string
-	let service: ShadowCheckpointService
-
-	const initRepo = async ({
-		workspaceDir,
-		userName = "Roo Code",
-		userEmail = "[email protected]",
-		testFileName = "test.txt",
-		textFileContent = "Hello, world!",
-	}: {
-		workspaceDir: string
-		userName?: string
-		userEmail?: string
-		testFileName?: string
-		textFileContent?: string
-	}) => {
-		// Create a temporary directory for testing.
-		await fs.mkdir(workspaceDir, { recursive: true })
-
-		// Initialize git repo.
-		const git = simpleGit(workspaceDir)
-		await git.init()
-		await git.addConfig("user.name", userName)
-		await git.addConfig("user.email", userEmail)
-
-		// Create test file.
-		const testFile = path.join(workspaceDir, testFileName)
-		await fs.writeFile(testFile, textFileContent)
-
-		// Create initial commit.
-		await git.add(".")
-		await git.commit("Initial commit")!
-
-		return { git, testFile }
-	}
+	let service: RepoPerTaskCheckpointService | RepoPerWorkspaceCheckpointService
 
 	beforeEach(async () => {
 		jest.mocked(require("globby").globby).mockClear().mockResolvedValue([])
 
-		const shadowDir = path.join(tmpDir, `shadow-${Date.now()}`)
+		const shadowDir = path.join(tmpDir, `${prefix}-${Date.now()}`)
 		const workspaceDir = path.join(tmpDir, `workspace-${Date.now()}`)
 		const repo = await initRepo({ workspaceDir })
 
-		testFile = repo.testFile
 		workspaceGit = repo.git
+		testFile = repo.testFile
 
-		service = await ShadowCheckpointService.create({ taskId, shadowDir, workspaceDir, log: () => {} })
+		service = await klass.create({ taskId, shadowDir, workspaceDir, log: () => {} })
 		await service.initShadowGit()
 	})
 
@@ -77,14 +81,14 @@ describe("ShadowCheckpointService", () => {
 		await fs.rm(tmpDir, { recursive: true, force: true })
 	})
 
-	describe("getDiff", () => {
+	describe(`${klass.name}#getDiff`, () => {
 		it("returns the correct diff between commits", async () => {
 			await fs.writeFile(testFile, "Ahoy, world!")
-			const commit1 = await service.saveCheckpoint("First checkpoint")
+			const commit1 = await service.saveCheckpoint("Ahoy, world!")
 			expect(commit1?.commit).toBeTruthy()
 
 			await fs.writeFile(testFile, "Goodbye, world!")
-			const commit2 = await service.saveCheckpoint("Second checkpoint")
+			const commit2 = await service.saveCheckpoint("Goodbye, world!")
 			expect(commit2?.commit).toBeTruthy()
 
 			const diff1 = await service.getDiff({ to: commit1!.commit })
@@ -94,7 +98,7 @@ describe("ShadowCheckpointService", () => {
 			expect(diff1[0].content.before).toBe("Hello, world!")
 			expect(diff1[0].content.after).toBe("Ahoy, world!")
 
-			const diff2 = await service.getDiff({ to: commit2!.commit })
+			const diff2 = await service.getDiff({ from: service.baseHash, to: commit2!.commit })
 			expect(diff2).toHaveLength(1)
 			expect(diff2[0].paths.relative).toBe("test.txt")
 			expect(diff2[0].paths.absolute).toBe(testFile)
@@ -140,7 +144,7 @@ describe("ShadowCheckpointService", () => {
 		})
 	})
 
-	describe("saveCheckpoint", () => {
+	describe(`${klass.name}#saveCheckpoint`, () => {
 		it("creates a checkpoint if there are pending changes", async () => {
 			await fs.writeFile(testFile, "Ahoy, world!")
 			const commit1 = await service.saveCheckpoint("First checkpoint")
@@ -296,9 +300,9 @@ describe("ShadowCheckpointService", () => {
 		})
 	})
 
-	describe("create", () => {
+	describe(`${klass.name}#create`, () => {
 		it("initializes a git repository if one does not already exist", async () => {
-			const shadowDir = path.join(tmpDir, `shadow2-${Date.now()}`)
+			const shadowDir = path.join(tmpDir, `${prefix}2-${Date.now()}`)
 			const workspaceDir = path.join(tmpDir, `workspace2-${Date.now()}`)
 			await fs.mkdir(workspaceDir)
 
@@ -307,10 +311,11 @@ describe("ShadowCheckpointService", () => {
 			expect(await fs.readFile(newTestFile, "utf-8")).toBe("Hello, world!")
 
 			// Ensure the git repository was initialized.
-			const gitDir = path.join(shadowDir, "tasks", taskId, "checkpoints", ".git")
-			await expect(fs.stat(gitDir)).rejects.toThrow()
-			const newService = await ShadowCheckpointService.create({ taskId, shadowDir, workspaceDir, log: () => {} })
-			await newService.initShadowGit()
+			const newService = await klass.create({ taskId, shadowDir, workspaceDir, log: () => {} })
+			const { created } = await newService.initShadowGit()
+			expect(created).toBeTruthy()
+
+			const gitDir = path.join(newService.checkpointsDir, ".git")
 			expect(await fs.stat(gitDir)).toBeTruthy()
 
 			// Save a new checkpoint: Ahoy, world!
@@ -327,15 +332,15 @@ describe("ShadowCheckpointService", () => {
 			await newService.restoreCheckpoint(commit1!.commit)
 			expect(await fs.readFile(newTestFile, "utf-8")).toBe("Ahoy, world!")
 
-			await fs.rm(newService.shadowDir, { recursive: true, force: true })
+			await fs.rm(newService.checkpointsDir, { recursive: true, force: true })
 			await fs.rm(newService.workspaceDir, { recursive: true, force: true })
 		})
 	})
 
-	describe("events", () => {
+	describe(`${klass.name}#events`, () => {
 		it("emits initialize event when service is created", async () => {
-			const shadowDir = path.join(tmpDir, `shadow-event-test-${Date.now()}`)
-			const workspaceDir = path.join(tmpDir, `workspace-event-test-${Date.now()}`)
+			const shadowDir = path.join(tmpDir, `${prefix}3-${Date.now()}`)
+			const workspaceDir = path.join(tmpDir, `workspace3-${Date.now()}`)
 			await fs.mkdir(workspaceDir, { recursive: true })
 
 			const newTestFile = path.join(workspaceDir, "test.txt")
@@ -345,7 +350,7 @@ describe("ShadowCheckpointService", () => {
 			const emitSpy = jest.spyOn(EventEmitter.prototype, "emit")
 
 			// Create the service - this will trigger the initialize event.
-			const newService = await ShadowCheckpointService.create({ taskId, shadowDir, workspaceDir, log: () => {} })
+			const newService = await klass.create({ taskId, shadowDir, workspaceDir, log: () => {} })
 			await newService.initShadowGit()
 
 			// Find the initialize event in the emit calls.

+ 4 - 0
src/services/checkpoints/index.ts

@@ -0,0 +1,4 @@
+export type { CheckpointServiceOptions } from "./types"
+
+export { RepoPerTaskCheckpointService } from "./RepoPerTaskCheckpointService"
+export { RepoPerWorkspaceCheckpointService } from "./RepoPerWorkspaceCheckpointService"

+ 3 - 33
src/services/checkpoints/types.ts

@@ -1,5 +1,4 @@
-import EventEmitter from "events"
-import { CommitResult } from "simple-git"
+import { CommitResult, SimpleGit } from "simple-git"
 
 export type CheckpointResult = Partial<CommitResult> & Pick<CommitResult, "commit">
 
@@ -14,25 +13,14 @@ export type CheckpointDiff = {
 	}
 }
 
-export interface CheckpointService {
-	saveCheckpoint(message: string): Promise<CheckpointResult | undefined>
-	restoreCheckpoint(commit: string): Promise<void>
-	getDiff(range: { from?: string; to?: string }): Promise<CheckpointDiff[]>
-	workspaceDir: string
-	baseHash?: string
-	version: number
-}
-
 export interface CheckpointServiceOptions {
 	taskId: string
 	workspaceDir: string
+	shadowDir: string // globalStorageUri.fsPath
+
 	log?: (message: string) => void
 }
 
-/**
- * EventEmitter
- */
-
 export interface CheckpointEventMap {
 	initialize: { type: "initialize"; workspaceDir: string; baseHash: string; created: boolean; duration: number }
 	checkpoint: {
@@ -45,21 +33,3 @@ export interface CheckpointEventMap {
 	restore: { type: "restore"; commitHash: string; duration: number }
 	error: { type: "error"; error: Error }
 }
-
-export class CheckpointEventEmitter extends EventEmitter {
-	override emit<K extends keyof CheckpointEventMap>(event: K, data: CheckpointEventMap[K]): boolean {
-		return super.emit(event, data)
-	}
-
-	override on<K extends keyof CheckpointEventMap>(event: K, listener: (data: CheckpointEventMap[K]) => void): this {
-		return super.on(event, listener)
-	}
-
-	override off<K extends keyof CheckpointEventMap>(event: K, listener: (data: CheckpointEventMap[K]) => void): this {
-		return super.off(event, listener)
-	}
-
-	override once<K extends keyof CheckpointEventMap>(event: K, listener: (data: CheckpointEventMap[K]) => void): this {
-		return super.once(event, listener)
-	}
-}

+ 1 - 0
src/shared/WebviewMessage.ts

@@ -122,6 +122,7 @@ export interface WebviewMessage {
 
 export const checkoutDiffPayloadSchema = z.object({
 	ts: z.number(),
+	previousCommitHash: z.string().optional(),
 	commitHash: z.string(),
 	mode: z.enum(["full", "checkpoint"]),
 })

+ 7 - 3
webview-ui/src/components/chat/checkpoints/CheckpointMenu.tsx

@@ -20,13 +20,17 @@ export const CheckpointMenu = ({ ts, commitHash, currentHash, checkpoint }: Chec
 
 	const isCurrent = currentHash === commitHash
 	const isFirst = checkpoint.isFirst
-
 	const isDiffAvailable = !isFirst
 	const isRestoreAvailable = !isFirst || !isCurrent
 
+	const previousCommitHash = checkpoint?.from
+
 	const onCheckpointDiff = useCallback(() => {
-		vscode.postMessage({ type: "checkpointDiff", payload: { ts, commitHash, mode: "checkpoint" } })
-	}, [ts, commitHash])
+		vscode.postMessage({
+			type: "checkpointDiff",
+			payload: { ts, previousCommitHash, commitHash, mode: "checkpoint" },
+		})
+	}, [ts, previousCommitHash, commitHash])
 
 	const onPreview = useCallback(() => {
 		vscode.postMessage({ type: "checkpointRestore", payload: { ts, commitHash, mode: "preview" } })

+ 6 - 1
webview-ui/src/components/history/HistoryView.tsx

@@ -299,7 +299,12 @@ const HistoryView = ({ onDone }: HistoryViewProps) => {
 												{formatLargeNumber(item.tokensOut || 0)}
 											</span>
 										</div>
-										{!item.totalCost && <ExportButton itemId={item.id} />}
+										{!item.totalCost && (
+											<div className="flex flex-row gap-1">
+												<CopyButton itemTask={item.task} />
+												<ExportButton itemId={item.id} />
+											</div>
+										)}
 									</div>
 
 									{!!item.cacheWrites && (