Saoud Rizwan 1 год назад
Родитель
Сommit
9b1b9c10a1

+ 83 - 52
src/core/ClaudeDev.ts

@@ -21,7 +21,7 @@ import { ApiConfiguration } from "../shared/api"
 import { findLastIndex } from "../shared/array"
 import { combineApiRequests } from "../shared/combineApiRequests"
 import { combineCommandSequences } from "../shared/combineCommandSequences"
-import { ClaudeAsk, ClaudeMessage, ClaudeSay, ClaudeSayTool } from "../shared/ExtensionMessage"
+import { ClaudeApiReqInfo, ClaudeAsk, ClaudeMessage, ClaudeSay, ClaudeSayTool } from "../shared/ExtensionMessage"
 import { getApiMetrics } from "../shared/getApiMetrics"
 import { HistoryItem } from "../shared/HistoryItem"
 import { ToolName } from "../shared/Tool"
@@ -69,6 +69,7 @@ export class ClaudeDev {
 	private consecutiveMistakeCount: number = 0
 	private providerRef: WeakRef<ClaudeDevProvider>
 	private abort: boolean = false
+	didFinishAborting = false
 	private diffViewProvider: DiffViewProvider
 
 	// streaming
@@ -381,19 +382,6 @@ export class ClaudeDev {
 	private async resumeTaskFromHistory() {
 		const modifiedClaudeMessages = await this.getSavedClaudeMessages()
 
-		// Need to modify claude messages for good ux, i.e. if the last message is an api_request_started, then remove it otherwise the user will think the request is still loading
-		const lastApiReqStartedIndex = modifiedClaudeMessages.reduce(
-			(lastIndex, m, index) => (m.type === "say" && m.say === "api_req_started" ? index : lastIndex),
-			-1
-		)
-		const lastApiReqFinishedIndex = modifiedClaudeMessages.reduce(
-			(lastIndex, m, index) => (m.type === "say" && m.say === "api_req_finished" ? index : lastIndex),
-			-1
-		)
-		if (lastApiReqStartedIndex > lastApiReqFinishedIndex && lastApiReqStartedIndex !== -1) {
-			modifiedClaudeMessages.splice(lastApiReqStartedIndex, 1)
-		}
-
 		// Remove any resume messages that may have been added before
 		const lastRelevantMessageIndex = findLastIndex(
 			modifiedClaudeMessages,
@@ -403,6 +391,23 @@ export class ClaudeDev {
 			modifiedClaudeMessages.splice(lastRelevantMessageIndex + 1)
 		}
 
+		// if the last message is an api_req_started it means there was no partial content streamed, so we remove it
+		if (modifiedClaudeMessages.at(-1)?.say === "api_req_started") {
+			modifiedClaudeMessages.pop()
+		}
+		// since we don't use api_req_finished anymore, we need to check if the last api_req_started has a cost value, if it doesn't and it's not cancelled, then we remove it since it indicates an api request without any partial content streamed
+		// const lastApiReqStartedIndex = findLastIndex(
+		// 	modifiedClaudeMessages,
+		// 	(m) => m.type === "say" && m.say === "api_req_started"
+		// )
+		// if (lastApiReqStartedIndex !== -1) {
+		// 	const lastApiReqStarted = modifiedClaudeMessages[lastApiReqStartedIndex]
+		// 	const { cost, cancelled }: ClaudeApiReqInfo = JSON.parse(lastApiReqStarted.text || "{}")
+		// 	if (cost === undefined || cancelled) {
+		// 		modifiedClaudeMessages.splice(lastApiReqStartedIndex, 1)
+		// 	}
+		// }
+
 		await this.overwriteClaudeMessages(modifiedClaudeMessages)
 		this.claudeMessages = await this.getSavedClaudeMessages()
 
@@ -698,13 +703,9 @@ export class ClaudeDev {
 			if (previousApiReqIndex >= 0) {
 				const previousRequest = this.claudeMessages[previousApiReqIndex]
 				if (previousRequest && previousRequest.text) {
-					const {
-						tokensIn,
-						tokensOut,
-						cacheWrites,
-						cacheReads,
-					}: { tokensIn?: number; tokensOut?: number; cacheWrites?: number; cacheReads?: number } =
-						JSON.parse(previousRequest.text)
+					const { tokensIn, tokensOut, cacheWrites, cacheReads }: ClaudeApiReqInfo = JSON.parse(
+						previousRequest.text
+					)
 					const totalTokens = (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0)
 					const contextWindow = this.api.getModel().info.contextWindow
 					const maxAllowedSize = Math.max(contextWindow - 40_000, contextWindow * 0.8)
@@ -1584,7 +1585,7 @@ export class ClaudeDev {
 			request: userContent
 				.map((block) => formatContentBlockToMarkdown(block, this.apiConversationHistory))
 				.join("\n\n"),
-		})
+		} satisfies ClaudeApiReqInfo)
 		await this.saveClaudeMessages()
 		await this.providerRef.deref()?.postStateToWebview()
 
@@ -1596,6 +1597,29 @@ export class ClaudeDev {
 			let outputTokens = 0
 			let totalCost: number | undefined
 
+			// update api_req_started. we can't use api_req_finished anymore since it's a unique case where it could come after a streaming message (ie in the middle of being updated or executed)
+			// fortunately api_req_finished was always parsed out for the gui anyways, so it remains solely for legacy purposes to keep track of prices in tasks from history
+			// (it's worth removing a few months from now)
+			const updateApiReqMsg = (cancelled?: boolean) => {
+				this.claudeMessages[lastApiReqIndex].text = JSON.stringify({
+					...JSON.parse(this.claudeMessages[lastApiReqIndex].text || "{}"),
+					tokensIn: inputTokens,
+					tokensOut: outputTokens,
+					cacheWrites: cacheWriteTokens,
+					cacheReads: cacheReadTokens,
+					cost:
+						totalCost ??
+						calculateApiCost(
+							this.api.getModel().info,
+							inputTokens,
+							outputTokens,
+							cacheWriteTokens,
+							cacheReadTokens
+						),
+					cancelled,
+				} satisfies ClaudeApiReqInfo)
+			}
+
 			// reset streaming state
 			this.currentStreamingContentIndex = 0
 			this.assistantMessageContent = []
@@ -1624,6 +1648,42 @@ export class ClaudeDev {
 						this.presentAssistantMessage()
 						break
 				}
+
+				if (this.abort) {
+					console.log("aborting stream...")
+					if (this.diffViewProvider.isEditing) {
+						await this.diffViewProvider.revertChanges() // closes diff view
+					}
+
+					// if last message is a partial we need to save it
+					const lastMessage = this.claudeMessages.at(-1)
+					if (lastMessage && lastMessage.partial) {
+						lastMessage.ts = Date.now()
+						lastMessage.partial = false
+						// instead of streaming partialMessage events, we do a save and post like normal to persist to disk
+						console.log("saving messages...", lastMessage)
+						// await this.saveClaudeMessages()
+					}
+
+					//
+					await this.addToApiConversationHistory({
+						role: "assistant",
+						content: [{ type: "text", text: assistantMessage + "\n\n[Response interrupted by user]" }],
+					})
+
+					// update api_req_started to have cancelled and cost, so that we can display the cost of the partial stream
+					updateApiReqMsg(true)
+					await this.saveClaudeMessages()
+
+					// signals to provider that it can retrieve the saved messages from disk, as abortTask can not be awaited on in nature
+					this.didFinishAborting = true
+					break // aborts the stream
+				}
+			}
+
+			// need to call here in case the stream was aborted
+			if (this.abort) {
+				throw new Error("ClaudeDev instance aborted")
 			}
 
 			this.didCompleteReadingStream = true
@@ -1637,36 +1697,7 @@ export class ClaudeDev {
 				this.presentAssistantMessage() // if there is content to update then it will complete and update this.userMessageContentReady to true, which we pwaitfor before making the next request
 			}
 
-			// let inputTokens = response.usage.input_tokens
-			// let outputTokens = response.usage.output_tokens
-			// let cacheCreationInputTokens =
-			// 	(response as Anthropic.Beta.PromptCaching.Messages.PromptCachingBetaMessage).usage
-			// 		.cache_creation_input_tokens || undefined
-			// let cacheReadInputTokens =
-			// 	(response as Anthropic.Beta.PromptCaching.Messages.PromptCachingBetaMessage).usage
-			// 		.cache_read_input_tokens || undefined
-			// @ts-ignore-next-line
-			// let totalCost = response.usage.total_cost
-
-			// update api_req_started. we can't use api_req_finished anymore since it's a unique case where it could come after a streaming message (ie in the middle of being updated or executed)
-			// fortunately api_req_finished was always parsed out for the gui anyways, so it remains solely for legacy purposes to keep track of prices in tasks from history
-			// (it's worth removing a few months from now)
-			this.claudeMessages[lastApiReqIndex].text = JSON.stringify({
-				...JSON.parse(this.claudeMessages[lastApiReqIndex].text),
-				tokensIn: inputTokens,
-				tokensOut: outputTokens,
-				cacheWrites: cacheWriteTokens,
-				cacheReads: cacheReadTokens,
-				cost:
-					totalCost ??
-					calculateApiCost(
-						this.api.getModel().info,
-						inputTokens,
-						outputTokens,
-						cacheWriteTokens,
-						cacheReadTokens
-					),
-			})
+			updateApiReqMsg()
 			await this.saveClaudeMessages()
 			await this.providerRef.deref()?.postStateToWebview()
 

+ 13 - 0
src/core/webview/ClaudeDevProvider.ts

@@ -19,6 +19,7 @@ import WorkspaceTracker from "../../integrations/workspace/WorkspaceTracker"
 import { openMention } from "../mentions"
 import { fileExistsAtPath } from "../../utils/fs"
 import { buildApiHandler } from "../../api"
+import pWaitFor from "p-wait-for"
 
 /*
 https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default/weather-webview/src/providers/WeatherViewProvider.ts
@@ -441,6 +442,18 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 						break
 					case "openMention":
 						openMention(message.text)
+						break
+					case "cancelTask":
+						if (this.claudeDev) {
+							const { historyItem } = await this.getTaskWithId(this.claudeDev.taskId)
+							this.claudeDev.abortTask()
+							await pWaitFor(() => this.claudeDev === undefined || this.claudeDev.didFinishAborting, {
+								timeout: 3_000,
+							})
+							await this.initClaudeDevWithHistoryItem(historyItem) // clears task again, so we need to abortTask manually above
+							await this.postStateToWebview()
+						}
+
 						break
 					// Add more switch case statements here as more webview message commands
 					// are created within the webview context (i.e. inside media/main.js)

+ 10 - 0
src/shared/ExtensionMessage.ts

@@ -87,3 +87,13 @@ export interface ClaudeSayTool {
 	regex?: string
 	filePattern?: string
 }
+
+export interface ClaudeApiReqInfo {
+	request?: string
+	tokensIn?: number
+	tokensOut?: number
+	cacheWrites?: number
+	cacheReads?: number
+	cost?: number
+	cancelled?: boolean
+}

+ 1 - 0
src/shared/WebviewMessage.ts

@@ -20,6 +20,7 @@ export interface WebviewMessage {
 		| "openImage"
 		| "openFile"
 		| "openMention"
+		| "cancelTask"
 	text?: string
 	askResponse?: ClaudeAskResponse
 	apiConfiguration?: ApiConfiguration

+ 21 - 9
webview-ui/src/components/chat/ChatRow.tsx

@@ -2,7 +2,7 @@ import { VSCodeBadge, VSCodeProgressRing } from "@vscode/webview-ui-toolkit/reac
 import deepEqual from "fast-deep-equal"
 import React, { memo, useMemo } from "react"
 import ReactMarkdown from "react-markdown"
-import { ClaudeMessage, ClaudeSayTool } from "../../../../src/shared/ExtensionMessage"
+import { ClaudeApiReqInfo, ClaudeMessage, ClaudeSayTool } from "../../../../src/shared/ExtensionMessage"
 import { COMMAND_OUTPUT_STRING } from "../../../../src/shared/combineCommandSequences"
 import { vscode } from "../../utils/vscode"
 import CodeAccordian, { removeLeadingNonAlphanumeric } from "../common/CodeAccordian"
@@ -37,11 +37,12 @@ const ChatRow = memo(
 export default ChatRow
 
 const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessage, isLast }: ChatRowProps) => {
-	const cost = useMemo(() => {
+	const [cost, apiReqCancelled] = useMemo(() => {
 		if (message.text != null && message.say === "api_req_started") {
-			return JSON.parse(message.text).cost
+			const info: ClaudeApiReqInfo = JSON.parse(message.text)
+			return [info.cost, info.cancelled]
 		}
-		return undefined
+		return [undefined, undefined]
 	}, [message.text, message.say])
 	const apiRequestFailedMessage =
 		isLast && lastModifiedMessage?.ask === "api_req_failed" // if request is retried then the latest message is a api_req_retried
@@ -54,6 +55,7 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa
 	const normalColor = "var(--vscode-foreground)"
 	const errorColor = "var(--vscode-errorForeground)"
 	const successColor = "var(--vscode-charts-green)"
+	const cancelledColor = "var(--vscode-descriptionForeground)"
 
 	const [icon, title] = useMemo(() => {
 		switch (type) {
@@ -94,9 +96,15 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa
 			case "api_req_started":
 				return [
 					cost != null ? (
-						<span
-							className="codicon codicon-check"
-							style={{ color: successColor, marginBottom: "-1.5px" }}></span>
+						apiReqCancelled ? (
+							<span
+								className="codicon codicon-error"
+								style={{ color: cancelledColor, marginBottom: "-1.5px" }}></span>
+						) : (
+							<span
+								className="codicon codicon-check"
+								style={{ color: successColor, marginBottom: "-1.5px" }}></span>
+						)
 					) : apiRequestFailedMessage ? (
 						<span
 							className="codicon codicon-error"
@@ -105,7 +113,11 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa
 						<ProgressIndicator />
 					),
 					cost != null ? (
-						<span style={{ color: normalColor, fontWeight: "bold" }}>API Request</span>
+						apiReqCancelled ? (
+							<span style={{ color: normalColor, fontWeight: "bold" }}>API Request Cancelled</span>
+						) : (
+							<span style={{ color: normalColor, fontWeight: "bold" }}>API Request</span>
+						)
 					) : apiRequestFailedMessage ? (
 						<span style={{ color: errorColor, fontWeight: "bold" }}>API Request Failed</span>
 					) : (
@@ -122,7 +134,7 @@ const ChatRowContent = ({ message, isExpanded, onToggleExpand, lastModifiedMessa
 			default:
 				return [null, null]
 		}
-	}, [type, cost, apiRequestFailedMessage, isCommandExecuting])
+	}, [type, cost, apiRequestFailedMessage, isCommandExecuting, apiReqCancelled])
 
 	const headerStyle: React.CSSProperties = {
 		display: "flex",

+ 39 - 7
webview-ui/src/components/chat/ChatView.tsx

@@ -14,6 +14,7 @@ import ChatRow from "./ChatRow"
 import ChatTextArea from "./ChatTextArea"
 import HistoryPreview from "../history/HistoryPreview"
 import TaskHeader from "./TaskHeader"
+import { findLast } from "../../../../src/shared/array"
 
 interface ChatViewProps {
 	isHidden: boolean
@@ -182,6 +183,24 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
 		}
 	}, [messages.length])
 
+	const isStreaming = useMemo(() => {
+		const isLastMessagePartial = modifiedMessages.at(-1)?.partial === true
+		if (isLastMessagePartial) {
+			return true
+		} else {
+			const lastApiReqStarted = findLast(modifiedMessages, (message) => message.say === "api_req_started")
+			if (lastApiReqStarted && lastApiReqStarted.text != null && lastApiReqStarted.say === "api_req_started") {
+				const cost = JSON.parse(lastApiReqStarted.text).cost
+				if (cost === undefined) {
+					// api request has not finished yet
+					return true
+				}
+			}
+		}
+
+		return false
+	}, [modifiedMessages])
+
 	const handleSendMessage = useCallback(
 		(text: string, images: string[]) => {
 			text = text.trim()
@@ -251,6 +270,11 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
 	}, [claudeAsk, startNewTask])
 
 	const handleSecondaryButtonClick = useCallback(() => {
+		if (isStreaming) {
+			vscode.postMessage({ type: "cancelTask" })
+			return
+		}
+
 		switch (claudeAsk) {
 			case "api_req_failed":
 			case "mistake_limit_reached":
@@ -267,7 +291,7 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
 		setEnableButtons(false)
 		// setPrimaryButtonText(undefined)
 		// setSecondaryButtonText(undefined)
-	}, [claudeAsk, startNewTask])
+	}, [claudeAsk, startNewTask, isStreaming])
 
 	const handleTaskCloseButtonClick = useCallback(() => {
 		startNewTask()
@@ -544,11 +568,16 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
 					/>
 					<div
 						style={{
-							opacity: primaryButtonText || secondaryButtonText ? (enableButtons ? 1 : 0.5) : 0,
+							opacity:
+								primaryButtonText || secondaryButtonText || isStreaming
+									? enableButtons || isStreaming
+										? 1
+										: 0.5
+									: 0,
 							display: "flex",
 							padding: "10px 15px 0px 15px",
 						}}>
-						{primaryButtonText && (
+						{primaryButtonText && !isStreaming && (
 							<VSCodeButton
 								appearance="primary"
 								disabled={!enableButtons}
@@ -560,13 +589,16 @@ const ChatView = ({ isHidden, showAnnouncement, hideAnnouncement, showHistoryVie
 								{primaryButtonText}
 							</VSCodeButton>
 						)}
-						{secondaryButtonText && (
+						{(secondaryButtonText || isStreaming) && (
 							<VSCodeButton
 								appearance="secondary"
-								disabled={!enableButtons}
-								style={{ flex: 1, marginLeft: "6px" }}
+								disabled={!enableButtons && !isStreaming}
+								style={{
+									flex: isStreaming ? 2 : 1,
+									marginLeft: isStreaming ? 0 : "6px",
+								}}
 								onClick={handleSecondaryButtonClick}>
-								{secondaryButtonText}
+								{isStreaming ? "Cancel" : secondaryButtonText}
 							</VSCodeButton>
 						)}
 					</div>