Просмотр исходного кода

Added tests for Claude Sonnet Thinking

Catalin Lupuleti 10 месяцев назад
Родитель
Сommit
5eba1d53fb

+ 250 - 0
src/api/providers/__tests__/vertex.test.ts

@@ -2,6 +2,7 @@
 
 import { Anthropic } from "@anthropic-ai/sdk"
 import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
+import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta"
 
 import { VertexHandler } from "../vertex"
 import { ApiStreamChunk } from "../../transform/stream"
@@ -431,6 +432,138 @@ describe("VertexHandler", () => {
 		})
 	})
 
+	describe("thinking functionality", () => {
+		const mockMessages: Anthropic.Messages.MessageParam[] = [
+			{
+				role: "user",
+				content: "Hello",
+			},
+		]
+
+		const systemPrompt = "You are a helpful assistant"
+
+		it("should handle thinking content blocks and deltas", async () => {
+			const mockStream = [
+				{
+					type: "message_start",
+					message: {
+						usage: {
+							input_tokens: 10,
+							output_tokens: 0,
+						},
+					},
+				},
+				{
+					type: "content_block_start",
+					index: 0,
+					content_block: {
+						type: "thinking",
+						thinking: "Let me think about this...",
+					},
+				},
+				{
+					type: "content_block_delta",
+					delta: {
+						type: "thinking_delta",
+						thinking: " I need to consider all options.",
+					},
+				},
+				{
+					type: "content_block_start",
+					index: 1,
+					content_block: {
+						type: "text",
+						text: "Here's my answer:",
+					},
+				},
+			]
+
+			// Setup async iterator for mock stream
+			const asyncIterator = {
+				async *[Symbol.asyncIterator]() {
+					for (const chunk of mockStream) {
+						yield chunk
+					}
+				},
+			}
+
+			const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
+			;(handler["client"].messages as any).create = mockCreate
+
+			const stream = handler.createMessage(systemPrompt, mockMessages)
+			const chunks: ApiStreamChunk[] = []
+
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			// Verify thinking content is processed correctly
+			const reasoningChunks = chunks.filter((chunk) => chunk.type === "reasoning")
+			expect(reasoningChunks).toHaveLength(2)
+			expect(reasoningChunks[0].text).toBe("Let me think about this...")
+			expect(reasoningChunks[1].text).toBe(" I need to consider all options.")
+
+			// Verify text content is processed correctly
+			const textChunks = chunks.filter((chunk) => chunk.type === "text")
+			expect(textChunks).toHaveLength(2) // One for the text block, one for the newline
+			expect(textChunks[0].text).toBe("\n")
+			expect(textChunks[1].text).toBe("Here's my answer:")
+		})
+
+		it("should handle multiple thinking blocks with line breaks", async () => {
+			const mockStream = [
+				{
+					type: "content_block_start",
+					index: 0,
+					content_block: {
+						type: "thinking",
+						thinking: "First thinking block",
+					},
+				},
+				{
+					type: "content_block_start",
+					index: 1,
+					content_block: {
+						type: "thinking",
+						thinking: "Second thinking block",
+					},
+				},
+			]
+
+			const asyncIterator = {
+				async *[Symbol.asyncIterator]() {
+					for (const chunk of mockStream) {
+						yield chunk
+					}
+				},
+			}
+
+			const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
+			;(handler["client"].messages as any).create = mockCreate
+
+			const stream = handler.createMessage(systemPrompt, mockMessages)
+			const chunks: ApiStreamChunk[] = []
+
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks.length).toBe(3)
+			expect(chunks[0]).toEqual({
+				type: "reasoning",
+				text: "First thinking block",
+			})
+			expect(chunks[1]).toEqual({
+				type: "reasoning",
+				text: "\n",
+			})
+			expect(chunks[2]).toEqual({
+				type: "reasoning",
+				text: "Second thinking block",
+			})
+		})
+	})
+
 	describe("completePrompt", () => {
 		it("should complete prompt successfully", async () => {
 			const result = await handler.completePrompt("Test prompt")
@@ -500,4 +633,121 @@ describe("VertexHandler", () => {
 			expect(modelInfo.id).toBe("claude-3-7-sonnet@20250219") // Default model
 		})
 	})
+
+	describe("thinking model configuration", () => {
+		it("should configure thinking for models with :thinking suffix", () => {
+			const thinkingHandler = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 16384,
+				vertexThinking: 4096,
+			})
+
+			const modelInfo = thinkingHandler.getModel()
+
+			// Verify thinking configuration
+			expect(modelInfo.id).toBe("claude-3-7-sonnet@20250219")
+			expect(modelInfo.thinking).toBeDefined()
+			const thinkingConfig = modelInfo.thinking as { type: "enabled"; budget_tokens: number }
+			expect(thinkingConfig.type).toBe("enabled")
+			expect(thinkingConfig.budget_tokens).toBe(4096)
+			expect(modelInfo.temperature).toBe(1.0) // Thinking requires temperature 1.0
+		})
+
+		it("should calculate thinking budget correctly", () => {
+			// Test with explicit thinking budget
+			const handlerWithBudget = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 16384,
+				vertexThinking: 5000,
+			})
+
+			expect((handlerWithBudget.getModel().thinking as any).budget_tokens).toBe(5000)
+
+			// Test with default thinking budget (80% of max tokens)
+			const handlerWithDefaultBudget = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 10000,
+			})
+
+			expect((handlerWithDefaultBudget.getModel().thinking as any).budget_tokens).toBe(8000) // 80% of 10000
+
+			// Test with minimum thinking budget (should be at least 1024)
+			const handlerWithSmallMaxTokens = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 1000, // This would result in 800 tokens for thinking, but minimum is 1024
+			})
+
+			expect((handlerWithSmallMaxTokens.getModel().thinking as any).budget_tokens).toBe(1024)
+		})
+
+		it("should use anthropicThinking value if vertexThinking is not provided", () => {
+			const handler = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 16384,
+				anthropicThinking: 6000, // Should be used as fallback
+			})
+
+			expect((handler.getModel().thinking as any).budget_tokens).toBe(6000)
+		})
+
+		it("should pass thinking configuration to API", async () => {
+			const thinkingHandler = new VertexHandler({
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+				modelMaxTokens: 16384,
+				vertexThinking: 4096,
+			})
+
+			const mockCreate = jest.fn().mockImplementation(async (options) => {
+				if (!options.stream) {
+					return {
+						id: "test-completion",
+						content: [{ type: "text", text: "Test response" }],
+						role: "assistant",
+						model: options.model,
+						usage: {
+							input_tokens: 10,
+							output_tokens: 5,
+						},
+					}
+				}
+				return {
+					async *[Symbol.asyncIterator]() {
+						yield {
+							type: "message_start",
+							message: {
+								usage: {
+									input_tokens: 10,
+									output_tokens: 5,
+								},
+							},
+						}
+					},
+				}
+			})
+			;(thinkingHandler["client"].messages as any).create = mockCreate
+
+			await thinkingHandler
+				.createMessage("You are a helpful assistant", [{ role: "user", content: "Hello" }])
+				.next()
+
+			expect(mockCreate).toHaveBeenCalledWith(
+				expect.objectContaining({
+					thinking: { type: "enabled", budget_tokens: 4096 },
+					temperature: 1.0, // Thinking requires temperature 1.0
+				}),
+			)
+		})
+	})
 })

