Browse Source

Merge pull request #1254 from lupuletic/added-new-thinking-support

Added support for Claude Sonnet 3.7 thinking via Vertex AI
Chris Estreich 1 year ago
parent
commit
54c687485b

+ 5 - 5
package-lock.json

@@ -10,7 +10,7 @@
 			"dependencies": {
 				"@anthropic-ai/bedrock-sdk": "^0.10.2",
 				"@anthropic-ai/sdk": "^0.37.0",
-				"@anthropic-ai/vertex-sdk": "^0.4.1",
+				"@anthropic-ai/vertex-sdk": "^0.7.0",
 				"@aws-sdk/client-bedrock-runtime": "^3.706.0",
 				"@google/generative-ai": "^0.18.0",
 				"@mistralai/mistralai": "^1.3.6",
@@ -150,11 +150,11 @@
 			"integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA=="
 		},
 		"node_modules/@anthropic-ai/vertex-sdk": {
-			"version": "0.4.3",
-			"resolved": "https://registry.npmjs.org/@anthropic-ai/vertex-sdk/-/vertex-sdk-0.4.3.tgz",
-			"integrity": "sha512-2Uef0C5P2Hx+T88RnUSRA3u4aZqmqnrRSOb2N64ozgKPiSUPTM5JlggAq2b32yWMj5d3MLYa6spJXKMmHXOcoA==",
+			"version": "0.7.0",
+			"resolved": "https://registry.npmjs.org/@anthropic-ai/vertex-sdk/-/vertex-sdk-0.7.0.tgz",
+			"integrity": "sha512-zNm3hUXgYmYDTyveIxOyxbcnh5VXFkrLo4bSnG6LAfGzW7k3k2iCNDSVKtR9qZrK2BCid7JtVu7jsEKaZ/9dSw==",
 			"dependencies": {
-				"@anthropic-ai/sdk": ">=0.14 <1",
+				"@anthropic-ai/sdk": ">=0.35 <1",
 				"google-auth-library": "^9.4.2"
 			}
 		},

+ 1 - 1
package.json

@@ -305,7 +305,7 @@
 	"dependencies": {
 		"@anthropic-ai/bedrock-sdk": "^0.10.2",
 		"@anthropic-ai/sdk": "^0.37.0",
-		"@anthropic-ai/vertex-sdk": "^0.4.1",
+		"@anthropic-ai/vertex-sdk": "^0.7.0",
 		"@aws-sdk/client-bedrock-runtime": "^3.706.0",
 		"@google/generative-ai": "^0.18.0",
 		"@mistralai/mistralai": "^1.3.6",

+ 250 - 0
src/api/providers/__tests__/vertex.test.ts

@@ -2,6 +2,7 @@
 
 import { Anthropic } from "@anthropic-ai/sdk"
 import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
+import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta"
 
 import { VertexHandler } from "../vertex"
 import { ApiStreamChunk } from "../../transform/stream"
@@ -431,6 +432,138 @@ describe("VertexHandler", () => {
 		})
 	})
 
