فهرست منبع

fix: Tackling Race/State condition issue by Changing the Code Design for Gemini Grounding Sources (#7434)

Co-authored-by: daniel-lxs <[email protected]>
Co-authored-by: Matt Rubens <[email protected]>
Ton Hoang Nguyen (Bill) 5 ماه پیش
والد
کامیت
c206da4a26

+ 15 - 8
src/api/providers/__tests__/gemini-handler.spec.ts

@@ -85,7 +85,7 @@ describe("GeminiHandler backend support", () => {
 							groundingMetadata: {
 							groundingMetadata: {
 								groundingChunks: [
 								groundingChunks: [
 									{ web: null }, // Missing URI
 									{ web: null }, // Missing URI
-									{ web: { uri: "https://example.com" } }, // Valid
+									{ web: { uri: "https://example.com", title: "Example Site" } }, // Valid
 									{}, // Missing web property entirely
 									{}, // Missing web property entirely
 								],
 								],
 							},
 							},
@@ -105,13 +105,20 @@ describe("GeminiHandler backend support", () => {
 				messages.push(chunk)
 				messages.push(chunk)
 			}
 			}
 
 
-			// Should only include valid citations
-			const sourceMessage = messages.find((m) => m.type === "text" && m.text?.includes("[2]"))
-			expect(sourceMessage).toBeDefined()
-			if (sourceMessage && "text" in sourceMessage) {
-				expect(sourceMessage.text).toContain("https://example.com")
-				expect(sourceMessage.text).not.toContain("[1]")
-				expect(sourceMessage.text).not.toContain("[3]")
+			// Should have the text response
+			const textMessage = messages.find((m) => m.type === "text")
+			expect(textMessage).toBeDefined()
+			if (textMessage && "text" in textMessage) {
+				expect(textMessage.text).toBe("test response")
+			}
+
+			// Should have grounding chunk with only valid sources
+			const groundingMessage = messages.find((m) => m.type === "grounding")
+			expect(groundingMessage).toBeDefined()
+			if (groundingMessage && "sources" in groundingMessage) {
+				expect(groundingMessage.sources).toHaveLength(1)
+				expect(groundingMessage.sources[0].url).toBe("https://example.com")
+				expect(groundingMessage.sources[0].title).toBe("Example Site")
 			}
 			}
 		})
 		})
 
 

+ 23 - 13
src/api/providers/gemini.ts

@@ -15,7 +15,7 @@ import { safeJsonParse } from "../../shared/safeJsonParse"
 
 
 import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
 import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
 import { t } from "i18next"
 import { t } from "i18next"
-import type { ApiStream } from "../transform/stream"
+import type { ApiStream, GroundingSource } from "../transform/stream"
 import { getModelParams } from "../transform/model-params"
 import { getModelParams } from "../transform/model-params"
 
 
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
 import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
@@ -132,9 +132,9 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
 			}
 			}
 
 
 			if (pendingGroundingMetadata) {
 			if (pendingGroundingMetadata) {
-				const citations = this.extractCitationsOnly(pendingGroundingMetadata)
-				if (citations) {
-					yield { type: "text", text: `\n\n${t("common:errors.gemini.sources")} ${citations}` }
+				const sources = this.extractGroundingSources(pendingGroundingMetadata)
+				if (sources.length > 0) {
+					yield { type: "grounding", sources }
 				}
 				}
 			}
 			}
 
 
@@ -175,28 +175,38 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
 		return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params }
 		return { id: id.endsWith(":thinking") ? id.replace(":thinking", "") : id, info, ...params }
 	}
 	}
 
 
