ソースを参照

Merge pull request #718 from MuriloFP/fix-context-calculation

fix: update context token calculation
Matt Rubens 11 ヶ月 前
コミット
4040e934b2
1 ファイル変更20 行追加23 行削除
  1. 20 23
      src/shared/getApiMetrics.ts

+ 20 - 23
src/shared/getApiMetrics.ts

@@ -6,7 +6,7 @@ interface ApiMetrics {
 	totalCacheWrites?: number
 	totalCacheReads?: number
 	totalCost: number
-	contextTokens: number // Total tokens in conversation (last message's tokensIn + tokensOut)
+	contextTokens: number // Total tokens in conversation (last message's tokensIn + tokensOut + cacheWrites + cacheReads)
 }
 
 /**
@@ -17,7 +17,7 @@ interface ApiMetrics {
  * It extracts and sums up the tokensIn, tokensOut, cacheWrites, cacheReads, and cost from these messages.
  *
  * @param messages - An array of ClineMessage objects to process.
- * @returns An ApiMetrics object containing totalTokensIn, totalTokensOut, totalCacheWrites, totalCacheReads, and totalCost.
+ * @returns An ApiMetrics object containing totalTokensIn, totalTokensOut, totalCacheWrites, totalCacheReads, totalCost, and contextTokens.
  *
  * @example
  * const messages = [
@@ -36,27 +36,30 @@ export function getApiMetrics(messages: ClineMessage[]): ApiMetrics {
 		contextTokens: 0,
 	}
 
-	// Find the last api_req_started message that has valid token information
+	// Helper function to get total tokens from a message
+	const getTotalTokensFromMessage = (message: ClineMessage): number => {
+		if (!message.text) return 0
+		try {
+			const { tokensIn, tokensOut, cacheWrites, cacheReads } = JSON.parse(message.text)
+			return (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0)
+		} catch {
+			return 0
+		}
+	}
+
+	// Find the last api_req_started message that has any tokens
 	const lastApiReq = [...messages].reverse().find((message) => {
-		if (message.type === "say" && message.say === "api_req_started" && message.text) {
-			try {
-				const parsedData = JSON.parse(message.text)
-				return typeof parsedData.tokensIn === "number" && typeof parsedData.tokensOut === "number"
-			} catch {
-				return false
-			}
+		if (message.type === "say" && message.say === "api_req_started") {
+			return getTotalTokensFromMessage(message) > 0
 		}
 		return false
 	})
 
-	// Keep track of the last valid context tokens
-	let lastValidContextTokens = 0
-
+	// Calculate running totals
 	messages.forEach((message) => {
 		if (message.type === "say" && message.say === "api_req_started" && message.text) {
 			try {
-				const parsedData = JSON.parse(message.text)
-				const { tokensIn, tokensOut, cacheWrites, cacheReads, cost } = parsedData
+				const { tokensIn, tokensOut, cacheWrites, cacheReads, cost } = JSON.parse(message.text)
 
 				if (typeof tokensIn === "number") {
 					result.totalTokensIn += tokensIn
@@ -74,15 +77,9 @@ export function getApiMetrics(messages: ClineMessage[]): ApiMetrics {
 					result.totalCost += cost
 				}
 
-				// Update last valid context tokens whenever we have valid input and output tokens
-				if (tokensIn > 0 && tokensOut > 0) {
-					lastValidContextTokens = tokensIn + tokensOut
-				}
-
-				// If this is the last api request, use its tokens for context size
+				// If this is the last api request with tokens, use its total for context size
 				if (message === lastApiReq) {
-					// Use the last valid context tokens if the current request doesn't have valid tokens
-					result.contextTokens = tokensIn > 0 && tokensOut > 0 ? tokensIn + tokensOut : lastValidContextTokens
+					result.contextTokens = getTotalTokensFromMessage(message)
 				}
 			} catch (error) {
 				console.error("Error parsing JSON:", error)