+	describe("thinking functionality", () => {
+		const mockMessages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: "Hello",
+			},
+		]
+
+		const systemPrompt = "You are a helpful assistant"
+
+		it("should handle thinking content blocks and deltas", async () => {
+			const mockStream = [
+				{
+					type: "message_start",
+					message: {
+						usage: {
+							input_tokens: 10,
+							output_tokens: 0,
+						},
+					},
+				},
+				{
+					type: "content_block_start",
+					index: 0,
+					content_block: {
+						type: "thinking",
+						thinking: "Let me think about this...",
+					},
+				},
+				{
+					type: "content_block_delta",
+					delta: {
+						type: "thinking_delta",
+						thinking: " I need to consider all options.",
+					},
+				},
+				{
+					type: "content_block_start",
+					index: 1,
+					content_block: {
+						type: "text",
+						text: "Here's my answer:",
+					},
+				},
+			]
+
+			// Setup async iterator for mock stream
+			const asyncIterator = {
+				async *[Symbol.asyncIterator]() {
+					for (const chunk of mockStream) {
+						yield chunk
+					}
+				},
+			}
+
+			const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
+			;(handler["client"].messages as any).create = mockCreate
+
+			const stream = handler.createMessage(systemPrompt, mockMessages)
+			const chunks: ApiStreamChunk[] = []
+
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Verify thinking content is processed correctly
+			const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning")
+			expect(reasoningChunks).toHaveLength(2)
+			expect(reasoningChunks[0].text).toBe("Let me think about this...")
+			expect(reasoningChunks[1].text).toBe(" I need to consider all options.")
+
+			// Verify text content is processed correctly
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(2) // One for the text block, one for the newline
+			expect(textChunks[0].text).toBe("\n")
+			expect(textChunks[1].text).toBe("Here's my answer:")
+		})
+
+		it("should handle multiple thinking blocks with line breaks", async () => {
+			const mockStream = [
+				{
+					type: "content_block_start",
+					index: 0,
+					content_block: {
+						type: "thinking",
+						thinking: "First thinking block",
+					},
+				},
+				{
+					type: "content_block_start",
+					index: 1,
+					content_block: {
+						type: "thinking",
+						thinking: "Second thinking block",
+					},
+				},
+			]
+
+			const asyncIterator = {
+				async *[Symbol.asyncIterator]() {
+					for (const chunk of mockStream) {
+						yield chunk
+					}
+				},
+			}
+
+			const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
+			;(handler["client"].messages as any).create = mockCreate
+
+			const stream = handler.createMessage(systemPrompt, mockMessages)
+			const chunks: ApiStreamChunk[] = []
+
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks.length).toBe(3)
+			expect(chunks[0]).toEqual({
+				type: "reasoning",
+				text: "First thinking block",
+			})
+			expect(chunks[1]).toEqual({
+				type: "reasoning",
+				text: "\n",
+			})
+			expect(chunks[2]).toEqual({
+				type: "reasoning",
+				text: "Second thinking block",
+			})
+		})
+	})
+
 	describe("completePrompt", () => {
 		it("should complete prompt successfully", async () => {
 			const result = await handler.completePrompt("Test prompt")
@@ -500,4 +633,121 @@ describe("VertexHandler", () => {
 			expect(modelInfo.id).toBe("claude-3-7-sonnet@20250219") // Default model
 		})
 	})
+
+	describe("thinking model configuration", () => {
+		it("should configure thinking for models with :thinking suffix", () => {
+			const thinkingHandler = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 16384,
+				vertexThinking: 4096,
+			})
+
+			const modelInfo = thinkingHandler.getModel()
+
+			// Verify thinking configuration
+			expect(modelInfo.id).toBe("claude-3-7-sonnet@20250219")
+			expect(modelInfo.thinking).toBeDefined()
+			const thinkingConfig = modelInfo.thinking as { type: "enabled"; budget_tokens: number }
+			expect(thinkingConfig.type).toBe("enabled")
+			expect(thinkingConfig.budget_tokens).toBe(4096)
+			expect(modelInfo.temperature).toBe(1.0) // Thinking requires temperature 1.0
+		})
+
+		it("should calculate thinking budget correctly", () => {
+			// Test with explicit thinking budget
+			const handlerWithBudget = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 16384,
+				vertexThinking: 5000,
+			})
+
+			expect((handlerWithBudget.getModel().thinking as any).budget_tokens).toBe(5000)
+
+			// Test with default thinking budget (80% of max tokens)
+			const handlerWithDefaultBudget = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 10000,
+			})
+
+			expect((handlerWithDefaultBudget.getModel().thinking as any).budget_tokens).toBe(8000) // 80% of 10000
+
+			// Test with minimum thinking budget (should be at least 1024)
+			const handlerWithSmallMaxTokens = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 1000, // This would result in 800 tokens for thinking, but minimum is 1024
+			})
+
+			expect((handlerWithSmallMaxTokens.getModel().thinking as any).budget_tokens).toBe(1024)
+		})
+
+		it("should use anthropicThinking value if vertexThinking is not provided", () => {
+			const handler = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 16384,
+				anthropicThinking: 6000, // Should be used as fallback
+			})
+
+			expect((handler.getModel().thinking as any).budget_tokens).toBe(6000)
+		})
+
+		it("should pass thinking configuration to API", async () => {
+			const thinkingHandler = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 16384,
+				vertexThinking: 4096,
+			})
+
+			const mockCreate = jest.fn().mockImplementation(async (options) => {
+				if (!options.stream) {
+					return {
+						id: "test-completion",
+						content: [{ type: "text", text: "Test response" }],
+						role: "assistant",
+						model: options.model,
+						usage: {
+							input_tokens: 10,
+							output_tokens: 5,
+						},
+					}
+				}
+				return {
+					async *[Symbol.asyncIterator]() {
+						yield {
+							type: "message_start",
+							message: {
+								usage: {
+									input_tokens: 10,
+									output_tokens: 5,
+								},
+							},
+						}
+					},
+				}
+			})
+			;(thinkingHandler["client"].messages as any).create = mockCreate
+
+			await thinkingHandler
+				.createMessage("You are a helpful assistant", [{ role: "user", content: "Hello" }])
+				.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					thinking: { type: "enabled", budget_tokens: 4096 },
+					temperature: 1.0, // Thinking requires temperature 1.0
+				}),
+			)
+		})
+	})
 })