-	private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null {
+	private extractGroundingSources(groundingMetadata?: GroundingMetadata): GroundingSource[] {
 		const chunks = groundingMetadata?.groundingChunks
 		const chunks = groundingMetadata?.groundingChunks
 
 
 		if (!chunks) {
 		if (!chunks) {
-			return null
+			return []
 		}
 		}
 
 
-		const citationLinks = chunks
-			.map((chunk, i) => {
+		return chunks
+			.map((chunk): GroundingSource | null => {
 				const uri = chunk.web?.uri
 				const uri = chunk.web?.uri
+				const title = chunk.web?.title || uri || "Unknown Source"
+
 				if (uri) {
 				if (uri) {
-					return `[${i + 1}](${uri})`
+					return {
+						title,
+						url: uri,
+					}
 				}
 				}
 				return null
 				return null
 			})
 			})
-			.filter((link): link is string => link !== null)
+			.filter((source): source is GroundingSource => source !== null)
+	}
+
+	private extractCitationsOnly(groundingMetadata?: GroundingMetadata): string | null {
+		const sources = this.extractGroundingSources(groundingMetadata)
 
 
-		if (citationLinks.length > 0) {
-			return citationLinks.join(", ")
+		if (sources.length === 0) {
+			return null
 		}
 		}
 
 
-		return null
+		const citationLinks = sources.map((source, i) => `[${i + 1}](${source.url})`)
+		return citationLinks.join(", ")
 	}
 	}
 
 
 	async completePrompt(prompt: string): Promise<string> {
 	async completePrompt(prompt: string): Promise<string> {

+ 17 - 1
src/api/transform/stream.ts

@@ -1,6 +1,11 @@
 export type ApiStream = AsyncGenerator<ApiStreamChunk>
 export type ApiStream = AsyncGenerator<ApiStreamChunk>
 
 
-export type ApiStreamChunk = ApiStreamTextChunk | ApiStreamUsageChunk | ApiStreamReasoningChunk | ApiStreamError
+export type ApiStreamChunk =
+	| ApiStreamTextChunk
+	| ApiStreamUsageChunk
+	| ApiStreamReasoningChunk
+	| ApiStreamGroundingChunk
+	| ApiStreamError
 
 
 export interface ApiStreamError {
 export interface ApiStreamError {
 	type: "error"
 	type: "error"
@@ -27,3 +32,14 @@ export interface ApiStreamUsageChunk {
 	reasoningTokens?: number
 	reasoningTokens?: number
 	totalCost?: number
 	totalCost?: number
 }
 }
+
+export interface ApiStreamGroundingChunk {
+	type: "grounding"
+	sources: GroundingSource[]
+}
+
+export interface GroundingSource {
+	title: string
+	url: string
+	snippet?: string
+}

+ 20 - 2
src/core/task/Task.ts

@@ -41,7 +41,7 @@ import { CloudService, BridgeOrchestrator } from "@roo-code/cloud"
 
 
 // api
 // api
 import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api"
 import { ApiHandler, ApiHandlerCreateMessageMetadata, buildApiHandler } from "../../api"
-import { ApiStream } from "../../api/transform/stream"
+import { ApiStream, GroundingSource } from "../../api/transform/stream"
 import { maybeRemoveImageBlocks } from "../../api/transform/image-cleaning"
 import { maybeRemoveImageBlocks } from "../../api/transform/image-cleaning"
 
 
 // shared
 // shared
@@ -1897,7 +1897,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 					this.didFinishAbortingStream = true
 					this.didFinishAbortingStream = true
 				}
 				}
 
 
-				// Reset streaming state.
+				// Reset streaming state for each new API request
 				this.currentStreamingContentIndex = 0
 				this.currentStreamingContentIndex = 0
 				this.currentStreamingDidCheckpoint = false
 				this.currentStreamingDidCheckpoint = false
 				this.assistantMessageContent = []
 				this.assistantMessageContent = []
@@ -1918,6 +1918,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 				const stream = this.attemptApiRequest()
 				const stream = this.attemptApiRequest()
 				let assistantMessage = ""
 				let assistantMessage = ""
 				let reasoningMessage = ""
 				let reasoningMessage = ""
+				let pendingGroundingSources: GroundingSource[] = []
 				this.isStreaming = true
 				this.isStreaming = true
 
 
 				try {
 				try {
@@ -1944,6 +1945,13 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 								cacheReadTokens += chunk.cacheReadTokens ?? 0
 								cacheReadTokens += chunk.cacheReadTokens ?? 0
 								totalCost = chunk.totalCost
 								totalCost = chunk.totalCost
 								break
 								break
+							case "grounding":
+								// Handle grounding sources separately from regular content
+								// to prevent state persistence issues - store them separately
+								if (chunk.sources && chunk.sources.length > 0) {
+									pendingGroundingSources.push(...chunk.sources)
+								}
+								break
 							case "text": {
 							case "text": {
 								assistantMessage += chunk.text
 								assistantMessage += chunk.text
 
 
@@ -2237,6 +2245,16 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
 				let didEndLoop = false
 				let didEndLoop = false
 
 
 				if (assistantMessage.length > 0) {
 				if (assistantMessage.length > 0) {
+					// Display grounding sources to the user if they exist
+					if (pendingGroundingSources.length > 0) {
+						const citationLinks = pendingGroundingSources.map((source, i) => `[${i + 1}](${source.url})`)
+						const sourcesText = `${t("common:gemini.sources")} ${citationLinks.join(", ")}`
+
+						await this.say("text", sourcesText, undefined, false, undefined, undefined, {
+							isNonInteractive: true,
+						})
+					}
+
 					await this.addToApiConversationHistory({
 					await this.addToApiConversationHistory({
 						role: "assistant",
 						role: "assistant",
 						content: [{ type: "text", text: assistantMessage }],
 						content: [{ type: "text", text: assistantMessage }],

+ 226 - 0
src/core/task/__tests__/grounding-sources.test.ts

@@ -0,0 +1,226 @@
+import { describe, it, expect, vi, beforeEach, beforeAll } from "vitest"
+import type { ClineProvider } from "../../webview/ClineProvider"
+import type { ProviderSettings } from "@roo-code/types"
+
+// Mock vscode module before importing Task
+vi.mock("vscode", () => ({
+	workspace: {
+		createFileSystemWatcher: vi.fn(() => ({
+			onDidCreate: vi.fn(),
+			onDidChange: vi.fn(),
+			onDidDelete: vi.fn(),
+			dispose: vi.fn(),
+		})),
+		getConfiguration: vi.fn(() => ({
+			get: vi.fn(() => true),
+		})),
+		openTextDocument: vi.fn(),
+		applyEdit: vi.fn(),
+	},
+	RelativePattern: vi.fn((base, pattern) => ({ base, pattern })),
+	window: {
+		createOutputChannel: vi.fn(() => ({
+			appendLine: vi.fn(),
+			dispose: vi.fn(),
+		})),
+		createTextEditorDecorationType: vi.fn(() => ({
+			dispose: vi.fn(),
+		})),
+		showTextDocument: vi.fn(),
+		activeTextEditor: undefined,
+	},
+	Uri: {
+		file: vi.fn((path) => ({ fsPath: path })),
+		parse: vi.fn((str) => ({ toString: () => str })),
+	},
+	Range: vi.fn(),
+	Position: vi.fn(),
+	WorkspaceEdit: vi.fn(() => ({
+		replace: vi.fn(),
+		insert: vi.fn(),
+		delete: vi.fn(),
+	})),
+	ViewColumn: {
+		One: 1,
+		Two: 2,
+		Three: 3,
+	},
+}))
+
+// Mock other dependencies
+vi.mock("../../services/mcp/McpServerManager", () => ({
+	McpServerManager: {
+		getInstance: vi.fn().mockResolvedValue(null),
+	},
+}))
+
+vi.mock("../../integrations/terminal/TerminalRegistry", () => ({
+	TerminalRegistry: {
+		releaseTerminalsForTask: vi.fn(),
+	},
+}))
+
+vi.mock("@roo-code/telemetry", () => ({
+	TelemetryService: {
+		instance: {
+			captureTaskCreated: vi.fn(),
+			captureTaskRestarted: vi.fn(),
+			captureConversationMessage: vi.fn(),
+			captureLlmCompletion: vi.fn(),
+			captureConsecutiveMistakeError: vi.fn(),
+		},
+	},
+}))
+
+describe("Task grounding sources handling", () => {
+	let mockProvider: Partial<ClineProvider>
+	let mockApiConfiguration: ProviderSettings
+	let Task: any
+
+	beforeAll(async () => {
+		// Import Task after mocks are set up
+		const taskModule = await import("../Task")
+		Task = taskModule.Task
+	})
+
+	beforeEach(() => {
+		// Mock provider with necessary methods
+		mockProvider = {
+			postStateToWebview: vi.fn().mockResolvedValue(undefined),
+			getState: vi.fn().mockResolvedValue({
+				mode: "code",
+				experiments: {},
+			}),
+			context: {
+				globalStorageUri: { fsPath: "/test/storage" },
+				extensionPath: "/test/extension",
+			} as any,
+			log: vi.fn(),
+			updateTaskHistory: vi.fn().mockResolvedValue(undefined),
+			postMessageToWebview: vi.fn().mockResolvedValue(undefined),
+		}
+
+		mockApiConfiguration = {
+			apiProvider: "gemini",
+			geminiApiKey: "test-key",
+			enableGrounding: true,
+		} as ProviderSettings
+	})
+
+	it("should strip grounding sources from assistant message before persisting to API history", async () => {
+		// Create a task instance
+		const task = new Task({
+			provider: mockProvider as ClineProvider,
+			apiConfiguration: mockApiConfiguration,
+			task: "Test task",
+			startTask: false,
+		})
+
+		// Mock the API conversation history
+		task.apiConversationHistory = []
+
+		// Simulate an assistant message with grounding sources
+		const assistantMessageWithSources = `
+This is the main response content.
+
+[1] Example Source: https://example.com
+[2] Another Source: https://another.com
+
+Sources: [1](https://example.com), [2](https://another.com)
+		`.trim()
+
+		// Mock grounding sources
+		const mockGroundingSources = [
+			{ title: "Example Source", url: "https://example.com" },
+			{ title: "Another Source", url: "https://another.com" },
+		]
+
+		// Spy on addToApiConversationHistory to check what gets persisted
+		const addToApiHistorySpy = vi.spyOn(task as any, "addToApiConversationHistory")
+
+		// Simulate the logic from Task.ts that strips grounding sources
+		let cleanAssistantMessage = assistantMessageWithSources
+		if (mockGroundingSources.length > 0) {
+			cleanAssistantMessage = assistantMessageWithSources
+				.replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "") // e.g., "[1] Example Source: https://example.com"
+				.replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "") // e.g., "Sources: [1](url1), [2](url2)"
+				.trim()
+		}
+
+		// Add the cleaned message to API history
+		await (task as any).addToApiConversationHistory({
+			role: "assistant",
+			content: [{ type: "text", text: cleanAssistantMessage }],
+		})
+
+		// Verify that the cleaned message was added without grounding sources
+		expect(addToApiHistorySpy).toHaveBeenCalledWith({
+			role: "assistant",
+			content: [{ type: "text", text: "This is the main response content." }],
+		})
+
+		// Verify the API conversation history contains the cleaned message
+		expect(task.apiConversationHistory).toHaveLength(1)
+		expect(task.apiConversationHistory[0].content).toEqual([
+			{ type: "text", text: "This is the main response content." },
+		])
+	})
+
+	it("should not modify assistant message when no grounding sources are present", async () => {
+		const task = new Task({
+			provider: mockProvider as ClineProvider,
+			apiConfiguration: mockApiConfiguration,
+			task: "Test task",
+			startTask: false,
+		})
+
+		task.apiConversationHistory = []
+
+		const assistantMessage = "This is a regular response without any sources."
+		const mockGroundingSources: any[] = [] // No grounding sources
+
+		// Apply the same logic
+		let cleanAssistantMessage = assistantMessage
+		if (mockGroundingSources.length > 0) {
+			cleanAssistantMessage = assistantMessage
+				.replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "")
+				.replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "")
+				.trim()
+		}
+
+		await (task as any).addToApiConversationHistory({
+			role: "assistant",
+			content: [{ type: "text", text: cleanAssistantMessage }],
+		})
+
+		// Message should remain unchanged
+		expect(task.apiConversationHistory[0].content).toEqual([
+			{ type: "text", text: "This is a regular response without any sources." },
+		])
+	})
+
+	it("should handle various grounding source formats", () => {
+		const testCases = [
+			{
+				input: "[1] Source Title: https://example.com\n[2] Another: https://test.com\nMain content here",
+				expected: "Main content here",
+			},
+			{
+				input: "Content first\n\nSources: [1](https://example.com), [2](https://test.com)",
+				expected: "Content first",
+			},
+			{
+				input: "Mixed content\n[1] Inline Source: https://inline.com\nMore content\nSource: [1](https://inline.com)",
+				expected: "Mixed content\n\nMore content",
+			},
+		]
+
+		testCases.forEach(({ input, expected }) => {
+			const cleaned = input
+				.replace(/\[\d+\]\s+[^:\n]+:\s+https?:\/\/[^\s\n]+/g, "")
+				.replace(/Sources?:\s*[\s\S]*?(?=\n\n|\n$|$)/g, "")
+				.trim()
+			expect(cleaned).toBe(expected)
+		})
+	})
+})