Просмотр исходного кода

Support Gemini 2.5 Flash thinking (#2752)

Chris Estreich 8 месяцев назад
Родитель
Сommit
5abea50cf1

+ 5 - 0
.changeset/shiny-poems-search.md

@@ -0,0 +1,5 @@
+---
+"roo-cline": patch
+---
+
+Support Gemini 2.5 Flash thinking mode

+ 30 - 5
package-lock.json

@@ -13,7 +13,7 @@
 				"@anthropic-ai/vertex-sdk": "^0.7.0",
 				"@aws-sdk/client-bedrock-runtime": "^3.779.0",
 				"@google-cloud/vertexai": "^1.9.3",
-				"@google/generative-ai": "^0.18.0",
+				"@google/genai": "^0.9.0",
 				"@mistralai/mistralai": "^1.3.6",
 				"@modelcontextprotocol/sdk": "^1.7.0",
 				"@types/clone-deep": "^4.0.4",
@@ -5781,14 +5781,39 @@
 				"node": ">=18.0.0"
 			}
 		},
-		"node_modules/@google/generative-ai": {
-			"version": "0.18.0",
-			"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.18.0.tgz",
-			"integrity": "sha512-AhaIWSpk2tuhYHrBhUqC0xrWWznmYEja1/TRDIb+5kruBU5kUzMlFsXCQNO9PzyTZ4clUJ3CX/Rvy+Xm9x+w3g==",
+		"node_modules/@google/genai": {
+			"version": "0.9.0",
+			"resolved": "https://registry.npmjs.org/@google/genai/-/genai-0.9.0.tgz",
+			"integrity": "sha512-FD2RizYGInsvfjeaN6O+wQGpRnGVglS1XWrGQr8K7D04AfMmvPodDSw94U9KyFtsVLzWH9kmlPyFM+G4jbmkqg==",
+			"license": "Apache-2.0",
+			"dependencies": {
+				"google-auth-library": "^9.14.2",
+				"ws": "^8.18.0",
+				"zod": "^3.22.4",
+				"zod-to-json-schema": "^3.22.4"
+			},
 			"engines": {
 				"node": ">=18.0.0"
 			}
 		},
+		"node_modules/@google/genai/node_modules/zod": {
+			"version": "3.24.3",
+			"resolved": "https://registry.npmjs.org/zod/-/zod-3.24.3.tgz",
+			"integrity": "sha512-HhY1oqzWCQWuUqvBFnsyrtZRhyPeR7SUGv+C4+MsisMuVfSPx8HpwWqH8tRahSlt6M3PiFAcoeFhZAqIXTxoSg==",
+			"license": "MIT",
+			"funding": {
+				"url": "https://github.com/sponsors/colinhacks"
+			}
+		},
+		"node_modules/@google/genai/node_modules/zod-to-json-schema": {
+			"version": "3.24.5",
+			"resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.5.tgz",
+			"integrity": "sha512-/AuWwMP+YqiPbsJx5D6TfgRTc4kTLjsh5SOcd4bLsfUg2RcEXrFMJl1DGgdHy2aCfsIA/cr/1JM0xcB2GZji8g==",
+			"license": "ISC",
+			"peerDependencies": {
+				"zod": "^3.24.1"
+			}
+		},
 		"node_modules/@humanwhocodes/config-array": {
 			"version": "0.13.0",
 			"resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.13.0.tgz",

+ 1 - 1
package.json

@@ -405,7 +405,7 @@
 		"@anthropic-ai/vertex-sdk": "^0.7.0",
 		"@aws-sdk/client-bedrock-runtime": "^3.779.0",
 		"@google-cloud/vertexai": "^1.9.3",
-		"@google/generative-ai": "^0.18.0",
+		"@google/genai": "^0.9.0",
 		"@mistralai/mistralai": "^1.3.6",
 		"@modelcontextprotocol/sdk": "^1.7.0",
 		"@types/clone-deep": "^4.0.4",

+ 52 - 98
src/api/providers/__tests__/gemini.test.ts

@@ -1,45 +1,41 @@
-import { GeminiHandler } from "../gemini"
+// npx jest src/api/providers/__tests__/gemini.test.ts
+
 import { Anthropic } from "@anthropic-ai/sdk"
-import { GoogleGenerativeAI } from "@google/generative-ai"
-
-// Mock the Google Generative AI SDK
-jest.mock("@google/generative-ai", () => ({
-	GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
-		getGenerativeModel: jest.fn().mockReturnValue({
-			generateContentStream: jest.fn(),
-			generateContent: jest.fn().mockResolvedValue({
-				response: {
-					text: () => "Test response",
-				},
-			}),
-		}),
-	})),
-}))
+
+import { GeminiHandler } from "../gemini"
+import { geminiDefaultModelId } from "../../../shared/api"
+
+const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"
 
 describe("GeminiHandler", () => {
 	let handler: GeminiHandler
 
 	beforeEach(() => {
+		// Create mock functions
+		const mockGenerateContentStream = jest.fn()
+		const mockGenerateContent = jest.fn()
+		const mockGetGenerativeModel = jest.fn()
+
 		handler = new GeminiHandler({
 			apiKey: "test-key",
-			apiModelId: "gemini-2.0-flash-thinking-exp-1219",
+			apiModelId: GEMINI_20_FLASH_THINKING_NAME,
 			geminiApiKey: "test-key",
 		})
+
+		// Replace the client with our mock
+		handler["client"] = {
+			models: {
+				generateContentStream: mockGenerateContentStream,
+				generateContent: mockGenerateContent,
+				getGenerativeModel: mockGetGenerativeModel,
+			},
+		} as any
 	})
 
 	describe("constructor", () => {
 		it("should initialize with provided config", () => {
 			expect(handler["options"].geminiApiKey).toBe("test-key")
-			expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
-		})
-
-		it.skip("should throw if API key is missing", () => {
-			expect(() => {
-				new GeminiHandler({
-					apiModelId: "gemini-2.0-flash-thinking-exp-1219",
-					geminiApiKey: "",
-				})
-			}).toThrow("API key is required for Google Gemini")
+			expect(handler["options"].apiModelId).toBe(GEMINI_20_FLASH_THINKING_NAME)
 		})
 	})
 
@@ -58,25 +54,15 @@ describe("GeminiHandler", () => {
 		const systemPrompt = "You are a helpful assistant"
 
 		it("should handle text messages correctly", async () => {
-			// Mock the stream response
-			const mockStream = {
-				stream: [{ text: () => "Hello" }, { text: () => " world!" }],
-				response: {
-					usageMetadata: {
-						promptTokenCount: 10,
-						candidatesTokenCount: 5,
-					},
+			// Setup the mock implementation to return an async generator
+			;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({
+				[Symbol.asyncIterator]: async function* () {
+					yield { text: "Hello" }
+					yield { text: " world!" }
+					yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
 				},
-			}
-
-			// Setup the mock implementation
-			const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
-			const mockGetGenerativeModel = jest.fn().mockReturnValue({
-				generateContentStream: mockGenerateContentStream,
 			})
 
-			;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
-
 			const stream = handler.createMessage(systemPrompt, mockMessages)
 			const chunks = []
 
@@ -100,35 +86,21 @@ describe("GeminiHandler", () => {
 				outputTokens: 5,
 			})
 
-			// Verify the model configuration
-			expect(mockGetGenerativeModel).toHaveBeenCalledWith(
-				{
-					model: "gemini-2.0-flash-thinking-exp-1219",
-					systemInstruction: systemPrompt,
-				},
-				{
-					baseUrl: undefined,
-				},
-			)
-
-			// Verify generation config
-			expect(mockGenerateContentStream).toHaveBeenCalledWith(
+			// Verify the call to generateContentStream
+			expect(handler["client"].models.generateContentStream).toHaveBeenCalledWith(
 				expect.objectContaining({
-					generationConfig: {
+					model: GEMINI_20_FLASH_THINKING_NAME,
+					config: expect.objectContaining({
 						temperature: 0,
-					},
+						systemInstruction: systemPrompt,
+					}),
 				}),
 			)
 		})
 
 		it("should handle API errors", async () => {
 			const mockError = new Error("Gemini API error")
-			const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
-			const mockGetGenerativeModel = jest.fn().mockReturnValue({
-				generateContentStream: mockGenerateContentStream,
-			})
-
-			;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
+			;(handler["client"].models.generateContentStream as jest.Mock).mockRejectedValue(mockError)
 
 			const stream = handler.createMessage(systemPrompt, mockMessages)
 
@@ -136,35 +108,26 @@ describe("GeminiHandler", () => {
 				for await (const chunk of stream) {
 					// Should throw before yielding any chunks
 				}
-			}).rejects.toThrow("Gemini API error")
+			}).rejects.toThrow()
 		})
 	})
 
 	describe("completePrompt", () => {
 		it("should complete prompt successfully", async () => {
-			const mockGenerateContent = jest.fn().mockResolvedValue({
-				response: {
-					text: () => "Test response",
-				},
-			})
-			const mockGetGenerativeModel = jest.fn().mockReturnValue({
-				generateContent: mockGenerateContent,
+			// Mock the response with text property
+			;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({
+				text: "Test response",
 			})
-			;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
 
 			const result = await handler.completePrompt("Test prompt")
 			expect(result).toBe("Test response")
-			expect(mockGetGenerativeModel).toHaveBeenCalledWith(
-				{
-					model: "gemini-2.0-flash-thinking-exp-1219",
-				},
-				{
-					baseUrl: undefined,
-				},
-			)
-			expect(mockGenerateContent).toHaveBeenCalledWith({
+
+			// Verify the call to generateContent
+			expect(handler["client"].models.generateContent).toHaveBeenCalledWith({
+				model: GEMINI_20_FLASH_THINKING_NAME,
 				contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
-				generationConfig: {
+				config: {
+					httpOptions: undefined,
 					temperature: 0,
 				},
 			})
@@ -172,11 +135,7 @@ describe("GeminiHandler", () => {
 
 		it("should handle API errors", async () => {
 			const mockError = new Error("Gemini API error")
-			const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
-			const mockGetGenerativeModel = jest.fn().mockReturnValue({
-				generateContent: mockGenerateContent,
-			})
-			;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
+			;(handler["client"].models.generateContent as jest.Mock).mockRejectedValue(mockError)
 
 			await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
 				"Gemini completion error: Gemini API error",
@@ -184,15 +143,10 @@ describe("GeminiHandler", () => {
 		})
 
 		it("should handle empty response", async () => {
-			const mockGenerateContent = jest.fn().mockResolvedValue({
-				response: {
-					text: () => "",
-				},
-			})
-			const mockGetGenerativeModel = jest.fn().mockReturnValue({
-				generateContent: mockGenerateContent,
+			// Mock the response with empty text
+			;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({
+				text: "",
 			})
-			;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
 
 			const result = await handler.completePrompt("Test prompt")
 			expect(result).toBe("")
@@ -202,7 +156,7 @@ describe("GeminiHandler", () => {
 	describe("getModel", () => {
 		it("should return correct model info", () => {
 			const modelInfo = handler.getModel()
-			expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219")
+			expect(modelInfo.id).toBe(GEMINI_20_FLASH_THINKING_NAME)
 			expect(modelInfo.info).toBeDefined()
 			expect(modelInfo.info.maxTokens).toBe(8192)
 			expect(modelInfo.info.contextWindow).toBe(32_767)
@@ -214,7 +168,7 @@ describe("GeminiHandler", () => {
 				geminiApiKey: "test-key",
 			})
 			const modelInfo = invalidHandler.getModel()
-			expect(modelInfo.id).toBe("gemini-2.0-flash-001") // Default model
+			expect(modelInfo.id).toBe(geminiDefaultModelId) // Default model
 		})
 	})
 })

+ 6 - 10
src/api/providers/anthropic.ts

@@ -23,6 +23,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
 
 		const apiKeyFieldName =
 			this.options.anthropicBaseUrl && this.options.anthropicUseAuthToken ? "authToken" : "apiKey"
+
 		this.client = new Anthropic({
 			baseURL: this.options.anthropicBaseUrl || undefined,
 			[apiKeyFieldName]: this.options.apiKey,
@@ -217,10 +218,10 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
 	}
 
 	async completePrompt(prompt: string) {
-		let { id: modelId, temperature } = this.getModel()
+		let { id: model, temperature } = this.getModel()
 
 		const message = await this.client.messages.create({
-			model: modelId,
+			model,
 			max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS,
 			thinking: undefined,
 			temperature,
@@ -241,16 +242,11 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
 	override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
 		try {
 			// Use the current model
-			const actualModelId = this.getModel().id
+			const { id: model } = this.getModel()
 
 			const response = await this.client.messages.countTokens({
-				model: actualModelId,
-				messages: [
-					{
-						role: "user",
-						content: content,
-					},
-				],
+				model,
+				messages: [{ role: "user", content: content }],
 			})
 
 			return response.input_tokens

+ 104 - 51
src/api/providers/gemini.ts

@@ -1,89 +1,142 @@
-import { Anthropic } from "@anthropic-ai/sdk"
-import { GoogleGenerativeAI } from "@google/generative-ai"
+import type { Anthropic } from "@anthropic-ai/sdk"
+import {
+	GoogleGenAI,
+	ThinkingConfig,
+	type GenerateContentResponseUsageMetadata,
+	type GenerateContentParameters,
+} from "@google/genai"
+
 import { SingleCompletionHandler } from "../"
-import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../../shared/api"
-import { convertAnthropicMessageToGemini } from "../transform/gemini-format"
-import { ApiStream } from "../transform/stream"
+import type { ApiHandlerOptions, GeminiModelId, ModelInfo } from "../../shared/api"
+import { geminiDefaultModelId, geminiModels } from "../../shared/api"
+import { convertAnthropicContentToGemini, convertAnthropicMessageToGemini } from "../transform/gemini-format"
+import type { ApiStream } from "../transform/stream"
 import { BaseProvider } from "./base-provider"
 
-const GEMINI_DEFAULT_TEMPERATURE = 0
-
 export class GeminiHandler extends BaseProvider implements SingleCompletionHandler {
 	protected options: ApiHandlerOptions
-	private client: GoogleGenerativeAI
+	private client: GoogleGenAI
 
 	constructor(options: ApiHandlerOptions) {
 		super()
 		this.options = options
-		this.client = new GoogleGenerativeAI(options.geminiApiKey ?? "not-provided")
+		this.client = new GoogleGenAI({ apiKey: options.geminiApiKey ?? "not-provided" })
 	}
 
-	override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
-		const model = this.client.getGenerativeModel(
-			{
-				model: this.getModel().id,
-				systemInstruction: systemPrompt,
-			},
-			{
-				baseUrl: this.options.googleGeminiBaseUrl || undefined,
-			},
-		)
-		const result = await model.generateContentStream({
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+		const { id: model, thinkingConfig, maxOutputTokens } = this.getModel()
+
+		const params: GenerateContentParameters = {
+			model,
 			contents: messages.map(convertAnthropicMessageToGemini),
-			generationConfig: {
-				// maxOutputTokens: this.getModel().info.maxTokens,
-				temperature: this.options.modelTemperature ?? GEMINI_DEFAULT_TEMPERATURE,
+			config: {
+				thinkingConfig,
+				maxOutputTokens,
+				temperature: this.options.modelTemperature ?? 0,
+				systemInstruction: systemPrompt,
 			},
-		})
+		}
 
-		for await (const chunk of result.stream) {
-			yield {
-				type: "text",
-				text: chunk.text(),
+		const result = await this.client.models.generateContentStream(params)
+
+		let lastUsageMetadata: GenerateContentResponseUsageMetadata | undefined
+
+		for await (const chunk of result) {
+			if (chunk.text) {
+				yield { type: "text", text: chunk.text }
+			}
+
+			if (chunk.usageMetadata) {
+				lastUsageMetadata = chunk.usageMetadata
 			}
 		}
 
-		const response = await result.response
-		yield {
-			type: "usage",
-			inputTokens: response.usageMetadata?.promptTokenCount ?? 0,
-			outputTokens: response.usageMetadata?.candidatesTokenCount ?? 0,
+		if (lastUsageMetadata) {
+			yield {
+				type: "usage",
+				inputTokens: lastUsageMetadata.promptTokenCount ?? 0,
+				outputTokens: lastUsageMetadata.candidatesTokenCount ?? 0,
+			}
 		}
 	}
 
-	override getModel(): { id: GeminiModelId; info: ModelInfo } {
-		const modelId = this.options.apiModelId
-		if (modelId && modelId in geminiModels) {
-			const id = modelId as GeminiModelId
-			return { id, info: geminiModels[id] }
+	override getModel(): {
+		id: GeminiModelId
+		info: ModelInfo
+		thinkingConfig?: ThinkingConfig
+		maxOutputTokens?: number
+	} {
+		let id = this.options.apiModelId ? (this.options.apiModelId as GeminiModelId) : geminiDefaultModelId
+		let info: ModelInfo = geminiModels[id]
+		let thinkingConfig: ThinkingConfig | undefined = undefined
+		let maxOutputTokens: number | undefined = undefined
+
+		const thinkingSuffix = ":thinking"
+
+		if (id?.endsWith(thinkingSuffix)) {
+			id = id.slice(0, -thinkingSuffix.length) as GeminiModelId
+			info = geminiModels[id]
+
+			thinkingConfig = this.options.modelMaxThinkingTokens
+				? { thinkingBudget: this.options.modelMaxThinkingTokens }
+				: undefined
+
+			maxOutputTokens = this.options.modelMaxTokens ?? info.maxTokens ?? undefined
+		}
+
+		if (!info) {
+			id = geminiDefaultModelId
+			info = geminiModels[geminiDefaultModelId]
+			thinkingConfig = undefined
+			maxOutputTokens = undefined
 		}
-		return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
+
+		return { id, info, thinkingConfig, maxOutputTokens }
 	}
 
 	async completePrompt(prompt: string): Promise<string> {
 		try {
-			const model = this.client.getGenerativeModel(
-				{
-					model: this.getModel().id,
-				},
-				{
-					baseUrl: this.options.googleGeminiBaseUrl || undefined,
-				},
-			)
+			const { id: model } = this.getModel()
 
-			const result = await model.generateContent({
+			const result = await this.client.models.generateContent({
+				model,
 				contents: [{ role: "user", parts: [{ text: prompt }] }],
-				generationConfig: {
-					temperature: this.options.modelTemperature ?? GEMINI_DEFAULT_TEMPERATURE,
+				config: {
+					httpOptions: this.options.googleGeminiBaseUrl
+						? { baseUrl: this.options.googleGeminiBaseUrl }
+						: undefined,
+					temperature: this.options.modelTemperature ?? 0,
 				},
 			})
 
-			return result.response.text()
+			return result.text ?? ""
 		} catch (error) {
 			if (error instanceof Error) {
 				throw new Error(`Gemini completion error: ${error.message}`)
 			}
+
 			throw error
 		}
 	}
+
+	override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
+		try {
+			const { id: model } = this.getModel()
+
+			const response = await this.client.models.countTokens({
+				model,
+				contents: convertAnthropicContentToGemini(content),
+			})
+
+			if (response.totalTokens === undefined) {
+				console.warn("Gemini token counting returned undefined, using fallback")
+				return super.countTokens(content)
+			}
+
+			return response.totalTokens
+		} catch (error) {
+			console.warn("Gemini token counting failed, using fallback", error)
+			return super.countTokens(content)
+		}
+	}
 }

+ 45 - 50
src/api/transform/gemini-format.ts

@@ -1,76 +1,71 @@
 import { Anthropic } from "@anthropic-ai/sdk"
-import { Content, FunctionCallPart, FunctionResponsePart, InlineDataPart, Part, TextPart } from "@google/generative-ai"
+import { Content, Part } from "@google/genai"
 
-function convertAnthropicContentToGemini(content: Anthropic.Messages.MessageParam["content"]): Part[] {
+export function convertAnthropicContentToGemini(content: string | Anthropic.ContentBlockParam[]): Part[] {
 	if (typeof content === "string") {
-		return [{ text: content } as TextPart]
+		return [{ text: content }]
 	}
 
-	return content.flatMap((block) => {
+	return content.flatMap((block): Part | Part[] => {
 		switch (block.type) {
 			case "text":
-				return { text: block.text } as TextPart
+				return { text: block.text }
 			case "image":
 				if (block.source.type !== "base64") {
 					throw new Error("Unsupported image source type")
 				}
-				return {
-					inlineData: {
-						data: block.source.data,
-						mimeType: block.source.media_type,
-					},
-				} as InlineDataPart
+
+				return { inlineData: { data: block.source.data, mimeType: block.source.media_type } }
 			case "tool_use":
 				return {
 					functionCall: {
 						name: block.name,
-						args: block.input,
+						args: block.input as Record<string, unknown>,
 					},
-				} as FunctionCallPart
-			case "tool_result":
-				const name = block.tool_use_id.split("-")[0]
+				}
+			case "tool_result": {
 				if (!block.content) {
 					return []
 				}
+
+				// Extract tool name from tool_use_id (e.g., "calculator-123" -> "calculator")
+				const toolName = block.tool_use_id.split("-")[0]
+
 				if (typeof block.content === "string") {
 					return {
-						functionResponse: {
-							name,
-							response: {
-								name,
-								content: block.content,
-							},
-						},
-					} as FunctionResponsePart
-				} else {
-					// The only case when tool_result could be array is when the tool failed and we're providing ie user feedback potentially with images
-					const textParts = block.content.filter((part) => part.type === "text")
-					const imageParts = block.content.filter((part) => part.type === "image")
-					const text = textParts.length > 0 ? textParts.map((part) => part.text).join("\n\n") : ""
-					const imageText = imageParts.length > 0 ? "\n\n(See next part for image)" : ""
-					return [
-						{
-							functionResponse: {
-								name,
-								response: {
-									name,
-									content: text + imageText,
-								},
-							},
-						} as FunctionResponsePart,
-						...imageParts.map(
-							(part) =>
-								({
-									inlineData: {
-										data: part.source.data,
-										mimeType: part.source.media_type,
-									},
-								}) as InlineDataPart,
-						),
-					]
+						functionResponse: { name: toolName, response: { name: toolName, content: block.content } },
+					}
+				}
+
+				if (!Array.isArray(block.content)) {
+					return []
 				}
+
+				const textParts: string[] = []
+				const imageParts: Part[] = []
+
+				for (const item of block.content) {
+					if (item.type === "text") {
+						textParts.push(item.text)
+					} else if (item.type === "image" && item.source.type === "base64") {
+						const { data, media_type } = item.source
+						imageParts.push({ inlineData: { data, mimeType: media_type } })
+					}
+				}
+
+				// Create content text with a note about images if present
+				const contentText =
+					textParts.join("\n\n") + (imageParts.length > 0 ? "\n\n(See next part for image)" : "")
+
+				// Return function response followed by any images
+				return [
+					{ functionResponse: { name: toolName, response: { name: toolName, content: contentText } } },
+					...imageParts,
+				]
+			}
 			default:
-				throw new Error(`Unsupported content block type: ${(block as any).type}`)
+				// Currently unsupported: "thinking" | "redacted_thinking" | "document"
+				throw new Error(`Unsupported content block type: ${block.type}`)
 		}
 	})
 }

+ 5 - 0
src/exports/roo-code.d.ts

@@ -31,6 +31,7 @@ type ProviderSettings = {
 	glamaModelInfo?:
 		| ({
 				maxTokens?: (number | null) | undefined
+				maxThinkingTokens?: (number | null) | undefined
 				contextWindow: number
 				supportsImages?: boolean | undefined
 				supportsComputerUse?: boolean | undefined
@@ -53,6 +54,7 @@ type ProviderSettings = {
 	openRouterModelInfo?:
 		| ({
 				maxTokens?: (number | null) | undefined
+				maxThinkingTokens?: (number | null) | undefined
 				contextWindow: number
 				supportsImages?: boolean | undefined
 				supportsComputerUse?: boolean | undefined
@@ -95,6 +97,7 @@ type ProviderSettings = {
 	openAiCustomModelInfo?:
 		| ({
 				maxTokens?: (number | null) | undefined
+				maxThinkingTokens?: (number | null) | undefined
 				contextWindow: number
 				supportsImages?: boolean | undefined
 				supportsComputerUse?: boolean | undefined
@@ -140,6 +143,7 @@ type ProviderSettings = {
 	unboundModelInfo?:
 		| ({
 				maxTokens?: (number | null) | undefined
+				maxThinkingTokens?: (number | null) | undefined
 				contextWindow: number
 				supportsImages?: boolean | undefined
 				supportsComputerUse?: boolean | undefined
@@ -161,6 +165,7 @@ type ProviderSettings = {
 	requestyModelInfo?:
 		| ({
 				maxTokens?: (number | null) | undefined
+				maxThinkingTokens?: (number | null) | undefined
 				contextWindow: number
 				supportsImages?: boolean | undefined
 				supportsComputerUse?: boolean | undefined

+ 5 - 0
src/exports/types.ts

@@ -32,6 +32,7 @@ type ProviderSettings = {
 	glamaModelInfo?:
 		| ({
 				maxTokens?: (number | null) | undefined
+				maxThinkingTokens?: (number | null) | undefined
 				contextWindow: number
 				supportsImages?: boolean | undefined
 				supportsComputerUse?: boolean | undefined
@@ -54,6 +55,7 @@ type ProviderSettings = {
 	openRouterModelInfo?:
 		| ({
 				maxTokens?: (number | null) | undefined
+				maxThinkingTokens?: (number | null) | undefined
 				contextWindow: number
 				supportsImages?: boolean | undefined
 				supportsComputerUse?: boolean | undefined
@@ -96,6 +98,7 @@ type ProviderSettings = {
 	openAiCustomModelInfo?:
 		| ({
 				maxTokens?: (number | null) | undefined
+				maxThinkingTokens?: (number | null) | undefined
 				contextWindow: number
 				supportsImages?: boolean | undefined
 				supportsComputerUse?: boolean | undefined
@@ -141,6 +144,7 @@ type ProviderSettings = {
 	unboundModelInfo?:
 		| ({
 				maxTokens?: (number | null) | undefined
+				maxThinkingTokens?: (number | null) | undefined
 				contextWindow: number
 				supportsImages?: boolean | undefined
 				supportsComputerUse?: boolean | undefined
@@ -162,6 +166,7 @@ type ProviderSettings = {
 	requestyModelInfo?:
 		| ({
 				maxTokens?: (number | null) | undefined
+				maxThinkingTokens?: (number | null) | undefined
 				contextWindow: number
 				supportsImages?: boolean | undefined
 				supportsComputerUse?: boolean | undefined

+ 1 - 0
src/schemas/index.ts

@@ -99,6 +99,7 @@ export type ReasoningEffort = z.infer<typeof reasoningEffortsSchema>
 
 export const modelInfoSchema = z.object({
 	maxTokens: z.number().nullish(),
+	maxThinkingTokens: z.number().nullish(),
 	contextWindow: z.number(),
 	supportsImages: z.boolean().optional(),
 	supportsComputerUse: z.boolean().optional(),

+ 22 - 0
src/shared/api.ts

@@ -485,6 +485,16 @@ export const vertexModels = {
 		inputPrice: 0.15,
 		outputPrice: 0.6,
 	},
+	"gemini-2.5-flash-preview-04-17:thinking": {
+		maxTokens: 65_535,
+		contextWindow: 1_048_576,
+		supportsImages: true,
+		supportsPromptCache: false,
+		inputPrice: 0.15,
+		outputPrice: 0.6,
+		thinking: true,
+		maxThinkingTokens: 24_576,
+	},
 	"gemini-2.5-flash-preview-04-17": {
 		maxTokens: 65_535,
 		contextWindow: 1_048_576,
@@ -492,6 +502,7 @@ export const vertexModels = {
 		supportsPromptCache: false,
 		inputPrice: 0.15,
 		outputPrice: 0.6,
+		thinking: false,
 	},
 	"gemini-2.5-pro-preview-03-25": {
 		maxTokens: 65_535,
@@ -640,6 +651,16 @@ export const openAiModelInfoSaneDefaults: ModelInfo = {
 export type GeminiModelId = keyof typeof geminiModels
 export const geminiDefaultModelId: GeminiModelId = "gemini-2.0-flash-001"
 export const geminiModels = {
+	"gemini-2.5-flash-preview-04-17:thinking": {
+		maxTokens: 65_535,
+		contextWindow: 1_048_576,
+		supportsImages: true,
+		supportsPromptCache: false,
+		inputPrice: 0.15,
+		outputPrice: 0.6,
+		thinking: true,
+		maxThinkingTokens: 24_576,
+	},
 	"gemini-2.5-flash-preview-04-17": {
 		maxTokens: 65_535,
 		contextWindow: 1_048_576,
@@ -647,6 +668,7 @@ export const geminiModels = {
 		supportsPromptCache: false,
 		inputPrice: 0.15,
 		outputPrice: 0.6,
+		thinking: false,
 	},
 	"gemini-2.5-pro-exp-03-25": {
 		maxTokens: 65_535,

+ 29 - 28
webview-ui/src/components/settings/ThinkingBudget.tsx

@@ -1,10 +1,13 @@
-import { useEffect, useMemo } from "react"
+import { useEffect } from "react"
 import { useAppTranslation } from "@/i18n/TranslationContext"
 
 import { Slider } from "@/components/ui"
 
 import { ApiConfiguration, ModelInfo } from "../../../../src/shared/api"
 
+const DEFAULT_MAX_OUTPUT_TOKENS = 16_384
+const DEFAULT_MAX_THINKING_TOKENS = 8_192
+
 interface ThinkingBudgetProps {
 	apiConfiguration: ApiConfiguration
 	setApiConfigurationField: <K extends keyof ApiConfiguration>(field: K, value: ApiConfiguration[K]) => void
@@ -13,57 +16,55 @@ interface ThinkingBudgetProps {
 
 export const ThinkingBudget = ({ apiConfiguration, setApiConfigurationField, modelInfo }: ThinkingBudgetProps) => {
 	const { t } = useAppTranslation()
-	const tokens = apiConfiguration?.modelMaxTokens || 16_384
-	const tokensMin = 8192
-	const tokensMax = modelInfo?.maxTokens || 64_000
 
-	// Get the appropriate thinking tokens based on provider
-	const thinkingTokens = useMemo(() => {
-		const value = apiConfiguration?.modelMaxThinkingTokens
-		return value || Math.min(Math.floor(0.8 * tokens), 8192)
-	}, [apiConfiguration, tokens])
+	const isThinkingModel = modelInfo && modelInfo.thinking && modelInfo.maxTokens
+
+	const customMaxOutputTokens = apiConfiguration.modelMaxTokens || DEFAULT_MAX_OUTPUT_TOKENS
+	const customMaxThinkingTokens = apiConfiguration.modelMaxThinkingTokens || DEFAULT_MAX_THINKING_TOKENS
 
-	const thinkingTokensMin = 1024
-	const thinkingTokensMax = Math.floor(0.8 * tokens)
+	// Dynamically expand or shrink the max thinking budget based on the custom
+	// max output tokens so that there's always a 20% buffer.
+	const modelMaxThinkingTokens = modelInfo?.maxThinkingTokens
+		? Math.min(modelInfo.maxThinkingTokens, Math.floor(0.8 * customMaxOutputTokens))
+		: Math.floor(0.8 * customMaxOutputTokens)
 
+	// If the custom max thinking tokens are going to exceed it's limit due
+	// to the custom max output tokens being reduced then we need to shrink it
+	// appropriately.
 	useEffect(() => {
-		if (thinkingTokens > thinkingTokensMax) {
-			setApiConfigurationField("modelMaxThinkingTokens", thinkingTokensMax)
+		if (isThinkingModel && customMaxThinkingTokens > modelMaxThinkingTokens) {
+			setApiConfigurationField("modelMaxThinkingTokens", modelMaxThinkingTokens)
 		}
-	}, [thinkingTokens, thinkingTokensMax, setApiConfigurationField])
-
-	if (!modelInfo?.thinking) {
-		return null
-	}
+	}, [isThinkingModel, customMaxThinkingTokens, modelMaxThinkingTokens, setApiConfigurationField])
 
-	return (
+	return isThinkingModel ? (
 		<>
 			<div className="flex flex-col gap-1">
 				<div className="font-medium">{t("settings:thinkingBudget.maxTokens")}</div>
 				<div className="flex items-center gap-1">
 					<Slider
-						min={tokensMin}
-						max={tokensMax}
+						min={8192}
+						max={modelInfo.maxTokens!}
 						step={1024}
-						value={[tokens]}
+						value={[customMaxOutputTokens]}
 						onValueChange={([value]) => setApiConfigurationField("modelMaxTokens", value)}
 					/>
-					<div className="w-12 text-sm text-center">{tokens}</div>
+					<div className="w-12 text-sm text-center">{customMaxOutputTokens}</div>
 				</div>
 			</div>
 			<div className="flex flex-col gap-1">
 				<div className="font-medium">{t("settings:thinkingBudget.maxThinkingTokens")}</div>
 				<div className="flex items-center gap-1">
 					<Slider
-						min={thinkingTokensMin}
-						max={thinkingTokensMax}
+						min={1024}
+						max={modelMaxThinkingTokens}
 						step={1024}
-						value={[thinkingTokens]}
+						value={[customMaxThinkingTokens]}
 						onValueChange={([value]) => setApiConfigurationField("modelMaxThinkingTokens", value)}
 					/>
-					<div className="w-12 text-sm text-center">{thinkingTokens}</div>
+					<div className="w-12 text-sm text-center">{customMaxThinkingTokens}</div>
 				</div>
 			</div>
 		</>
-	)
+	) : null
 }