+ 56 - 1
webview-ui/src/components/settings/__tests__/ApiOptions.test.tsx

@@ -46,6 +46,21 @@ jest.mock("../TemperatureControl", () => ({
 	),
 }))
 
+// Mock ThinkingBudget component
+jest.mock("../ThinkingBudget", () => ({
+	ThinkingBudget: ({ apiConfiguration, setApiConfigurationField, modelInfo, provider }: any) =>
+		modelInfo?.thinking ? (
+			<div data-testid="thinking-budget" data-provider={provider}>
+				<input
+					data-testid="thinking-tokens"
+					value={
+						provider === "vertex" ? apiConfiguration?.vertexThinking : apiConfiguration?.anthropicThinking
+					}
+				/>
+			</div>
+		) : null,
+}))
+
 describe("ApiOptions", () => {
 	const renderApiOptions = (props = {}) => {
 		render(
@@ -72,5 +87,45 @@ describe("ApiOptions", () => {
 		expect(screen.queryByTestId("temperature-control")).not.toBeInTheDocument()
 	})
 
-	//TODO: More test cases needed
+	describe("thinking functionality", () => {
+		it("should show ThinkingBudget for Anthropic models that support thinking", () => {
+			renderApiOptions({
+				apiConfiguration: {
+					apiProvider: "anthropic",
+					apiModelId: "claude-3-7-sonnet-20250219:thinking",
+				},
+			})
+
+			expect(screen.getByTestId("thinking-budget")).toBeInTheDocument()
+			expect(screen.getByTestId("thinking-budget")).toHaveAttribute("data-provider", "anthropic")
+		})
+
+		it("should show ThinkingBudget for Vertex models that support thinking", () => {
+			renderApiOptions({
+				apiConfiguration: {
+					apiProvider: "vertex",
+					apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				},
+			})
+
+			expect(screen.getByTestId("thinking-budget")).toBeInTheDocument()
+			expect(screen.getByTestId("thinking-budget")).toHaveAttribute("data-provider", "vertex")
+		})
+
+		it("should not show ThinkingBudget for models that don't support thinking", () => {
+			renderApiOptions({
+				apiConfiguration: {
+					apiProvider: "anthropic",
+					apiModelId: "claude-3-opus-20240229",
+					modelInfo: { thinking: false }, // Non-thinking model
+				},
+			})
+
+			expect(screen.queryByTestId("thinking-budget")).not.toBeInTheDocument()
+		})
+
+		// Note: We don't need to test the actual ThinkingBudget component functionality here
+		// since we have separate tests for that component. We just need to verify that
+		// it's included in the ApiOptions component when appropriate.
+	})
 })

+ 145 - 0
webview-ui/src/components/settings/__tests__/ThinkingBudget.test.tsx

@@ -0,0 +1,145 @@
+import React from "react"
+import { render, screen, fireEvent } from "@testing-library/react"
+import { ThinkingBudget } from "../ThinkingBudget"
+import { ApiProvider, ModelInfo } from "../../../../../src/shared/api"
+
+// Mock Slider component
+jest.mock("@/components/ui", () => ({
+	Slider: ({ value, onValueChange, min, max }: any) => (
+		<input
+			type="range"
+			data-testid="slider"
+			min={min}
+			max={max}
+			value={value[0]}
+			onChange={(e) => onValueChange([parseInt(e.target.value)])}
+		/>
+	),
+}))
+
+describe("ThinkingBudget", () => {
+	const mockModelInfo: ModelInfo = {
+		thinking: true,
+		maxTokens: 16384,
+		contextWindow: 200000,
+		supportsPromptCache: true,
+		supportsImages: true,
+	}
+	const defaultProps = {
+		apiConfiguration: {},
+		setApiConfigurationField: jest.fn(),
+		modelInfo: mockModelInfo,
+		provider: "anthropic" as ApiProvider,
+	}
+
+	beforeEach(() => {
+		jest.clearAllMocks()
+	})
+
+	it("should render nothing when model doesn't support thinking", () => {
+		const { container } = render(
+			<ThinkingBudget
+				{...defaultProps}
+				modelInfo={{
+					...mockModelInfo,
+					thinking: false,
+					maxTokens: 16384,
+					contextWindow: 200000,
+					supportsPromptCache: true,
+					supportsImages: true,
+				}}
+			/>,
+		)
+
+		expect(container.firstChild).toBeNull()
+	})
+
+	it("should render sliders when model supports thinking", () => {
+		render(<ThinkingBudget {...defaultProps} />)
+
+		expect(screen.getAllByTestId("slider")).toHaveLength(2)
+	})
+
+	it("should use anthropicThinking field for Anthropic provider", () => {
+		const setApiConfigurationField = jest.fn()
+
+		render(
+			<ThinkingBudget
+				{...defaultProps}
+				apiConfiguration={{ anthropicThinking: 4096 }}
+				setApiConfigurationField={setApiConfigurationField}
+				provider="anthropic"
+			/>,
+		)
+
+		const sliders = screen.getAllByTestId("slider")
+		fireEvent.change(sliders[1], { target: { value: "5000" } })
+
+		expect(setApiConfigurationField).toHaveBeenCalledWith("anthropicThinking", 5000)
+	})
+
+	it("should use vertexThinking field for Vertex provider", () => {
+		const setApiConfigurationField = jest.fn()
+
+		render(
+			<ThinkingBudget
+				{...defaultProps}
+				apiConfiguration={{ vertexThinking: 4096 }}
+				setApiConfigurationField={setApiConfigurationField}
+				provider="vertex"
+			/>,
+		)
+
+		const sliders = screen.getAllByTestId("slider")
+		fireEvent.change(sliders[1], { target: { value: "5000" } })
+
+		expect(setApiConfigurationField).toHaveBeenCalledWith("vertexThinking", 5000)
+	})
+
+	it("should cap thinking tokens at 80% of max tokens", () => {
+		const setApiConfigurationField = jest.fn()
+
+		render(
+			<ThinkingBudget
+				{...defaultProps}
+				apiConfiguration={{ modelMaxTokens: 10000, anthropicThinking: 9000 }}
+				setApiConfigurationField={setApiConfigurationField}
+			/>,
+		)
+
+		// Effect should trigger and cap the value
+		expect(setApiConfigurationField).toHaveBeenCalledWith("anthropicThinking", 8000) // 80% of 10000
+	})
+
+	it("should use default thinking tokens if not provided", () => {
+		render(<ThinkingBudget {...defaultProps} apiConfiguration={{ modelMaxTokens: 10000 }} />)
+
+		// Default is 80% of max tokens, capped at 8192
+		const sliders = screen.getAllByTestId("slider")
+		expect(sliders[1]).toHaveValue("8000") // 80% of 10000
+	})
+
+	it("should use min thinking tokens of 1024", () => {
+		render(<ThinkingBudget {...defaultProps} apiConfiguration={{ modelMaxTokens: 1000 }} />)
+
+		const sliders = screen.getAllByTestId("slider")
+		expect(sliders[1].getAttribute("min")).toBe("1024")
+	})
+
+	it("should update max tokens when slider changes", () => {
+		const setApiConfigurationField = jest.fn()
+
+		render(
+			<ThinkingBudget
+				{...defaultProps}
+				apiConfiguration={{ modelMaxTokens: 10000 }}
+				setApiConfigurationField={setApiConfigurationField}
+			/>,
+		)
+
+		const sliders = screen.getAllByTestId("slider")
+		fireEvent.change(sliders[0], { target: { value: "12000" } })
+
+		expect(setApiConfigurationField).toHaveBeenCalledWith("modelMaxTokens", 12000)
+	})
+})