Просмотр исходного кода

fix: prevent cascading truncation loop by only truncating visible messages (#9844)

Hannes Rudolph 4 недель назад
Родитель
Сommit
a1d392f5cc

+ 27 - 21
src/core/context-management/__tests__/context-management.spec.ts

@@ -65,14 +65,12 @@ describe("Context Management", () => {
 
 			// With 2 messages after the first, 0.5 fraction means remove 1 message
 			// But 1 is odd, so it rounds down to 0 (to make it even)
-			// Result should have messages + truncation marker
-			expect(result.messages.length).toBe(4) // First message + truncation marker + 2 remaining messages
+			// No truncation happens, so no marker is inserted
+			expect(result.messages.length).toBe(3) // Original messages unchanged
+			expect(result.messagesRemoved).toBe(0)
 			expect(result.messages[0]).toEqual(messages[0])
-			// messages[1] is the truncation marker
-			expect(result.messages[1].isTruncationMarker).toBe(true)
-			// Original messages[1] and messages[2] are at indices 2 and 3 now
-			expect(result.messages[2].content).toEqual(messages[1].content)
-			expect(result.messages[3].content).toEqual(messages[2].content)
+			expect(result.messages[1]).toEqual(messages[1])
+			expect(result.messages[2]).toEqual(messages[2])
 		})
 
 		it("should remove the specified fraction of messages (rounded to even number)", () => {
@@ -92,11 +90,16 @@ describe("Context Management", () => {
 			expect(result.messages.length).toBe(6) // 5 original + 1 marker
 			expect(result.messagesRemoved).toBe(2)
 			expect(result.messages[0]).toEqual(messages[0])
-			expect(result.messages[1].isTruncationMarker).toBe(true)
-			// Messages 2 and 3 (indices 1 and 2 from original) should be tagged
+
+			// Messages at indices 1 and 2 from original should be tagged
+			expect(result.messages[1].truncationParent).toBe(result.truncationId)
 			expect(result.messages[2].truncationParent).toBe(result.truncationId)
-			expect(result.messages[3].truncationParent).toBe(result.truncationId)
-			// Messages 4 and 5 (indices 3 and 4 from original) should NOT be tagged
+
+			// Marker should be at index 3 (at the boundary, after truncated messages)
+			expect(result.messages[3].isTruncationMarker).toBe(true)
+			expect(result.messages[3].role).toBe("user")
+
+			// Messages at indices 3 and 4 from original should NOT be tagged (now at indices 4 and 5)
 			expect(result.messages[4].truncationParent).toBeUndefined()
 			expect(result.messages[5].truncationParent).toBeUndefined()
 		})
@@ -117,9 +120,8 @@ describe("Context Management", () => {
 			const result = truncateConversation(messages, 0.3, taskId)
 
 			expect(result.messagesRemoved).toBe(0) // No messages removed
-			// Should still have truncation marker inserted
-			expect(result.messages.length).toBe(8) // 7 original + 1 marker
-			expect(result.messages[1].isTruncationMarker).toBe(true)
+			// When nothing is truncated, no marker is inserted
+			expect(result.messages.length).toBe(7) // Original messages unchanged
 		})
 
 		it("should handle edge case with fracToRemove = 0", () => {
@@ -132,9 +134,8 @@ describe("Context Management", () => {
 			const result = truncateConversation(messages, 0, taskId)
 
 			expect(result.messagesRemoved).toBe(0)
-			// Should have original messages + truncation marker
-			expect(result.messages.length).toBe(4)
-			expect(result.messages[1].isTruncationMarker).toBe(true)
+			// When nothing is truncated, no marker is inserted
+			expect(result.messages.length).toBe(3) // Original messages unchanged
 		})
 
 		it("should handle edge case with fracToRemove = 1", () => {
@@ -153,11 +154,16 @@ describe("Context Management", () => {
 			// Should have all original messages + truncation marker
 			expect(result.messages.length).toBe(5) // 4 original + 1 marker
 			expect(result.messages[0]).toEqual(messages[0])
-			expect(result.messages[1].isTruncationMarker).toBe(true)
-			// Messages at indices 2 and 3 should be tagged (original indices 1 and 2)
+
+			// Messages at indices 1 and 2 should be tagged
+			expect(result.messages[1].truncationParent).toBe(result.truncationId)
 			expect(result.messages[2].truncationParent).toBe(result.truncationId)
-			expect(result.messages[3].truncationParent).toBe(result.truncationId)
-			// Last message should NOT be tagged
+
+			// Marker should be at index 3 (at the boundary)
+			expect(result.messages[3].isTruncationMarker).toBe(true)
+			expect(result.messages[3].role).toBe("user")
+
+			// Last message should NOT be tagged (now at index 4)
 			expect(result.messages[4].truncationParent).toBeUndefined()
 		})
 	})

+ 60 - 25
src/core/context-management/__tests__/truncation.spec.ts

@@ -33,39 +33,41 @@ describe("Non-Destructive Sliding Window Truncation", () => {
 		it("should tag messages with truncationParent instead of deleting", () => {
 			const result = truncateConversation(messages, 0.5, "test-task-id")
 
-			// All messages should still be present
+			// All messages should still be present plus the truncation marker
 			expect(result.messages.length).toBe(messages.length + 1) // +1 for truncation marker
 
 			// Calculate expected messages to remove: floor((11-1) * 0.5) = 5, rounded to even = 4
 			const expectedMessagesToRemove = 4
 
-			// Messages 1-4 should be tagged with truncationParent
-			for (let i = 1; i <= expectedMessagesToRemove; i++) {
-				// Account for truncation marker inserted at position 1
-				const msgIndex = i < 1 ? i : i + 1
-				expect(result.messages[msgIndex].truncationParent).toBeDefined()
-				expect(result.messages[msgIndex].truncationParent).toBe(result.truncationId)
+			// Find which messages have truncationParent set
+			const taggedMessages = result.messages.filter((msg) => msg.truncationParent)
+			expect(taggedMessages.length).toBe(expectedMessagesToRemove)
+
+			// All tagged messages should point to the truncationId
+			for (const msg of taggedMessages) {
+				expect(msg.truncationParent).toBe(result.truncationId)
 			}
 
 			// First message should not be tagged
 			expect(result.messages[0].truncationParent).toBeUndefined()
 
-			// Remaining messages should not be tagged
-			for (let i = expectedMessagesToRemove + 2; i < result.messages.length; i++) {
-				expect(result.messages[i].truncationParent).toBeUndefined()
-			}
+			// Marker should not have truncationParent
+			const marker = result.messages.find((msg) => msg.isTruncationMarker)
+			expect(marker?.truncationParent).toBeUndefined()
 		})
 
 		it("should insert truncation marker with truncationId", () => {
 			const result = truncateConversation(messages, 0.5, "test-task-id")
 
-			// Truncation marker should be at index 1 (after first message)
-			const marker = result.messages[1]
-			expect(marker.isTruncationMarker).toBe(true)
-			expect(marker.truncationId).toBeDefined()
-			expect(marker.truncationId).toBe(result.truncationId)
-			expect(marker.role).toBe("assistant")
-			expect(marker.content).toContain("Sliding window truncation")
+			// Truncation marker should be at the boundary (after truncated messages)
+			// With 4 messages truncated (indices 1-4), marker should be at index 5
+			const marker = result.messages.find((msg) => msg.isTruncationMarker)
+			expect(marker).toBeDefined()
+			expect(marker!.isTruncationMarker).toBe(true)
+			expect(marker!.truncationId).toBeDefined()
+			expect(marker!.truncationId).toBe(result.truncationId)
+			expect(marker!.role).toBe("user")
+			expect(marker!.content).toContain("Sliding window truncation")
 		})
 
 		it("should return truncationId and messagesRemoved", () => {
@@ -367,10 +369,10 @@ describe("Non-Destructive Sliding Window Truncation", () => {
 			// No messages should be tagged (messagesToRemove = 0)
 			const taggedMessages = result.messages.filter((msg) => msg.truncationParent)
 			expect(taggedMessages.length).toBe(0)
+			expect(result.messagesRemoved).toBe(0)
 
-			// Should still have truncation marker
-			const marker = result.messages.find((msg) => msg.isTruncationMarker)
-			expect(marker).toBeDefined()
+			// When nothing is truncated, no marker is inserted
+			expect(result.messages).toEqual(messages)
 		})
 
 		it("should handle truncateConversation with very few messages", () => {
@@ -381,10 +383,43 @@ describe("Non-Destructive Sliding Window Truncation", () => {
 
 			const result = truncateConversation(fewMessages, 0.5, "test-task-id")
 
-			// Should not crash and should still create marker
-			expect(result.messages.length).toBeGreaterThan(0)
-			const marker = result.messages.find((msg) => msg.isTruncationMarker)
-			expect(marker).toBeDefined()
+			// With only 1 message after first, 0.5 fraction = 0.5, floored to 0, rounded to even = 0
+			// So no messages should be removed and no marker inserted
+			expect(result.messages.length).toBe(2)
+			expect(result.messagesRemoved).toBe(0)
+		})
+
+		it("should handle truncating all visible messages except first", () => {
+			// This tests the edge case where visibleIndices[messagesToRemove + 1] would be undefined
+			// 3 messages total: first is preserved, 2 others can be truncated
+			const threeMessages: ApiMessage[] = [
+				{ role: "user", content: "Initial", ts: 1000 },
+				{ role: "assistant", content: "Response 1", ts: 1100 },
+				{ role: "user", content: "Message 2", ts: 1200 },
+			]
+
+			// With fracToRemove = 1.0:
+			// visibleCount = 3
+			// rawMessagesToRemove = floor((3-1) * 1.0) = 2
+			// messagesToRemove = 2 (already even)
+			// This truncates ALL messages except the first
+			const result = truncateConversation(threeMessages, 1.0, "test-task-id")
+
+			expect(result.messagesRemoved).toBe(2)
+			// Should have 3 original messages + 1 marker = 4
+			expect(result.messages.length).toBe(4)
+
+			// First message should be untouched
+			expect(result.messages[0].truncationParent).toBeUndefined()
+			expect(result.messages[0].content).toBe("Initial")
+
+			// Messages at indices 1 and 2 should be tagged
+			expect(result.messages[1].truncationParent).toBe(result.truncationId)
+			expect(result.messages[2].truncationParent).toBe(result.truncationId)
+
+			// Marker should be at the end (index 3)
+			expect(result.messages[3].isTruncationMarker).toBe(true)
+			expect(result.messages[3].role).toBe("user")
 		})
 
 		it("should handle empty condenseParent and truncationParent gracefully", () => {

+ 42 - 7
src/core/context-management/index.ts

@@ -67,29 +67,64 @@ export function truncateConversation(messages: ApiMessage[], fracToRemove: numbe
 	TelemetryService.instance.captureSlidingWindowTruncation(taskId)
 
 	const truncationId = crypto.randomUUID()
-	const rawMessagesToRemove = Math.floor((messages.length - 1) * fracToRemove)
+
+	// Filter to only visible messages (those not already truncated)
+	// We need to track original indices to correctly tag messages in the full array
+	const visibleIndices: number[] = []
+	messages.forEach((msg, index) => {
+		if (!msg.truncationParent && !msg.isTruncationMarker) {
+			visibleIndices.push(index)
+		}
+	})
+
+	// Calculate how many visible messages to truncate (excluding first visible message)
+	const visibleCount = visibleIndices.length
+	const rawMessagesToRemove = Math.floor((visibleCount - 1) * fracToRemove)
 	const messagesToRemove = rawMessagesToRemove - (rawMessagesToRemove % 2)
 
+	if (messagesToRemove <= 0) {
+		// Nothing to truncate
+		return {
+			messages,
+			truncationId,
+			messagesRemoved: 0,
+		}
+	}
+
+	// Get the indices of visible messages to truncate (skip first visible, take next N)
+	const indicesToTruncate = new Set(visibleIndices.slice(1, messagesToRemove + 1))
+
 	// Tag messages that are being "truncated" (hidden from API calls)
 	const taggedMessages = messages.map((msg, index) => {
-		if (index > 0 && index <= messagesToRemove) {
+		if (indicesToTruncate.has(index)) {
 			return { ...msg, truncationParent: truncationId }
 		}
 		return msg
 	})
 
-	// Insert truncation marker after first message (so we know a truncation happened)
-	const firstKeptTs = messages[messagesToRemove + 1]?.ts ?? Date.now()
+	// Find the actual boundary - the index right after the last truncated message
+	const lastTruncatedVisibleIndex = visibleIndices[messagesToRemove] // Last visible message being truncated
+	// If all visible messages except the first are truncated, insert marker at the end
+	const firstKeptVisibleIndex = visibleIndices[messagesToRemove + 1] ?? taggedMessages.length
+
+	// Insert truncation marker at the actual boundary (between last truncated and first kept)
+	const firstKeptTs = messages[firstKeptVisibleIndex]?.ts ?? Date.now()
 	const truncationMarker: ApiMessage = {
-		role: "assistant",
+		role: "user",
 		content: `[Sliding window truncation: ${messagesToRemove} messages hidden to reduce context]`,
 		ts: firstKeptTs - 1,
 		isTruncationMarker: true,
 		truncationId,
 	}
 
-	// Insert marker after first message
-	const result = [taggedMessages[0], truncationMarker, ...taggedMessages.slice(1)]
+	// Insert marker at the boundary position
+	// Find where to insert: right before the first kept visible message
+	const insertPosition = firstKeptVisibleIndex
+	const result = [
+		...taggedMessages.slice(0, insertPosition),
+		truncationMarker,
+		...taggedMessages.slice(insertPosition),
+	]
 
 	return {
 		messages: result,