Bladeren bron

Move checkpoint code into a separate module (#3291)

Chris Estreich 7 maanden geleden
bovenliggende
commit
b57e148dea

+ 39 - 373
src/core/Cline.ts

@@ -8,7 +8,6 @@ import cloneDeep from "clone-deep"
 import delay from "delay"
 import pWaitFor from "p-wait-for"
 import { serializeError } from "serialize-error"
-import * as vscode from "vscode"
 
 // schemas
 import { TokenUsage, ToolUsage, ToolName } from "../schemas"
@@ -43,10 +42,10 @@ import { McpHub } from "../services/mcp/McpHub"
 import { ToolRepetitionDetector } from "./ToolRepetitionDetector"
 import { McpServerManager } from "../services/mcp/McpServerManager"
 import { telemetryService } from "../services/telemetry/TelemetryService"
-import { CheckpointServiceOptions, RepoPerTaskCheckpointService } from "../services/checkpoints"
+import { RepoPerTaskCheckpointService } from "../services/checkpoints"
 
 // integrations
-import { DIFF_VIEW_URI_SCHEME, DiffViewProvider } from "../integrations/editor/DiffViewProvider"
+import { DiffViewProvider } from "../integrations/editor/DiffViewProvider"
 import { findToolName, formatContentBlockToMarkdown } from "../integrations/misc/export-markdown"
 import { RooTerminalProcess } from "../integrations/terminal/types"
 import { TerminalRegistry } from "../integrations/terminal/TerminalRegistry"
@@ -79,7 +78,6 @@ import { formatResponse } from "./prompts/responses"
 import { SYSTEM_PROMPT } from "./prompts/system"
 
 // ... everything else
-import { parseMentions } from "./mentions"
 import { FileContextTracker } from "./context-tracking/FileContextTracker"
 import { RooIgnoreController } from "./ignore/RooIgnoreController"
 import { type AssistantMessageContent, parseAssistantMessage } from "./assistant-message"
@@ -89,8 +87,15 @@ import { validateToolUse } from "./mode-validator"
 import { MultiSearchReplaceDiffStrategy } from "./diff/strategies/multi-search-replace"
 import { readApiMessages, saveApiMessages, readTaskMessages, saveTaskMessages, taskMetadata } from "./task-persistence"
 import { getEnvironmentDetails } from "./environment/getEnvironmentDetails"
-
-type UserContent = Array<Anthropic.Messages.ContentBlockParam>
+import {
+	type CheckpointDiffOptions,
+	type CheckpointRestoreOptions,
+	getCheckpointService,
+	checkpointSave,
+	checkpointRestore,
+	checkpointDiff,
+} from "./checkpoints"
+import { processUserContentMentions } from "./mentions/processUserContentMentions"
 
 export type ClineEvents = {
 	message: [{ action: "created" | "updated"; message: ClineMessage }]
@@ -144,7 +149,7 @@ export class Cline extends EventEmitter<ClineEvents> {
 
 	rooIgnoreController?: RooIgnoreController
 	fileContextTracker: FileContextTracker
-	private urlContentFetcher: UrlContentFetcher
+	urlContentFetcher: UrlContentFetcher
 	browserSession: BrowserSession
 	didEditFile: boolean = false
 	customInstructions?: string
@@ -180,9 +185,9 @@ export class Cline extends EventEmitter<ClineEvents> {
 	isInitialized = false
 
 	// checkpoints
-	private enableCheckpoints: boolean
-	private checkpointService?: RepoPerTaskCheckpointService
-	private checkpointServiceInitializing = false
+	enableCheckpoints: boolean
+	checkpointService?: RepoPerTaskCheckpointService
+	checkpointServiceInitializing = false
 
 	// streaming
 	isWaitingForFirstChunk = false
@@ -733,7 +738,7 @@ export class Cline extends EventEmitter<ClineEvents> {
 
 		// if the last message is a user message, we can need to get the assistant message before it to see if it made tool calls, and if so, fill in the remaining tool responses with 'interrupted'
 
-		let modifiedOldUserContent: UserContent // either the last message if its user message, or the user message before the last (assistant) message
+		let modifiedOldUserContent: Anthropic.Messages.ContentBlockParam[] // either the last message if its user message, or the user message before the last (assistant) message
 		let modifiedApiConversationHistory: Anthropic.Messages.MessageParam[] // need to remove the last user message to replace with new modified user message
 		if (existingApiConversationHistory.length > 0) {
 			const lastMessage = existingApiConversationHistory[existingApiConversationHistory.length - 1]
@@ -763,7 +768,7 @@ export class Cline extends EventEmitter<ClineEvents> {
 				const previousAssistantMessage: Anthropic.Messages.MessageParam | undefined =
 					existingApiConversationHistory[existingApiConversationHistory.length - 2]
 
-				const existingUserContent: UserContent = Array.isArray(lastMessage.content)
+				const existingUserContent: Anthropic.Messages.ContentBlockParam[] = Array.isArray(lastMessage.content)
 					? lastMessage.content
 					: [{ type: "text", text: lastMessage.content }]
 				if (previousAssistantMessage && previousAssistantMessage.role === "assistant") {
@@ -807,7 +812,7 @@ export class Cline extends EventEmitter<ClineEvents> {
 			throw new Error("Unexpected: No existing API conversation history")
 		}
 
-		let newUserContent: UserContent = [...modifiedOldUserContent]
+		let newUserContent: Anthropic.Messages.ContentBlockParam[] = [...modifiedOldUserContent]
 
 		const agoText = ((): string => {
 			const timestamp = lastClineMessage?.ts ?? Date.now()
@@ -915,9 +920,9 @@ export class Cline extends EventEmitter<ClineEvents> {
 
 	// Task Loop
 
-	private async initiateTaskLoop(userContent: UserContent): Promise<void> {
+	private async initiateTaskLoop(userContent: Anthropic.Messages.ContentBlockParam[]): Promise<void> {
 		// Kicks off the checkpoints initialization process in the background.
-		this.getCheckpointService()
+		getCheckpointService(this)
 
 		let nextUserContent = userContent
 		let includeFileDetails = true
@@ -951,7 +956,7 @@ export class Cline extends EventEmitter<ClineEvents> {
 	}
 
 	public async recursivelyMakeClineRequests(
-		userContent: UserContent,
+		userContent: Anthropic.Messages.ContentBlockParam[],
 		includeFileDetails: boolean = false,
 	): Promise<boolean> {
 		if (this.abort) {
@@ -1022,7 +1027,13 @@ export class Cline extends EventEmitter<ClineEvents> {
 			}),
 		)
 
-		const parsedUserContent = await this.parseUserContent(userContent)
+		const parsedUserContent = await processUserContentMentions({
+			userContent,
+			cwd: this.cwd,
+			urlContentFetcher: this.urlContentFetcher,
+			fileContextTracker: this.fileContextTracker,
+		})
+
 		const environmentDetails = await getEnvironmentDetails(this, includeFileDetails)
 
 		// Add environment details as its own text block, separate from tool
@@ -2010,7 +2021,7 @@ export class Cline extends EventEmitter<ClineEvents> {
 		if (recentlyModifiedFiles.length > 0) {
 			// TODO: We can track what file changes were made and only
 			// checkpoint those files, this will be save storage.
-			await this.checkpointSave()
+			await checkpointSave(this)
 		}
 
 		/*
@@ -2044,373 +2055,28 @@ export class Cline extends EventEmitter<ClineEvents> {
 		}
 	}
 
-	// Transform
-
-	public async parseUserContent(userContent: UserContent) {
-		// Process userContent array, which contains various block types:
-		// TextBlockParam, ImageBlockParam, ToolUseBlockParam, and ToolResultBlockParam.
-		// We need to apply parseMentions() to:
-		// 1. All TextBlockParam's text (first user message with task)
-		// 2. ToolResultBlockParam's content/context text arrays if it contains
-		// "<feedback>" (see formatToolDeniedFeedback, attemptCompletion,
-		// executeCommand, and consecutiveMistakeCount >= 3) or "<answer>"
-		// (see askFollowupQuestion), we place all user generated content in
-		// these tags so they can effectively be used as markers for when we
-		// should parse mentions).
-		return Promise.all(
-			userContent.map(async (block) => {
-				const shouldProcessMentions = (text: string) => text.includes("<task>") || text.includes("<feedback>")
-
-				if (block.type === "text") {
-					if (shouldProcessMentions(block.text)) {
-						return {
-							...block,
-							text: await parseMentions(
-								block.text,
-								this.cwd,
-								this.urlContentFetcher,
-								this.fileContextTracker,
-							),
-						}
-					}
-
-					return block
-				} else if (block.type === "tool_result") {
-					if (typeof block.content === "string") {
-						if (shouldProcessMentions(block.content)) {
-							return {
-								...block,
-								content: await parseMentions(
-									block.content,
-									this.cwd,
-									this.urlContentFetcher,
-									this.fileContextTracker,
-								),
-							}
-						}
-
-						return block
-					} else if (Array.isArray(block.content)) {
-						const parsedContent = await Promise.all(
-							block.content.map(async (contentBlock) => {
-								if (contentBlock.type === "text" && shouldProcessMentions(contentBlock.text)) {
-									return {
-										...contentBlock,
-										text: await parseMentions(
-											contentBlock.text,
-											this.cwd,
-											this.urlContentFetcher,
-											this.fileContextTracker,
-										),
-									}
-								}
-
-								return contentBlock
-							}),
-						)
-
-						return { ...block, content: parsedContent }
-					}
-
-					return block
-				}
-
-				return block
-			}),
-		)
-	}
-
 	// Checkpoints
 
-	private getCheckpointService() {
-		if (!this.enableCheckpoints) {
-			return undefined
-		}
-
-		if (this.checkpointService) {
-			return this.checkpointService
-		}
-
-		if (this.checkpointServiceInitializing) {
-			console.log("[Cline#getCheckpointService] checkpoint service is still initializing")
-			return undefined
-		}
-
-		const log = (message: string) => {
-			console.log(message)
-
-			try {
-				this.providerRef.deref()?.log(message)
-			} catch (err) {
-				// NO-OP
-			}
-		}
-
-		console.log("[Cline#getCheckpointService] initializing checkpoints service")
-
-		try {
-			const workspaceDir = getWorkspacePath()
-
-			if (!workspaceDir) {
-				log("[Cline#getCheckpointService] workspace folder not found, disabling checkpoints")
-				this.enableCheckpoints = false
-				return undefined
-			}
-
-			const globalStorageDir = this.providerRef.deref()?.context.globalStorageUri.fsPath
-
-			if (!globalStorageDir) {
-				log("[Cline#getCheckpointService] globalStorageDir not found, disabling checkpoints")
-				this.enableCheckpoints = false
-				return undefined
-			}
-
-			const options: CheckpointServiceOptions = {
-				taskId: this.taskId,
-				workspaceDir,
-				shadowDir: globalStorageDir,
-				log,
-			}
-
-			const service = RepoPerTaskCheckpointService.create(options)
-
-			this.checkpointServiceInitializing = true
-
-			service.on("initialize", () => {
-				log("[Cline#getCheckpointService] service initialized")
-
-				try {
-					const isCheckpointNeeded =
-						typeof this.clineMessages.find(({ say }) => say === "checkpoint_saved") === "undefined"
-
-					this.checkpointService = service
-					this.checkpointServiceInitializing = false
-
-					if (isCheckpointNeeded) {
-						log("[Cline#getCheckpointService] no checkpoints found, saving initial checkpoint")
-						this.checkpointSave()
-					}
-				} catch (err) {
-					log("[Cline#getCheckpointService] caught error in on('initialize'), disabling checkpoints")
-					this.enableCheckpoints = false
-				}
-			})
-
-			service.on("checkpoint", ({ isFirst, fromHash: from, toHash: to }) => {
-				try {
-					this.providerRef.deref()?.postMessageToWebview({ type: "currentCheckpointUpdated", text: to })
-
-					this.say("checkpoint_saved", to, undefined, undefined, { isFirst, from, to }).catch((err) => {
-						log("[Cline#getCheckpointService] caught unexpected error in say('checkpoint_saved')")
-						console.error(err)
-					})
-				} catch (err) {
-					log(
-						"[Cline#getCheckpointService] caught unexpected error in on('checkpoint'), disabling checkpoints",
-					)
-					console.error(err)
-					this.enableCheckpoints = false
-				}
-			})
-
-			log("[Cline#getCheckpointService] initializing shadow git")
-
-			service.initShadowGit().catch((err) => {
-				log(
-					`[Cline#getCheckpointService] caught unexpected error in initShadowGit, disabling checkpoints (${err.message})`,
-				)
-				console.error(err)
-				this.enableCheckpoints = false
-			})
-
-			return service
-		} catch (err) {
-			log("[Cline#getCheckpointService] caught unexpected error, disabling checkpoints")
-			this.enableCheckpoints = false
-			return undefined
-		}
+	public async checkpointSave() {
+		return checkpointSave(this)
 	}
 
-	private async getInitializedCheckpointService({
-		interval = 250,
-		timeout = 15_000,
-	}: { interval?: number; timeout?: number } = {}) {
-		const service = this.getCheckpointService()
-
-		if (!service || service.isInitialized) {
-			return service
-		}
-
-		try {
-			await pWaitFor(
-				() => {
-					console.log("[Cline#getCheckpointService] waiting for service to initialize")
-					return service.isInitialized
-				},
-				{ interval, timeout },
-			)
-
-			return service
-		} catch (err) {
-			return undefined
-		}
+	public async checkpointRestore(options: CheckpointRestoreOptions) {
+		return checkpointRestore(this, options)
 	}
 
-	public async checkpointDiff({
-		ts,
-		previousCommitHash,
-		commitHash,
-		mode,
-	}: {
-		ts: number
-		previousCommitHash?: string
-		commitHash: string
-		mode: "full" | "checkpoint"
-	}) {
-		const service = await this.getInitializedCheckpointService()
-
-		if (!service) {
-			return
-		}
-
-		telemetryService.captureCheckpointDiffed(this.taskId)
-
-		if (!previousCommitHash && mode === "checkpoint") {
-			const previousCheckpoint = this.clineMessages
-				.filter(({ say }) => say === "checkpoint_saved")
-				.sort((a, b) => b.ts - a.ts)
-				.find((message) => message.ts < ts)
-
-			previousCommitHash = previousCheckpoint?.text
-		}
-
-		try {
-			const changes = await service.getDiff({ from: previousCommitHash, to: commitHash })
-
-			if (!changes?.length) {
-				vscode.window.showInformationMessage("No changes found.")
-				return
-			}
-
-			await vscode.commands.executeCommand(
-				"vscode.changes",
-				mode === "full" ? "Changes since task started" : "Changes since previous checkpoint",
-				changes.map((change) => [
-					vscode.Uri.file(change.paths.absolute),
-					vscode.Uri.parse(`${DIFF_VIEW_URI_SCHEME}:${change.paths.relative}`).with({
-						query: Buffer.from(change.content.before ?? "").toString("base64"),
-					}),
-					vscode.Uri.parse(`${DIFF_VIEW_URI_SCHEME}:${change.paths.relative}`).with({
-						query: Buffer.from(change.content.after ?? "").toString("base64"),
-					}),
-				]),
-			)
-		} catch (err) {
-			this.providerRef.deref()?.log("[checkpointDiff] disabling checkpoints for this task")
-			this.enableCheckpoints = false
-		}
+	public async checkpointDiff(options: CheckpointDiffOptions) {
+		return checkpointDiff(this, options)
 	}
 
-	public async checkpointSave() {
-		const service = this.getCheckpointService()
-
-		if (!service) {
-			return
-		}
-
-		if (!service.isInitialized) {
-			this.providerRef
-				.deref()
-				?.log("[checkpointSave] checkpoints didn't initialize in time, disabling checkpoints for this task")
-
-			this.enableCheckpoints = false
-			return
-		}
-
-		telemetryService.captureCheckpointCreated(this.taskId)
-
-		// Start the checkpoint process in the background.
-		return service.saveCheckpoint(`Task: ${this.taskId}, Time: ${Date.now()}`).catch((err) => {
-			console.error("[Cline#checkpointSave] caught unexpected error, disabling checkpoints", err)
-			this.enableCheckpoints = false
-		})
-	}
-
-	public async checkpointRestore({
-		ts,
-		commitHash,
-		mode,
-	}: {
-		ts: number
-		commitHash: string
-		mode: "preview" | "restore"
-	}) {
-		const service = await this.getInitializedCheckpointService()
-
-		if (!service) {
-			return
-		}
-
-		const index = this.clineMessages.findIndex((m) => m.ts === ts)
-
-		if (index === -1) {
-			return
-		}
-
-		try {
-			await service.restoreCheckpoint(commitHash)
-
-			telemetryService.captureCheckpointRestored(this.taskId)
-
-			await this.providerRef.deref()?.postMessageToWebview({ type: "currentCheckpointUpdated", text: commitHash })
-
-			if (mode === "restore") {
-				await this.overwriteApiConversationHistory(
-					this.apiConversationHistory.filter((m) => !m.ts || m.ts < ts),
-				)
-
-				const deletedMessages = this.clineMessages.slice(index + 1)
-
-				const { totalTokensIn, totalTokensOut, totalCacheWrites, totalCacheReads, totalCost } = getApiMetrics(
-					combineApiRequests(combineCommandSequences(deletedMessages)),
-				)
-
-				await this.overwriteClineMessages(this.clineMessages.slice(0, index + 1))
-
-				// TODO: Verify that this is working as expected.
-				await this.say(
-					"api_req_deleted",
-					JSON.stringify({
-						tokensIn: totalTokensIn,
-						tokensOut: totalTokensOut,
-						cacheWrites: totalCacheWrites,
-						cacheReads: totalCacheReads,
-						cost: totalCost,
-					} satisfies ClineApiReqInfo),
-				)
-			}
+	// Metrics
 
-			// The task is already cancelled by the provider beforehand, but we
-			// need to re-init to get the updated messages.
-			//
-			// This was take from Cline's implementation of the checkpoints
-			// feature. The cline instance will hang if we don't cancel twice,
-			// so this is currently necessary, but it seems like a complicated
-			// and hacky solution to a problem that I don't fully understand.
-			// I'd like to revisit this in the future and try to improve the
-			// task flow and the communication between the webview and the
-			// Cline instance.
-			this.providerRef.deref()?.cancelTask()
-		} catch (err) {
-			this.providerRef.deref()?.log("[checkpointRestore] disabling checkpoints for this task")
-			this.enableCheckpoints = false
-		}
+	public combineMessages(messages: ClineMessage[]) {
+		return combineApiRequests(combineCommandSequences(messages))
 	}
 
-	// Metrics
-
 	public getTokenUsage() {
-		return getApiMetrics(combineApiRequests(combineCommandSequences(this.clineMessages.slice(1))))
+		return getApiMetrics(this.combineMessages(this.clineMessages.slice(1)))
 	}
 
 	public recordToolUsage(toolName: ToolName) {

+ 22 - 15
src/core/__tests__/Cline.test.ts

@@ -12,6 +12,19 @@ import { ClineProvider } from "../webview/ClineProvider"
 import { ApiConfiguration, ModelInfo } from "../../shared/api"
 import { ApiStreamChunk } from "../../api/transform/stream"
 import { ContextProxy } from "../config/ContextProxy"
+import { processUserContentMentions } from "../mentions/processUserContentMentions"
+
+jest.mock("../mentions", () => ({
+	parseMentions: jest.fn().mockImplementation((text) => {
+		return Promise.resolve(`processed: ${text}`)
+	}),
+	openMention: jest.fn(),
+	getLatestTerminalOutput: jest.fn(),
+}))
+
+jest.mock("../../integrations/misc/extract-text", () => ({
+	extractTextFromFile: jest.fn().mockResolvedValue("Mock file content"),
+}))
 
 jest.mock("../environment/getEnvironmentDetails", () => ({
 	getEnvironmentDetails: jest.fn().mockResolvedValue(""),
@@ -791,7 +804,7 @@ describe("Cline", () => {
 				await task.catch(() => {})
 			})
 
-			describe("parseUserContent", () => {
+			describe("processUserContentMentions", () => {
 				it("should process mentions in task and feedback tags", async () => {
 					const [cline, task] = Cline.create({
 						provider: mockProvider,
@@ -799,10 +812,6 @@ describe("Cline", () => {
 						task: "test task",
 					})
 
-					// Mock parseMentions to track calls
-					const mockParseMentions = jest.fn().mockImplementation((text) => `processed: ${text}`)
-					jest.spyOn(require("../../core/mentions"), "parseMentions").mockImplementation(mockParseMentions)
-
 					const userContent = [
 						{
 							type: "text",
@@ -834,30 +843,28 @@ describe("Cline", () => {
 						} as Anthropic.ToolResultBlockParam,
 					]
 
-					// Process the content
-					const processedContent = await cline.parseUserContent(userContent)
+					const processedContent = await processUserContentMentions({
+						userContent,
+						cwd: cline.cwd,
+						urlContentFetcher: cline.urlContentFetcher,
+						fileContextTracker: cline.fileContextTracker,
+					})
 
 					// Regular text should not be processed
 					expect((processedContent[0] as Anthropic.TextBlockParam).text).toBe("Regular text with @/some/path")
 
 					// Text within task tags should be processed
 					expect((processedContent[1] as Anthropic.TextBlockParam).text).toContain("processed:")
-					expect(mockParseMentions).toHaveBeenCalledWith(
+					expect((processedContent[1] as Anthropic.TextBlockParam).text).toContain(
 						"<task>Text with @/some/path in task tags</task>",
-						expect.any(String),
-						expect.any(Object),
-						expect.any(Object),
 					)
 
 					// Feedback tag content should be processed
 					const toolResult1 = processedContent[2] as Anthropic.ToolResultBlockParam
 					const content1 = Array.isArray(toolResult1.content) ? toolResult1.content[0] : toolResult1.content
 					expect((content1 as Anthropic.TextBlockParam).text).toContain("processed:")
-					expect(mockParseMentions).toHaveBeenCalledWith(
+					expect((content1 as Anthropic.TextBlockParam).text).toContain(
 						"<feedback>Check @/some/path</feedback>",
-						expect.any(String),
-						expect.any(Object),
-						expect.any(Object),
 					)
 
 					// Regular tool result should not be processed

+ 295 - 0
src/core/checkpoints/index.ts

@@ -0,0 +1,295 @@
+import pWaitFor from "p-wait-for"
+import * as vscode from "vscode"
+
+import { Cline } from "../Cline"
+
+import { getWorkspacePath } from "../../utils/path"
+
+import { ClineApiReqInfo } from "../../shared/ExtensionMessage"
+import { getApiMetrics } from "../../shared/getApiMetrics"
+
+import { DIFF_VIEW_URI_SCHEME } from "../../integrations/editor/DiffViewProvider"
+
+import { telemetryService } from "../../services/telemetry/TelemetryService"
+import { CheckpointServiceOptions, RepoPerTaskCheckpointService } from "../../services/checkpoints"
+
+export function getCheckpointService(cline: Cline) {
+	if (!cline.enableCheckpoints) {
+		return undefined
+	}
+
+	if (cline.checkpointService) {
+		return cline.checkpointService
+	}
+
+	if (cline.checkpointServiceInitializing) {
+		console.log("[Cline#getCheckpointService] checkpoint service is still initializing")
+		return undefined
+	}
+
+	const provider = cline.providerRef.deref()
+
+	const log = (message: string) => {
+		console.log(message)
+
+		try {
+			provider?.log(message)
+		} catch (err) {
+			// NO-OP
+		}
+	}
+
+	console.log("[Cline#getCheckpointService] initializing checkpoints service")
+
+	try {
+		const workspaceDir = getWorkspacePath()
+
+		if (!workspaceDir) {
+			log("[Cline#getCheckpointService] workspace folder not found, disabling checkpoints")
+			cline.enableCheckpoints = false
+			return undefined
+		}
+
+		const globalStorageDir = provider?.context.globalStorageUri.fsPath
+
+		if (!globalStorageDir) {
+			log("[Cline#getCheckpointService] globalStorageDir not found, disabling checkpoints")
+			cline.enableCheckpoints = false
+			return undefined
+		}
+
+		const options: CheckpointServiceOptions = {
+			taskId: cline.taskId,
+			workspaceDir,
+			shadowDir: globalStorageDir,
+			log,
+		}
+
+		const service = RepoPerTaskCheckpointService.create(options)
+
+		cline.checkpointServiceInitializing = true
+
+		service.on("initialize", () => {
+			log("[Cline#getCheckpointService] service initialized")
+
+			try {
+				const isCheckpointNeeded =
+					typeof cline.clineMessages.find(({ say }) => say === "checkpoint_saved") === "undefined"
+
+				cline.checkpointService = service
+				cline.checkpointServiceInitializing = false
+
+				if (isCheckpointNeeded) {
+					log("[Cline#getCheckpointService] no checkpoints found, saving initial checkpoint")
+					checkpointSave(cline)
+				}
+			} catch (err) {
+				log("[Cline#getCheckpointService] caught error in on('initialize'), disabling checkpoints")
+				cline.enableCheckpoints = false
+			}
+		})
+
+		service.on("checkpoint", ({ isFirst, fromHash: from, toHash: to }) => {
+			try {
+				provider?.postMessageToWebview({ type: "currentCheckpointUpdated", text: to })
+
+				cline.say("checkpoint_saved", to, undefined, undefined, { isFirst, from, to }).catch((err) => {
+					log("[Cline#getCheckpointService] caught unexpected error in say('checkpoint_saved')")
+					console.error(err)
+				})
+			} catch (err) {
+				log("[Cline#getCheckpointService] caught unexpected error in on('checkpoint'), disabling checkpoints")
+				console.error(err)
+				cline.enableCheckpoints = false
+			}
+		})
+
+		log("[Cline#getCheckpointService] initializing shadow git")
+
+		service.initShadowGit().catch((err) => {
+			log(
+				`[Cline#getCheckpointService] caught unexpected error in initShadowGit, disabling checkpoints (${err.message})`,
+			)
+
+			console.error(err)
+			cline.enableCheckpoints = false
+		})
+
+		return service
+	} catch (err) {
+		log("[Cline#getCheckpointService] caught unexpected error, disabling checkpoints")
+		cline.enableCheckpoints = false
+		return undefined
+	}
+}
+
+async function getInitializedCheckpointService(
+	cline: Cline,
+	{ interval = 250, timeout = 15_000 }: { interval?: number; timeout?: number } = {},
+) {
+	const service = getCheckpointService(cline)
+
+	if (!service || service.isInitialized) {
+		return service
+	}
+
+	try {
+		await pWaitFor(
+			() => {
+				console.log("[Cline#getCheckpointService] waiting for service to initialize")
+				return service.isInitialized
+			},
+			{ interval, timeout },
+		)
+
+		return service
+	} catch (err) {
+		return undefined
+	}
+}
+
+export async function checkpointSave(cline: Cline) {
+	const service = getCheckpointService(cline)
+
+	if (!service) {
+		return
+	}
+
+	if (!service.isInitialized) {
+		const provider = cline.providerRef.deref()
+		provider?.log("[checkpointSave] checkpoints didn't initialize in time, disabling checkpoints for this task")
+		cline.enableCheckpoints = false
+		return
+	}
+
+	telemetryService.captureCheckpointCreated(cline.taskId)
+
+	// Start the checkpoint process in the background.
+	return service.saveCheckpoint(`Task: ${cline.taskId}, Time: ${Date.now()}`).catch((err) => {
+		console.error("[Cline#checkpointSave] caught unexpected error, disabling checkpoints", err)
+		cline.enableCheckpoints = false
+	})
+}
+
+export type CheckpointRestoreOptions = {
+	ts: number
+	commitHash: string
+	mode: "preview" | "restore"
+}
+
+export async function checkpointRestore(cline: Cline, { ts, commitHash, mode }: CheckpointRestoreOptions) {
+	const service = await getInitializedCheckpointService(cline)
+
+	if (!service) {
+		return
+	}
+
+	const index = cline.clineMessages.findIndex((m) => m.ts === ts)
+
+	if (index === -1) {
+		return
+	}
+
+	const provider = cline.providerRef.deref()
+
+	try {
+		await service.restoreCheckpoint(commitHash)
+		telemetryService.captureCheckpointRestored(cline.taskId)
+		await provider?.postMessageToWebview({ type: "currentCheckpointUpdated", text: commitHash })
+
+		if (mode === "restore") {
+			await cline.overwriteApiConversationHistory(cline.apiConversationHistory.filter((m) => !m.ts || m.ts < ts))
+
+			const deletedMessages = cline.clineMessages.slice(index + 1)
+
+			const { totalTokensIn, totalTokensOut, totalCacheWrites, totalCacheReads, totalCost } = getApiMetrics(
+				cline.combineMessages(deletedMessages),
+			)
+
+			await cline.overwriteClineMessages(cline.clineMessages.slice(0, index + 1))
+
+			// TODO: Verify that this is working as expected.
+			await cline.say(
+				"api_req_deleted",
+				JSON.stringify({
+					tokensIn: totalTokensIn,
+					tokensOut: totalTokensOut,
+					cacheWrites: totalCacheWrites,
+					cacheReads: totalCacheReads,
+					cost: totalCost,
+				} satisfies ClineApiReqInfo),
+			)
+		}
+
+		// The task is already cancelled by the provider beforehand, but we
+		// need to re-init to get the updated messages.
+		//
+		// This was take from Cline's implementation of the checkpoints
+		// feature. The cline instance will hang if we don't cancel twice,
+		// so this is currently necessary, but it seems like a complicated
+		// and hacky solution to a problem that I don't fully understand.
+		// I'd like to revisit this in the future and try to improve the
+		// task flow and the communication between the webview and the
+		// Cline instance.
+		provider?.cancelTask()
+	} catch (err) {
+		provider?.log("[checkpointRestore] disabling checkpoints for this task")
+		cline.enableCheckpoints = false
+	}
+}
+
+export type CheckpointDiffOptions = {
+	ts: number
+	previousCommitHash?: string
+	commitHash: string
+	mode: "full" | "checkpoint"
+}
+
+export async function checkpointDiff(
+	cline: Cline,
+	{ ts, previousCommitHash, commitHash, mode }: CheckpointDiffOptions,
+) {
+	const service = await getInitializedCheckpointService(cline)
+
+	if (!service) {
+		return
+	}
+
+	telemetryService.captureCheckpointDiffed(cline.taskId)
+
+	if (!previousCommitHash && mode === "checkpoint") {
+		const previousCheckpoint = cline.clineMessages
+			.filter(({ say }) => say === "checkpoint_saved")
+			.sort((a, b) => b.ts - a.ts)
+			.find((message) => message.ts < ts)
+
+		previousCommitHash = previousCheckpoint?.text
+	}
+
+	try {
+		const changes = await service.getDiff({ from: previousCommitHash, to: commitHash })
+
+		if (!changes?.length) {
+			vscode.window.showInformationMessage("No changes found.")
+			return
+		}
+
+		await vscode.commands.executeCommand(
+			"vscode.changes",
+			mode === "full" ? "Changes since task started" : "Changes since previous checkpoint",
+			changes.map((change) => [
+				vscode.Uri.file(change.paths.absolute),
+				vscode.Uri.parse(`${DIFF_VIEW_URI_SCHEME}:${change.paths.relative}`).with({
+					query: Buffer.from(change.content.before ?? "").toString("base64"),
+				}),
+				vscode.Uri.parse(`${DIFF_VIEW_URI_SCHEME}:${change.paths.relative}`).with({
+					query: Buffer.from(change.content.after ?? "").toString("base64"),
+				}),
+			]),
+		)
+	} catch (err) {
+		const provider = cline.providerRef.deref()
+		provider?.log("[checkpointDiff] disabling checkpoints for this task")
+		cline.enableCheckpoints = false
+	}
+}

+ 10 - 4
src/core/mentions/index.ts

@@ -4,14 +4,17 @@ import * as path from "path"
 import * as vscode from "vscode"
 import { isBinaryFile } from "isbinaryfile"
 
-import { openFile } from "../../integrations/misc/open-file"
-import { UrlContentFetcher } from "../../services/browser/UrlContentFetcher"
 import { mentionRegexGlobal, unescapeSpaces } from "../../shared/context-mentions"
 
-import { extractTextFromFile } from "../../integrations/misc/extract-text"
-import { diagnosticsToProblemsString } from "../../integrations/diagnostics"
 import { getCommitInfo, getWorkingState } from "../../utils/git"
 import { getWorkspacePath } from "../../utils/path"
+
+import { openFile } from "../../integrations/misc/open-file"
+import { extractTextFromFile } from "../../integrations/misc/extract-text"
+import { diagnosticsToProblemsString } from "../../integrations/diagnostics"
+
+import { UrlContentFetcher } from "../../services/browser/UrlContentFetcher"
+
 import { FileContextTracker } from "../context-tracking/FileContextTracker"
 
 export async function openMention(mention?: string): Promise<void> {
@@ -273,3 +276,6 @@ export async function getLatestTerminalOutput(): Promise<string> {
 		await vscode.env.clipboard.writeText(originalClipboard)
 	}
 }
+
+// Export processUserContentMentions from its own file
+export { processUserContentMentions } from "./processUserContentMentions"

+ 81 - 0
src/core/mentions/processUserContentMentions.ts

@@ -0,0 +1,81 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import { parseMentions } from "./index"
+import { UrlContentFetcher } from "../../services/browser/UrlContentFetcher"
+import { FileContextTracker } from "../context-tracking/FileContextTracker"
+
+/**
+ * Process mentions in user content, specifically within task and feedback tags
+ */
+export async function processUserContentMentions({
+	userContent,
+	cwd,
+	urlContentFetcher,
+	fileContextTracker,
+}: {
+	userContent: Anthropic.Messages.ContentBlockParam[]
+	cwd: string
+	urlContentFetcher: UrlContentFetcher
+	fileContextTracker: FileContextTracker
+}) {
+	// Process userContent array, which contains various block types:
+	// TextBlockParam, ImageBlockParam, ToolUseBlockParam, and ToolResultBlockParam.
+	// We need to apply parseMentions() to:
+	// 1. All TextBlockParam's text (first user message with task)
+	// 2. ToolResultBlockParam's content/context text arrays if it contains
+	// "<feedback>" (see formatToolDeniedFeedback, attemptCompletion,
+	// executeCommand, and consecutiveMistakeCount >= 3) or "<answer>"
+	// (see askFollowupQuestion), we place all user generated content in
+	// these tags so they can effectively be used as markers for when we
+	// should parse mentions).
+	return Promise.all(
+		userContent.map(async (block) => {
+			const shouldProcessMentions = (text: string) => text.includes("<task>") || text.includes("<feedback>")
+
+			if (block.type === "text") {
+				if (shouldProcessMentions(block.text)) {
+					return {
+						...block,
+						text: await parseMentions(block.text, cwd, urlContentFetcher, fileContextTracker),
+					}
+				}
+
+				return block
+			} else if (block.type === "tool_result") {
+				if (typeof block.content === "string") {
+					if (shouldProcessMentions(block.content)) {
+						return {
+							...block,
+							content: await parseMentions(block.content, cwd, urlContentFetcher, fileContextTracker),
+						}
+					}
+
+					return block
+				} else if (Array.isArray(block.content)) {
+					const parsedContent = await Promise.all(
+						block.content.map(async (contentBlock) => {
+							if (contentBlock.type === "text" && shouldProcessMentions(contentBlock.text)) {
+								return {
+									...contentBlock,
+									text: await parseMentions(
+										contentBlock.text,
+										cwd,
+										urlContentFetcher,
+										fileContextTracker,
+									),
+								}
+							}
+
+							return contentBlock
+						}),
+					)
+
+					return { ...block, content: parsedContent }
+				}
+
+				return block
+			}
+
+			return block
+		}),
+	)
+}