+ 91 - 19
src/api/providers/vertex.ts

@@ -2,6 +2,7 @@ import { Anthropic } from "@anthropic-ai/sdk"
 import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
 import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming"
 import { ApiHandler, SingleCompletionHandler } from "../"
+import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta"
 import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
 import { ApiStream } from "../transform/stream"
 
@@ -70,15 +71,25 @@ interface VertexMessageStreamEvent {
 	usage?: {
 		output_tokens: number
 	}
-	content_block?: {
-		type: "text"
-		text: string
-	}
+	content_block?:
+		| {
+				type: "text"
+				text: string
+		  }
+		| {
+				type: "thinking"
+				thinking: string
+		  }
 	index?: number
-	delta?: {
-		type: "text_delta"
-		text: string
-	}
+	delta?:
+		| {
+				type: "text_delta"
+				text: string
+		  }
+		| {
+				type: "thinking_delta"
+				thinking: string
+		  }
 }
 
 // https://docs.anthropic.com/en/api/claude-on-vertex-ai
@@ -145,6 +156,7 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 
 	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
 		const model = this.getModel()
+		let { id, info, temperature, maxTokens, thinking } = model
 		const useCache = model.info.supportsPromptCache
 
 		// Find indices of user messages that we want to cache
@@ -158,9 +170,10 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 
 		// Create the stream with appropriate caching configuration
 		const params = {
-			model: model.id,
-			max_tokens: model.info.maxTokens || 8192,
-			temperature: this.options.modelTemperature ?? 0,
+			model: id,
+			max_tokens: maxTokens,
+			temperature,
+			thinking,
 			// Cache the system prompt if caching is enabled
 			system: useCache
 				? [
@@ -220,6 +233,19 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 							}
 							break
 						}
+						case "thinking": {
+							if (chunk.index! > 0) {
+								yield {
+									type: "reasoning",
+									text: "\n",
+								}
+							}
+							yield {
+								type: "reasoning",
+								text: (chunk.content_block as any).thinking,
+							}
+							break
+						}
 					}
 					break
 				}
@@ -232,6 +258,13 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 							}
 							break
 						}
+						case "thinking_delta": {
+							yield {
+								type: "reasoning",
+								text: (chunk.delta as any).thinking,
+							}
+							break
+						}
 					}
 					break
 				}
@@ -239,24 +272,63 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 		}
 	}
 
