Browse Source

Publish token usage metrics (#7637)

Chris Estreich 4 months ago
parent
commit
c25cfdeaef

+ 4 - 2
packages/cloud/src/bridge/ExtensionChannel.ts

@@ -187,6 +187,7 @@ export class ExtensionChannel extends BaseChannel<
 			{ from: RooCodeEventName.TaskUnpaused, to: ExtensionBridgeEventName.TaskUnpaused },
 			{ from: RooCodeEventName.TaskSpawned, to: ExtensionBridgeEventName.TaskSpawned },
 			{ from: RooCodeEventName.TaskUserMessage, to: ExtensionBridgeEventName.TaskUserMessage },
+			{ from: RooCodeEventName.TaskTokenUsageUpdated, to: ExtensionBridgeEventName.TaskTokenUsageUpdated },
 		] as const
 
 		eventMapping.forEach(({ from, to }) => {
@@ -229,11 +230,12 @@ export class ExtensionChannel extends BaseChannel<
 			task: task
 				? {
 						taskId: task.taskId,
+						parentTaskId: task.parentTaskId,
+						childTaskId: task.childTaskId,
 						taskStatus: task.taskStatus,
 						taskAsk: task?.taskAsk,
 						queuedMessages: task.queuedMessages,
-						parentTaskId: task.parentTaskId,
-						childTaskId: task.childTaskId,
+						tokenUsage: task.tokenUsage,
 						...task.metadata,
 					}
 				: { taskId: "", taskStatus: TaskStatus.None },

+ 2 - 1
packages/cloud/src/bridge/__tests__/ExtensionChannel.test.ts

@@ -120,6 +120,7 @@ describe("ExtensionChannel", () => {
 				RooCodeEventName.TaskUnpaused,
 				RooCodeEventName.TaskSpawned,
 				RooCodeEventName.TaskUserMessage,
+				RooCodeEventName.TaskTokenUsageUpdated,
 			]
 
 			// Check that on() was called for each event
@@ -254,7 +255,7 @@ describe("ExtensionChannel", () => {
 			}
 
 			// Listeners should still be the same count (not accumulated)
-			expect(eventListeners.size).toBe(14)
+			expect(eventListeners.size).toBe(15)
 
 			// Each event should have exactly 1 listener
 			eventListeners.forEach((listeners) => {

+ 1 - 1
packages/types/npm/package.metadata.json

@@ -1,6 +1,6 @@
 {
 	"name": "@roo-code/types",
-	"version": "1.73.0",
+	"version": "1.74.0",
 	"description": "TypeScript type definitions for Roo Code.",
 	"publishConfig": {
 		"access": "public",

+ 10 - 1
packages/types/src/cloud.ts

@@ -7,7 +7,7 @@ import { TaskStatus, taskMetadataSchema } from "./task.js"
 import { globalSettingsSchema } from "./global-settings.js"
 import { providerSettingsWithIdSchema } from "./provider-settings.js"
 import { mcpMarketplaceItemSchema } from "./marketplace.js"
-import { clineMessageSchema, queuedMessageSchema } from "./message.js"
+import { clineMessageSchema, queuedMessageSchema, tokenUsageSchema } from "./message.js"
 import { staticAppPropertiesSchema, gitPropertiesSchema } from "./telemetry.js"
 
 /**
@@ -363,6 +363,7 @@ const extensionTaskSchema = z.object({
 	queuedMessages: z.array(queuedMessageSchema).optional(),
 	parentTaskId: z.string().optional(),
 	childTaskId: z.string().optional(),
+	tokenUsage: tokenUsageSchema.optional(),
 	...taskMetadataSchema.shape,
 })
 
@@ -412,6 +413,8 @@ export enum ExtensionBridgeEventName {
 
 	TaskUserMessage = RooCodeEventName.TaskUserMessage,
 
+	TaskTokenUsageUpdated = RooCodeEventName.TaskTokenUsageUpdated,
+
 	ModeChanged = RooCodeEventName.ModeChanged,
 	ProviderProfileChanged = RooCodeEventName.ProviderProfileChanged,
 
@@ -494,6 +497,12 @@ export const extensionBridgeEventSchema = z.discriminatedUnion("type", [
 		timestamp: z.number(),
 	}),
 
+	z.object({
+		type: z.literal(ExtensionBridgeEventName.TaskTokenUsageUpdated),
+		instance: extensionInstanceSchema,
+		timestamp: z.number(),
+	}),
+
 	z.object({
 		type: z.literal(ExtensionBridgeEventName.ModeChanged),
 		instance: extensionInstanceSchema,

+ 3 - 0
packages/types/src/task.ts

@@ -75,6 +75,8 @@ export type TaskProviderEvents = {
 
 	[RooCodeEventName.TaskUserMessage]: [taskId: string]
 
+	[RooCodeEventName.TaskTokenUsageUpdated]: [taskId: string, tokenUsage: TokenUsage]
+
 	[RooCodeEventName.ModeChanged]: [mode: string]
 	[RooCodeEventName.ProviderProfileChanged]: [config: { name: string; provider?: string }]
 }
@@ -116,6 +118,7 @@ export interface TaskLike {
 	readonly taskStatus: TaskStatus
 	readonly taskAsk: ClineMessage | undefined
 	readonly queuedMessages: QueuedMessage[]
+	readonly tokenUsage: TokenUsage | undefined
 
 	on<K extends keyof TaskEvents>(event: K, listener: (...args: TaskEvents[K]) => void | Promise<void>): this
 	off<K extends keyof TaskEvents>(event: K, listener: (...args: TaskEvents[K]) => void | Promise<void>): this

+ 64 - 60
src/core/checkpoints/index.ts

@@ -17,17 +17,18 @@ import { DIFF_VIEW_URI_SCHEME } from "../../integrations/editor/DiffViewProvider
 import { CheckpointServiceOptions, RepoPerTaskCheckpointService } from "../../services/checkpoints"
 
 export async function getCheckpointService(
-	cline: Task,
+	task: Task,
 	{ interval = 250, timeout = 15_000 }: { interval?: number; timeout?: number } = {},
 ) {
-	if (!cline.enableCheckpoints) {
+	if (!task.enableCheckpoints) {
 		return undefined
 	}
-	if (cline.checkpointService) {
-		return cline.checkpointService
+
+	if (task.checkpointService) {
+		return task.checkpointService
 	}
 
-	const provider = cline.providerRef.deref()
+	const provider = task.providerRef.deref()
 
 	const log = (message: string) => {
 		console.log(message)
@@ -42,11 +43,11 @@ export async function getCheckpointService(
 	console.log("[Task#getCheckpointService] initializing checkpoints service")
 
 	try {
-		const workspaceDir = cline.cwd || getWorkspacePath()
+		const workspaceDir = task.cwd || getWorkspacePath()
 
 		if (!workspaceDir) {
 			log("[Task#getCheckpointService] workspace folder not found, disabling checkpoints")
-			cline.enableCheckpoints = false
+			task.enableCheckpoints = false
 			return undefined
 		}
 
@@ -54,48 +55,51 @@ export async function getCheckpointService(
 
 		if (!globalStorageDir) {
 			log("[Task#getCheckpointService] globalStorageDir not found, disabling checkpoints")
-			cline.enableCheckpoints = false
+			task.enableCheckpoints = false
 			return undefined
 		}
 
 		const options: CheckpointServiceOptions = {
-			taskId: cline.taskId,
+			taskId: task.taskId,
 			workspaceDir,
 			shadowDir: globalStorageDir,
 			log,
 		}
-		if (cline.checkpointServiceInitializing) {
+
+		if (task.checkpointServiceInitializing) {
 			await pWaitFor(
 				() => {
 					console.log("[Task#getCheckpointService] waiting for service to initialize")
-					return !!cline.checkpointService && !!cline?.checkpointService?.isInitialized
+					return !!task.checkpointService && !!task?.checkpointService?.isInitialized
 				},
 				{ interval, timeout },
 			)
-			if (!cline?.checkpointService) {
-				cline.enableCheckpoints = false
+			if (!task?.checkpointService) {
+				task.enableCheckpoints = false
 				return undefined
 			}
-			return cline.checkpointService
+			return task.checkpointService
 		}
-		if (!cline.enableCheckpoints) {
+
+		if (!task.enableCheckpoints) {
 			return undefined
 		}
+
 		const service = RepoPerTaskCheckpointService.create(options)
-		cline.checkpointServiceInitializing = true
-		await checkGitInstallation(cline, service, log, provider)
-		cline.checkpointService = service
+		task.checkpointServiceInitializing = true
+		await checkGitInstallation(task, service, log, provider)
+		task.checkpointService = service
 		return service
 	} catch (err) {
 		log(`[Task#getCheckpointService] ${err.message}`)
-		cline.enableCheckpoints = false
-		cline.checkpointServiceInitializing = false
+		task.enableCheckpoints = false
+		task.checkpointServiceInitializing = false
 		return undefined
 	}
 }
 
 async function checkGitInstallation(
-	cline: Task,
+	task: Task,
 	service: RepoPerTaskCheckpointService,
 	log: (message: string) => void,
 	provider: any,
@@ -105,8 +109,8 @@ async function checkGitInstallation(
 
 		if (!gitInstalled) {
 			log("[Task#getCheckpointService] Git is not installed, disabling checkpoints")
-			cline.enableCheckpoints = false
-			cline.checkpointServiceInitializing = false
+			task.enableCheckpoints = false
+			task.checkpointServiceInitializing = false
 
 			// Show user-friendly notification
 			const selection = await vscode.window.showWarningMessage(
@@ -124,56 +128,55 @@ async function checkGitInstallation(
 		// Git is installed, proceed with initialization
 		service.on("initialize", () => {
 			log("[Task#getCheckpointService] service initialized")
-			cline.checkpointServiceInitializing = false
+			task.checkpointServiceInitializing = false
 		})
 
 		service.on("checkpoint", ({ fromHash: from, toHash: to }) => {
 			try {
 				provider?.postMessageToWebview({ type: "currentCheckpointUpdated", text: to })
 
-				cline
-					.say("checkpoint_saved", to, undefined, undefined, { from, to }, undefined, {
-						isNonInteractive: true,
-					})
-					.catch((err) => {
-						log("[Task#getCheckpointService] caught unexpected error in say('checkpoint_saved')")
-						console.error(err)
-					})
+				task.say("checkpoint_saved", to, undefined, undefined, { from, to }, undefined, {
+					isNonInteractive: true,
+				}).catch((err) => {
+					log("[Task#getCheckpointService] caught unexpected error in say('checkpoint_saved')")
+					console.error(err)
+				})
 			} catch (err) {
 				log("[Task#getCheckpointService] caught unexpected error in on('checkpoint'), disabling checkpoints")
 				console.error(err)
-				cline.enableCheckpoints = false
+				task.enableCheckpoints = false
 			}
 		})
 
 		log("[Task#getCheckpointService] initializing shadow git")
+
 		try {
 			await service.initShadowGit()
 		} catch (err) {
 			log(`[Task#getCheckpointService] initShadowGit -> ${err.message}`)
-			cline.enableCheckpoints = false
+			task.enableCheckpoints = false
 		}
 	} catch (err) {
 		log(`[Task#getCheckpointService] Unexpected error during Git check: ${err.message}`)
 		console.error("Git check error:", err)
-		cline.enableCheckpoints = false
-		cline.checkpointServiceInitializing = false
+		task.enableCheckpoints = false
+		task.checkpointServiceInitializing = false
 	}
 }
 
-export async function checkpointSave(cline: Task, force = false) {
-	const service = await getCheckpointService(cline)
+export async function checkpointSave(task: Task, force = false) {
+	const service = await getCheckpointService(task)
 
 	if (!service) {
 		return
 	}
 
-	TelemetryService.instance.captureCheckpointCreated(cline.taskId)
+	TelemetryService.instance.captureCheckpointCreated(task.taskId)
 
 	// Start the checkpoint process in the background.
-	return service.saveCheckpoint(`Task: ${cline.taskId}, Time: ${Date.now()}`, { allowEmpty: force }).catch((err) => {
+	return service.saveCheckpoint(`Task: ${task.taskId}, Time: ${Date.now()}`, { allowEmpty: force }).catch((err) => {
 		console.error("[Task#checkpointSave] caught unexpected error, disabling checkpoints", err)
-		cline.enableCheckpoints = false
+		task.enableCheckpoints = false
 	})
 }
 
@@ -183,39 +186,39 @@ export type CheckpointRestoreOptions = {
 	mode: "preview" | "restore"
 }
 
-export async function checkpointRestore(cline: Task, { ts, commitHash, mode }: CheckpointRestoreOptions) {
-	const service = await getCheckpointService(cline)
+export async function checkpointRestore(task: Task, { ts, commitHash, mode }: CheckpointRestoreOptions) {
+	const service = await getCheckpointService(task)
 
 	if (!service) {
 		return
 	}
 
-	const index = cline.clineMessages.findIndex((m) => m.ts === ts)
+	const index = task.clineMessages.findIndex((m) => m.ts === ts)
 
 	if (index === -1) {
 		return
 	}
 
-	const provider = cline.providerRef.deref()
+	const provider = task.providerRef.deref()
 
 	try {
 		await service.restoreCheckpoint(commitHash)
-		TelemetryService.instance.captureCheckpointRestored(cline.taskId)
+		TelemetryService.instance.captureCheckpointRestored(task.taskId)
 		await provider?.postMessageToWebview({ type: "currentCheckpointUpdated", text: commitHash })
 
 		if (mode === "restore") {
-			await cline.overwriteApiConversationHistory(cline.apiConversationHistory.filter((m) => !m.ts || m.ts < ts))
+			await task.overwriteApiConversationHistory(task.apiConversationHistory.filter((m) => !m.ts || m.ts < ts))
 
-			const deletedMessages = cline.clineMessages.slice(index + 1)
+			const deletedMessages = task.clineMessages.slice(index + 1)
 
 			const { totalTokensIn, totalTokensOut, totalCacheWrites, totalCacheReads, totalCost } = getApiMetrics(
-				cline.combineMessages(deletedMessages),
+				task.combineMessages(deletedMessages),
 			)
 
-			await cline.overwriteClineMessages(cline.clineMessages.slice(0, index + 1))
+			await task.overwriteClineMessages(task.clineMessages.slice(0, index + 1))
 
 			// TODO: Verify that this is working as expected.
-			await cline.say(
+			await task.say(
 				"api_req_deleted",
 				JSON.stringify({
 					tokensIn: totalTokensIn,
@@ -230,17 +233,17 @@ export async function checkpointRestore(cline: Task, { ts, commitHash, mode }: C
 		// The task is already cancelled by the provider beforehand, but we
 		// need to re-init to get the updated messages.
 		//
-		// This was take from Cline's implementation of the checkpoints
-		// feature. The cline instance will hang if we don't cancel twice,
+		// This was taken from Cline's implementation of the checkpoints
+		// feature. The task instance will hang if we don't cancel twice,
 		// so this is currently necessary, but it seems like a complicated
 		// and hacky solution to a problem that I don't fully understand.
 		// I'd like to revisit this in the future and try to improve the
 		// task flow and the communication between the webview and the
-		// Cline instance.
+		// `Task` instance.
 		provider?.cancelTask()
 	} catch (err) {
 		provider?.log("[checkpointRestore] disabling checkpoints for this task")
-		cline.enableCheckpoints = false
+		task.enableCheckpoints = false
 	}
 }
 
@@ -251,20 +254,21 @@ export type CheckpointDiffOptions = {
 	mode: "full" | "checkpoint"
 }
 
-export async function checkpointDiff(cline: Task, { ts, previousCommitHash, commitHash, mode }: CheckpointDiffOptions) {
-	const service = await getCheckpointService(cline)
+export async function checkpointDiff(task: Task, { ts, previousCommitHash, commitHash, mode }: CheckpointDiffOptions) {
+	const service = await getCheckpointService(task)
 
 	if (!service) {
 		return
 	}
 
-	TelemetryService.instance.captureCheckpointDiffed(cline.taskId)
+	TelemetryService.instance.captureCheckpointDiffed(task.taskId)
 
 	let prevHash = commitHash
 	let nextHash: string | undefined
 
 	const checkpoints = typeof service.getCheckpoints === "function" ? service.getCheckpoints() : []
 	const idx = checkpoints.indexOf(commitHash)
+
 	if (idx !== -1 && idx < checkpoints.length - 1) {
 		nextHash = checkpoints[idx + 1]
 	} else {
@@ -293,8 +297,8 @@ export async function checkpointDiff(cline: Task, { ts, previousCommitHash, comm
 			]),
 		)
 	} catch (err) {
-		const provider = cline.providerRef.deref()
+		const provider = task.providerRef.deref()
 		provider?.log("[checkpointDiff] disabling checkpoints for this task")
-		cline.enableCheckpoints = false
+		task.enableCheckpoints = false
 	}
 }

+ 24 - 2
src/core/task/Task.ts

@@ -50,7 +50,7 @@ import { combineApiRequests } from "../../shared/combineApiRequests"
 import { combineCommandSequences } from "../../shared/combineCommandSequences"
 import { t } from "../../i18n"
 import { ClineApiReqCancelReason, ClineApiReqInfo } from "../../shared/ExtensionMessage"
-import { getApiMetrics } from "../../shared/getApiMetrics"
+import { getApiMetrics, hasTokenUsageChanged } from "../../shared/getApiMetrics"
 import { ClineAskResponse } from "../../shared/WebviewMessage"
 import { defaultModeSlug } from "../../shared/modes"
 import { DiffStrategy } from "../../shared/tools"
@@ -292,6 +292,10 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 	private lastUsedInstructions?: string
 	private skipPrevResponseIdOnce: boolean = false
 
+	// Token Usage Cache
+	private tokenUsageSnapshot?: TokenUsage
+	private tokenUsageSnapshotAt?: number
+
 	constructor({
 		provider,
 		apiConfiguration,
@@ -669,7 +673,11 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 				mode: this._taskMode || defaultModeSlug, // Use the task's own mode, not the current provider mode.
 			})
 
-			this.emit(RooCodeEventName.TaskTokenUsageUpdated, this.taskId, tokenUsage)
+			if (hasTokenUsageChanged(tokenUsage, this.tokenUsageSnapshot)) {
+				this.emit(RooCodeEventName.TaskTokenUsageUpdated, this.taskId, tokenUsage)
+				this.tokenUsageSnapshot = undefined
+				this.tokenUsageSnapshotAt = undefined
+			}
 
 			await this.providerRef.deref()?.updateTaskHistory(historyItem)
 		} catch (error) {
@@ -960,6 +968,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 		}
 
 		const { contextTokens: prevContextTokens } = this.getTokenUsage()
+
 		const {
 			messages,
 			summary,
@@ -2378,11 +2387,13 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 
 		const { contextTokens } = this.getTokenUsage()
 		const modelInfo = this.api.getModel().info
+
 		const maxTokens = getModelMaxOutputTokens({
 			modelId: this.api.getModel().id,
 			model: modelInfo,
 			settings: this.apiConfiguration,
 		})
+
 		const contextWindow = modelInfo.contextWindow
 
 		// Get the current profile ID using the helper method
@@ -2828,6 +2839,17 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 		return this.messageQueueService.messages
 	}
 
+	public get tokenUsage(): TokenUsage | undefined {
+		if (this.tokenUsageSnapshot && this.tokenUsageSnapshotAt) {
+			return this.tokenUsageSnapshot
+		}
+
+		this.tokenUsageSnapshot = this.getTokenUsage()
+		this.tokenUsageSnapshotAt = this.clineMessages.at(-1)?.ts
+
+		return this.tokenUsageSnapshot
+	}
+
 	public get cwd() {
 		return this.workspacePath
 	}

+ 7 - 2
src/core/webview/ClineProvider.ts

@@ -31,6 +31,7 @@ import {
 	type HistoryItem,
 	type CloudUserInfo,
 	type CreateTaskOptions,
+	type TokenUsage,
 	RooCodeEventName,
 	requestyDefaultModelId,
 	openRouterDefaultModelId,
@@ -184,10 +185,12 @@ export class ClineProvider
 			const onTaskInteractive = (taskId: string) => this.emit(RooCodeEventName.TaskInteractive, taskId)
 			const onTaskResumable = (taskId: string) => this.emit(RooCodeEventName.TaskResumable, taskId)
 			const onTaskIdle = (taskId: string) => this.emit(RooCodeEventName.TaskIdle, taskId)
-			const onTaskUserMessage = (taskId: string) => this.emit(RooCodeEventName.TaskUserMessage, taskId)
 			const onTaskPaused = (taskId: string) => this.emit(RooCodeEventName.TaskPaused, taskId)
 			const onTaskUnpaused = (taskId: string) => this.emit(RooCodeEventName.TaskUnpaused, taskId)
 			const onTaskSpawned = (taskId: string) => this.emit(RooCodeEventName.TaskSpawned, taskId)
+			const onTaskUserMessage = (taskId: string) => this.emit(RooCodeEventName.TaskUserMessage, taskId)
+			const onTaskTokenUsageUpdated = (taskId: string, tokenUsage: TokenUsage) =>
+				this.emit(RooCodeEventName.TaskTokenUsageUpdated, taskId, tokenUsage)
 
 			// Attach the listeners.
 			instance.on(RooCodeEventName.TaskStarted, onTaskStarted)
@@ -199,10 +202,11 @@ export class ClineProvider
 			instance.on(RooCodeEventName.TaskInteractive, onTaskInteractive)
 			instance.on(RooCodeEventName.TaskResumable, onTaskResumable)
 			instance.on(RooCodeEventName.TaskIdle, onTaskIdle)
-			instance.on(RooCodeEventName.TaskUserMessage, onTaskUserMessage)
 			instance.on(RooCodeEventName.TaskPaused, onTaskPaused)
 			instance.on(RooCodeEventName.TaskUnpaused, onTaskUnpaused)
 			instance.on(RooCodeEventName.TaskSpawned, onTaskSpawned)
+			instance.on(RooCodeEventName.TaskUserMessage, onTaskUserMessage)
+			instance.on(RooCodeEventName.TaskTokenUsageUpdated, onTaskTokenUsageUpdated)
 
 			// Store the cleanup functions for later removal.
 			this.taskEventListeners.set(instance, [
@@ -219,6 +223,7 @@ export class ClineProvider
 				() => instance.off(RooCodeEventName.TaskPaused, onTaskPaused),
 				() => instance.off(RooCodeEventName.TaskUnpaused, onTaskUnpaused),
 				() => instance.off(RooCodeEventName.TaskSpawned, onTaskSpawned),
+				() => instance.off(RooCodeEventName.TaskTokenUsageUpdated, onTaskTokenUsageUpdated),
 			])
 		}
 

+ 34 - 4
src/shared/getApiMetrics.ts

@@ -36,7 +36,7 @@ export function getApiMetrics(messages: ClineMessage[]) {
 		contextTokens: 0,
 	}
 
-	// Calculate running totals
+	// Calculate running totals.
 	messages.forEach((message) => {
 		if (message.type === "say" && message.say === "api_req_started" && message.text) {
 			try {
@@ -46,15 +46,19 @@ export function getApiMetrics(messages: ClineMessage[]) {
 				if (typeof tokensIn === "number") {
 					result.totalTokensIn += tokensIn
 				}
+
 				if (typeof tokensOut === "number") {
 					result.totalTokensOut += tokensOut
 				}
+
 				if (typeof cacheWrites === "number") {
 					result.totalCacheWrites = (result.totalCacheWrites ?? 0) + cacheWrites
 				}
+
 				if (typeof cacheReads === "number") {
 					result.totalCacheReads = (result.totalCacheReads ?? 0) + cacheReads
 				}
+
 				if (typeof cost === "number") {
 					result.totalCost += cost
 				}
@@ -66,20 +70,23 @@ export function getApiMetrics(messages: ClineMessage[]) {
 		}
 	})
 
-	// Calculate context tokens, from the last API request started or condense context message
+	// Calculate context tokens, from the last API request started or condense
+	// context message.
 	result.contextTokens = 0
+
 	for (let i = messages.length - 1; i >= 0; i--) {
 		const message = messages[i]
+
 		if (message.type === "say" && message.say === "api_req_started" && message.text) {
 			try {
 				const parsedText: ParsedApiReqStartedTextType = JSON.parse(message.text)
 				const { tokensIn, tokensOut, cacheWrites, cacheReads, apiProtocol } = parsedText
 
-				// Calculate context tokens based on API protocol
+				// Calculate context tokens based on API protocol.
 				if (apiProtocol === "anthropic") {
 					result.contextTokens = (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0)
 				} else {
-					// For OpenAI (or when protocol is not specified)
+					// For OpenAI (or when protocol is not specified).
 					result.contextTokens = (tokensIn || 0) + (tokensOut || 0)
 				}
 			} catch (error) {
@@ -96,3 +103,26 @@ export function getApiMetrics(messages: ClineMessage[]) {
 
 	return result
 }
+
+/**
+ * Check if token usage has changed by comparing relevant properties.
+ * @param current - Current token usage data
+ * @param snapshot - Previous snapshot to compare against
+ * @returns true if any relevant property has changed or snapshot is undefined
+ */
+export function hasTokenUsageChanged(current: TokenUsage, snapshot?: TokenUsage): boolean {
+	if (!snapshot) {
+		return true
+	}
+
+	const keysToCompare: (keyof TokenUsage)[] = [
+		"totalTokensIn",
+		"totalTokensOut",
+		"totalCacheWrites",
+		"totalCacheReads",
+		"totalCost",
+		"contextTokens",
+	]
+
+	return keysToCompare.some((key) => current[key] !== snapshot[key])
+}