-	getModel(): { id: VertexModelId; info: ModelInfo } {
+	getModel(): {
+		id: VertexModelId
+		info: ModelInfo
+		temperature: number
+		maxTokens: number
+		thinking?: BetaThinkingConfigParam
+	} {
 		const modelId = this.options.apiModelId
+		let temperature = this.options.modelTemperature ?? 0
+		let thinking: BetaThinkingConfigParam | undefined = undefined
+
 		if (modelId && modelId in vertexModels) {
 			const id = modelId as VertexModelId
-			return { id, info: vertexModels[id] }
+			const info: ModelInfo = vertexModels[id]
+
+			// The `:thinking` variant is a virtual identifier for thinking-enabled models
+			// Similar to how it's handled in the Anthropic provider
+			let actualId = id
+			if (id.endsWith(":thinking")) {
+				actualId = id.replace(":thinking", "") as VertexModelId
+			}
+
+			const maxTokens = this.options.modelMaxTokens || info.maxTokens || 8192
+
+			if (info.thinking) {
+				temperature = 1.0 // Thinking requires temperature 1.0
+				const maxBudgetTokens = Math.floor(maxTokens * 0.8)
+				const budgetTokens = Math.max(
+					Math.min(
+						this.options.vertexThinking ?? this.options.anthropicThinking ?? maxBudgetTokens,
+						maxBudgetTokens,
+					),
+					1024,
+				)
+				thinking = { type: "enabled", budget_tokens: budgetTokens }
+			}
+
+			return { id: actualId, info, temperature, maxTokens, thinking }
 		}
-		return { id: vertexDefaultModelId, info: vertexModels[vertexDefaultModelId] }
+
+		const id = vertexDefaultModelId
+		const info = vertexModels[id]
+		const maxTokens = this.options.modelMaxTokens || info.maxTokens || 8192
+
+		return { id, info, temperature, maxTokens, thinking }
 	}
 
 	async completePrompt(prompt: string): Promise<string> {
 		try {
-			const model = this.getModel()
-			const useCache = model.info.supportsPromptCache
+			let { id, info, temperature, maxTokens, thinking } = this.getModel()
+			const useCache = info.supportsPromptCache
 
 			const params = {
-				model: model.id,
-				max_tokens: model.info.maxTokens || 8192,
-				temperature: this.options.modelTemperature ?? 0,
+				model: id,
+				max_tokens: maxTokens,
+				temperature,
+				thinking,
 				system: "", // No system prompt needed for single completions
 				messages: [
 					{

+ 5 - 0
src/core/webview/ClineProvider.ts

@@ -1652,6 +1652,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			lmStudioBaseUrl,
 			anthropicBaseUrl,
 			anthropicThinking,
+			vertexThinking,
 			geminiApiKey,
 			openAiNativeApiKey,
 			deepSeekApiKey,
@@ -1701,6 +1702,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.updateGlobalState("lmStudioBaseUrl", lmStudioBaseUrl),
 			this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl),
 			this.updateGlobalState("anthropicThinking", anthropicThinking),
+			this.updateGlobalState("vertexThinking", vertexThinking),
 			this.storeSecret("geminiApiKey", geminiApiKey),
 			this.storeSecret("openAiNativeApiKey", openAiNativeApiKey),
 			this.storeSecret("deepSeekApiKey", deepSeekApiKey),
@@ -2158,6 +2160,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			lmStudioBaseUrl,
 			anthropicBaseUrl,
 			anthropicThinking,
+			vertexThinking,
 			geminiApiKey,
 			openAiNativeApiKey,
 			deepSeekApiKey,
@@ -2242,6 +2245,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.getGlobalState("lmStudioBaseUrl") as Promise<string | undefined>,
 			this.getGlobalState("anthropicBaseUrl") as Promise<string | undefined>,
 			this.getGlobalState("anthropicThinking") as Promise<number | undefined>,
+			this.getGlobalState("vertexThinking") as Promise<number | undefined>,
 			this.getSecret("geminiApiKey") as Promise<string | undefined>,
 			this.getSecret("openAiNativeApiKey") as Promise<string | undefined>,
 			this.getSecret("deepSeekApiKey") as Promise<string | undefined>,
@@ -2343,6 +2347,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 				lmStudioBaseUrl,
 				anthropicBaseUrl,
 				anthropicThinking,
+				vertexThinking,
 				geminiApiKey,
 				openAiNativeApiKey,
 				deepSeekApiKey,

+ 14 - 0
src/shared/api.ts

@@ -41,6 +41,7 @@ export interface ApiHandlerOptions {
 	awsUseProfile?: boolean
 	vertexProjectId?: string
 	vertexRegion?: string
+	vertexThinking?: number
 	openAiBaseUrl?: string
 	openAiApiKey?: string
 	openAiModelId?: string
@@ -436,6 +437,18 @@ export const openRouterDefaultModelInfo: ModelInfo = {
 export type VertexModelId = keyof typeof vertexModels
 export const vertexDefaultModelId: VertexModelId = "claude-3-7-sonnet@20250219"
 export const vertexModels = {
+	"claude-3-7-sonnet@20250219:thinking": {
+		maxTokens: 64000,
+		contextWindow: 200_000,
+		supportsImages: true,
+		supportsComputerUse: true,
+		supportsPromptCache: true,
+		inputPrice: 3.0,
+		outputPrice: 15.0,
+		cacheWritesPrice: 3.75,
+		cacheReadsPrice: 0.3,
+		thinking: true,
+	},
 	"claude-3-7-sonnet@20250219": {
 		maxTokens: 8192,
 		contextWindow: 200_000,
@@ -446,6 +459,7 @@ export const vertexModels = {
 		outputPrice: 15.0,
 		cacheWritesPrice: 3.75,
 		cacheReadsPrice: 0.3,
+		thinking: false,
 	},
 	"claude-3-5-sonnet-v2@20241022": {
 		maxTokens: 8192,

+ 1 - 0
src/shared/globalState.ts

@@ -24,6 +24,7 @@ export type GlobalStateKey =
 	| "awsUseProfile"
 	| "vertexProjectId"
 	| "vertexRegion"
+	| "vertexThinking"
 	| "lastShownAnnouncementId"
 	| "customInstructions"
 	| "alwaysAllowReadOnly"

+ 3 - 0
webview-ui/src/components/settings/ApiOptions.tsx

@@ -7,6 +7,7 @@ import * as vscodemodels from "vscode"
 import {
 	ApiConfiguration,
 	ModelInfo,
+	ApiProvider,
 	anthropicDefaultModelId,
 	anthropicModels,
 	azureOpenAiDefaultApiVersion,
@@ -1380,9 +1381,11 @@ const ApiOptions = ({
 						/>
 					</div>
 					<ThinkingBudget
+						key={`${selectedProvider}-${selectedModelId}`}
 						apiConfiguration={apiConfiguration}
 						setApiConfigurationField={setApiConfigurationField}
 						modelInfo={selectedModelInfo}
+						provider={selectedProvider as ApiProvider}
 					/>
 					<ModelInfoView
 						selectedModelId={selectedModelId}

+ 22 - 8
webview-ui/src/components/settings/ThinkingBudget.tsx

@@ -1,5 +1,5 @@
-import { useEffect } from "react"
-
+import { useEffect, useMemo } from "react"
+import { ApiProvider } from "../../../../src/shared/api"
 import { Slider } from "@/components/ui"
 
 import { ApiConfiguration, ModelInfo } from "../../../../src/shared/api"
@@ -8,24 +8,38 @@ interface ThinkingBudgetProps {
 	apiConfiguration: ApiConfiguration
 	setApiConfigurationField: <K extends keyof ApiConfiguration>(field: K, value: ApiConfiguration[K]) => void
 	modelInfo?: ModelInfo
+	provider?: ApiProvider
 }
 
-export const ThinkingBudget = ({ apiConfiguration, setApiConfigurationField, modelInfo }: ThinkingBudgetProps) => {
+export const ThinkingBudget = ({
+	apiConfiguration,
+	setApiConfigurationField,
+	modelInfo,
+	provider,
+}: ThinkingBudgetProps) => {
+	const isVertexProvider = provider === "vertex"
+	const budgetField = isVertexProvider ? "vertexThinking" : "anthropicThinking"
+
 	const tokens = apiConfiguration?.modelMaxTokens || modelInfo?.maxTokens || 64_000
 	const tokensMin = 8192
 	const tokensMax = modelInfo?.maxTokens || 64_000
 
-	const thinkingTokens = apiConfiguration?.anthropicThinking || 8192
+	// Get the appropriate thinking tokens based on provider
+	const thinkingTokens = useMemo(() => {
+		const value = isVertexProvider ? apiConfiguration?.vertexThinking : apiConfiguration?.anthropicThinking
+		return value || Math.min(Math.floor(0.8 * tokens), 8192)
+	}, [apiConfiguration, isVertexProvider, tokens])
+
 	const thinkingTokensMin = 1024
 	const thinkingTokensMax = Math.floor(0.8 * tokens)
 
 	useEffect(() => {
 		if (thinkingTokens > thinkingTokensMax) {
-			setApiConfigurationField("anthropicThinking", thinkingTokensMax)
+			setApiConfigurationField(budgetField, thinkingTokensMax)
 		}
-	}, [thinkingTokens, thinkingTokensMax, setApiConfigurationField])
+	}, [thinkingTokens, thinkingTokensMax, setApiConfigurationField, budgetField])
 
-	if (!modelInfo || !modelInfo.thinking) {
+	if (!modelInfo?.thinking) {
 		return null
 	}
 
@@ -52,7 +66,7 @@ export const ThinkingBudget = ({ apiConfiguration, setApiConfigurationField, mod
 						max={thinkingTokensMax}
 						step={1024}
 						value={[thinkingTokens]}
-						onValueChange={([value]) => setApiConfigurationField("anthropicThinking", value)}
+						onValueChange={([value]) => setApiConfigurationField(budgetField, value)}
 					/>
 					<div className="w-12 text-sm text-center">{thinkingTokens}</div>
 				</div>

+ 56 - 1
webview-ui/src/components/settings/__tests__/ApiOptions.test.tsx

@@ -46,6 +46,21 @@ jest.mock("../TemperatureControl", () => ({
 	),
 }))
 
+// Mock ThinkingBudget component
+jest.mock("../ThinkingBudget", () => ({
+	ThinkingBudget: ({ apiConfiguration, setApiConfigurationField, modelInfo, provider }: any) =>
+		modelInfo?.thinking ? (
+			<div data-testid="thinking-budget" data-provider={provider}>
+				<input
+					data-testid="thinking-tokens"
+					value={
+						provider === "vertex" ? apiConfiguration?.vertexThinking : apiConfiguration?.anthropicThinking
+					}
+				/>
+			</div>
+		) : null,
+}))
+
 describe("ApiOptions", () => {
 	const renderApiOptions = (props = {}) => {
 		render(
@@ -72,5 +87,45 @@ describe("ApiOptions", () => {
 		expect(screen.queryByTestId("temperature-control")).not.toBeInTheDocument()
 	})
 
-	//TODO: More test cases needed
+	describe("thinking functionality", () => {
+		it("should show ThinkingBudget for Anthropic models that support thinking", () => {
+			renderApiOptions({
+				apiConfiguration: {
+					apiProvider: "anthropic",
+					apiModelId: "claude-3-7-sonnet-20250219:thinking",
+				},
+			})
+
+			expect(screen.getByTestId("thinking-budget")).toBeInTheDocument()
+			expect(screen.getByTestId("thinking-budget")).toHaveAttribute("data-provider", "anthropic")
+		})
+
+		it("should show ThinkingBudget for Vertex models that support thinking", () => {
+			renderApiOptions({
+				apiConfiguration: {
+					apiProvider: "vertex",
+					apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				},
+			})
+
+			expect(screen.getByTestId("thinking-budget")).toBeInTheDocument()
+			expect(screen.getByTestId("thinking-budget")).toHaveAttribute("data-provider", "vertex")
+		})
+
+		it("should not show ThinkingBudget for models that don't support thinking", () => {
+			renderApiOptions({
+				apiConfiguration: {
+					apiProvider: "anthropic",
+					apiModelId: "claude-3-opus-20240229",
+					modelInfo: { thinking: false }, // Non-thinking model
+				},
+			})
+
+			expect(screen.queryByTestId("thinking-budget")).not.toBeInTheDocument()
+		})
+
+		// Note: We don't need to test the actual ThinkingBudget component functionality here
+		// since we have separate tests for that component. We just need to verify that
+		// it's included in the ApiOptions component when appropriate.
+	})
 })

+ 145 - 0
webview-ui/src/components/settings/__tests__/ThinkingBudget.test.tsx

@@ -0,0 +1,145 @@
+import React from "react"
+import { render, screen, fireEvent } from "@testing-library/react"
+import { ThinkingBudget } from "../ThinkingBudget"
+import { ApiProvider, ModelInfo } from "../../../../../src/shared/api"
+
+// Mock Slider component
+jest.mock("@/components/ui", () => ({
+	Slider: ({ value, onValueChange, min, max }: any) => (
+		<input
+			type="range"
+			data-testid="slider"
+			min={min}
+			max={max}
+			value={value[0]}
+			onChange={(e) => onValueChange([parseInt(e.target.value)])}
+		/>
+	),
+}))
+
+describe("ThinkingBudget", () => {
+	const mockModelInfo: ModelInfo = {
+		thinking: true,
+		maxTokens: 16384,
+		contextWindow: 200000,
+		supportsPromptCache: true,
+		supportsImages: true,
+	}
+	const defaultProps = {
+		apiConfiguration: {},
+		setApiConfigurationField: jest.fn(),
+		modelInfo: mockModelInfo,
+		provider: "anthropic" as ApiProvider,
+	}
+
+	beforeEach(() => {
+		jest.clearAllMocks()
+	})
+
+	it("should render nothing when model doesn't support thinking", () => {
+		const { container } = render(
+			<ThinkingBudget
+				{...defaultProps}
+				modelInfo={{
+					...mockModelInfo,
+					thinking: false,
+					maxTokens: 16384,
+					contextWindow: 200000,
+					supportsPromptCache: true,
+					supportsImages: true,
+				}}
+			/>,
+		)
+
+		expect(container.firstChild).toBeNull()
+	})
+
+	it("should render sliders when model supports thinking", () => {
+		render(<ThinkingBudget {...defaultProps} />)
+
+		expect(screen.getAllByTestId("slider")).toHaveLength(2)
+	})
+
+	it("should use anthropicThinking field for Anthropic provider", () => {
+		const setApiConfigurationField = jest.fn()
+
+		render(
+			<ThinkingBudget
+				{...defaultProps}
+				apiConfiguration={{ anthropicThinking: 4096 }}
+				setApiConfigurationField={setApiConfigurationField}
+				provider="anthropic"
+			/>,
+		)
+
+		const sliders = screen.getAllByTestId("slider")
+		fireEvent.change(sliders[1], { target: { value: "5000" } })
+
+		expect(setApiConfigurationField).toHaveBeenCalledWith("anthropicThinking", 5000)
+	})
+
+	it("should use vertexThinking field for Vertex provider", () => {
+		const setApiConfigurationField = jest.fn()
+
+		render(
+			<ThinkingBudget
+				{...defaultProps}
+				apiConfiguration={{ vertexThinking: 4096 }}
+				setApiConfigurationField={setApiConfigurationField}
+				provider="vertex"
+			/>,
+		)
+
+		const sliders = screen.getAllByTestId("slider")
+		fireEvent.change(sliders[1], { target: { value: "5000" } })
+
+		expect(setApiConfigurationField).toHaveBeenCalledWith("vertexThinking", 5000)
+	})
+
+	it("should cap thinking tokens at 80% of max tokens", () => {
+		const setApiConfigurationField = jest.fn()
+
+		render(
+			<ThinkingBudget
+				{...defaultProps}
+				apiConfiguration={{ modelMaxTokens: 10000, anthropicThinking: 9000 }}
+				setApiConfigurationField={setApiConfigurationField}
+			/>,
+		)
+
+		// Effect should trigger and cap the value
+		expect(setApiConfigurationField).toHaveBeenCalledWith("anthropicThinking", 8000) // 80% of 10000
+	})
+
+	it("should use default thinking tokens if not provided", () => {
+		render(<ThinkingBudget {...defaultProps} apiConfiguration={{ modelMaxTokens: 10000 }} />)
+
+		// Default is 80% of max tokens, capped at 8192
+		const sliders = screen.getAllByTestId("slider")
+		expect(sliders[1]).toHaveValue("8000") // 80% of 10000
+	})
+
+	it("should use min thinking tokens of 1024", () => {
+		render(<ThinkingBudget {...defaultProps} apiConfiguration={{ modelMaxTokens: 1000 }} />)
+
+		const sliders = screen.getAllByTestId("slider")
+		expect(sliders[1].getAttribute("min")).toBe("1024")
+	})
+
+	it("should update max tokens when slider changes", () => {
+		const setApiConfigurationField = jest.fn()
+
+		render(
+			<ThinkingBudget
+				{...defaultProps}
+				apiConfiguration={{ modelMaxTokens: 10000 }}
+				setApiConfigurationField={setApiConfigurationField}
+			/>,
+		)
+
+		const sliders = screen.getAllByTestId("slider")
+		fireEvent.change(sliders[0], { target: { value: "12000" } })
+
+		expect(setApiConfigurationField).toHaveBeenCalledWith("modelMaxTokens", 12000)
+	})
+})