Browse Source

Support Claude 3.7 Sonnet "Thinking" in OpenRouter

Chris Estreich 10 months ago
parent
commit
392a237985

+ 61 - 31
src/api/providers/openrouter.ts

@@ -52,10 +52,14 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			...convertToOpenAiMessages(messages),
 		]
 
+		const { id: modelId, info: modelInfo } = this.getModel()
+
 		// prompt caching: https://openrouter.ai/docs/prompt-caching
 		// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
-		switch (this.getModel().id) {
+		switch (modelId) {
+			case "anthropic/claude-3.7-sonnet:thinking":
 			case "anthropic/claude-3.7-sonnet":
+			case "anthropic/claude-3.7-sonnet:beta":
 			case "anthropic/claude-3.5-sonnet":
 			case "anthropic/claude-3.5-sonnet:beta":
 			case "anthropic/claude-3.5-sonnet-20240620":
@@ -103,31 +107,25 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 				break
 		}
 
-		// Not sure how openrouter defaults max tokens when no value is provided, but the anthropic api requires this value and since they offer both 4096 and 8192 variants, we should ensure 8192.
-		// (models usually default to max tokens allowed)
-		let maxTokens: number | undefined
-		switch (this.getModel().id) {
-			case "anthropic/claude-3.7-sonnet":
-			case "anthropic/claude-3.5-sonnet":
-			case "anthropic/claude-3.5-sonnet:beta":
-			case "anthropic/claude-3.5-sonnet-20240620":
-			case "anthropic/claude-3.5-sonnet-20240620:beta":
-			case "anthropic/claude-3-5-haiku":
-			case "anthropic/claude-3-5-haiku:beta":
-			case "anthropic/claude-3-5-haiku-20241022":
-			case "anthropic/claude-3-5-haiku-20241022:beta":
-				maxTokens = 8_192
-				break
+		// Not sure how openrouter defaults max tokens when no value is
+		// provided, but the Anthropic API requires this value and since they
+		// offer both 4096 and 8192 variants, we should ensure 8192.
+		// (Models usually default to max tokens allowed.)
+		let maxTokens: number | undefined = undefined
+
+		if (modelId.startsWith("anthropic/claude-3.5")) {
+			maxTokens = modelInfo.maxTokens ?? 8_192
+		}
+
+		if (modelId.startsWith("anthropic/claude-3.7")) {
+			maxTokens = modelInfo.maxTokens ?? 16_384
 		}
 
 		let defaultTemperature = OPENROUTER_DEFAULT_TEMPERATURE
 		let topP: number | undefined = undefined
 
 		// Handle models based on deepseek-r1
-		if (
-			this.getModel().id.startsWith("deepseek/deepseek-r1") ||
-			this.getModel().id === "perplexity/sonar-reasoning"
-		) {
+		if (modelId.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning") {
 			// Recommended temperature for DeepSeek reasoning models
 			defaultTemperature = DEEP_SEEK_DEFAULT_TEMPERATURE
 			// DeepSeek highly recommends using user instead of system role
@@ -136,24 +134,37 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			topP = 0.95
 		}
 
+		let temperature = this.options.modelTemperature ?? defaultTemperature
+
+		if (modelInfo.thinking) {
+			temperature = 1.0
+		}
+
 		// https://openrouter.ai/docs/transforms
 		let fullResponseText = ""
-		const stream = await this.client.chat.completions.create({
-			model: this.getModel().id,
+
+		const completionParams: OpenRouterChatCompletionParams = {
+			model: modelId,
 			max_tokens: maxTokens,
-			temperature: this.options.modelTemperature ?? defaultTemperature,
+			temperature,
 			top_p: topP,
 			messages: openAiMessages,
 			stream: true,
 			include_reasoning: true,
 			// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
 			...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }),
-		} as OpenRouterChatCompletionParams)
+		}
+
+		console.log("OpenRouter completionParams:", completionParams)
+
+		const stream = await this.client.chat.completions.create(completionParams)
 
 		let genId: string | undefined
 
 		for await (const chunk of stream as unknown as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>) {
-			// openrouter returns an error object instead of the openai sdk throwing an error
+			console.log("OpenRouter chunk:", chunk)
+
+			// OpenRouter returns an error object instead of the OpenAI SDK throwing an error.
 			if ("error" in chunk) {
 				const error = chunk.error as { message?: string; code?: number }
 				console.error(`OpenRouter API Error: ${error?.code} - ${error?.message}`)
@@ -165,12 +176,14 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			}
 
 			const delta = chunk.choices[0]?.delta
+
 			if ("reasoning" in delta && delta.reasoning) {
 				yield {
 					type: "reasoning",
 					text: delta.reasoning,
 				} as ApiStreamChunk
 			}
+
 			if (delta?.content) {
 				fullResponseText += delta.content
 				yield {
@@ -178,6 +191,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 					text: delta.content,
 				} as ApiStreamChunk
 			}
+
 			// if (chunk.usage) {
 			// 	yield {
 			// 		type: "usage",
@@ -187,10 +201,12 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			// }
 		}
 
-		// retry fetching generation details
+		// Retry fetching generation details.
 		let attempt = 0
+
 		while (attempt++ < 10) {
 			await delay(200) // FIXME: necessary delay to ensure generation endpoint is ready
+
 			try {
 				const response = await axios.get(`https://openrouter.ai/api/v1/generation?id=${genId}`, {
 					headers: {
@@ -201,6 +217,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 
 				const generation = response.data?.data
 				console.log("OpenRouter generation details:", response.data)
+
 				yield {
 					type: "usage",
 					// cacheWriteTokens: 0,
@@ -211,6 +228,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 					totalCost: generation?.total_cost || 0,
 					fullResponseText,
 				} as OpenRouterApiStreamUsageChunk
+
 				return
 			} catch (error) {
 				// ignore if fails
@@ -218,13 +236,13 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			}
 		}
 	}
-	getModel(): { id: string; info: ModelInfo } {
+
+	getModel() {
 		const modelId = this.options.openRouterModelId
 		const modelInfo = this.options.openRouterModelInfo
-		if (modelId && modelInfo) {
-			return { id: modelId, info: modelInfo }
-		}
-		return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
+		return modelId && modelInfo
+			? { id: modelId, info: modelInfo }
+			: { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
 	}
 
 	async completePrompt(prompt: string): Promise<string> {
@@ -247,6 +265,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			if (error instanceof Error) {
 				throw new Error(`OpenRouter completion error: ${error.message}`)
 			}
+
 			throw error
 		}
 	}
@@ -268,14 +287,23 @@ export async function getOpenRouterModels() {
 				inputPrice: parseApiPrice(rawModel.pricing?.prompt),
 				outputPrice: parseApiPrice(rawModel.pricing?.completion),
 				description: rawModel.description,
+				thinking: rawModel.id === "anthropic/claude-3.7-sonnet:thinking",
 			}
 
 			switch (rawModel.id) {
+				case "anthropic/claude-3.7-sonnet:thinking":
 				case "anthropic/claude-3.7-sonnet":
 				case "anthropic/claude-3.7-sonnet:beta":
+					modelInfo.maxTokens = 16_384
+					modelInfo.supportsComputerUse = true
+					modelInfo.supportsPromptCache = true
+					modelInfo.cacheWritesPrice = 3.75
+					modelInfo.cacheReadsPrice = 0.3
+					break
 				case "anthropic/claude-3.5-sonnet":
 				case "anthropic/claude-3.5-sonnet:beta":
 					// NOTE: This needs to be synced with api.ts/openrouter default model info.
+					modelInfo.maxTokens = 8_192
 					modelInfo.supportsComputerUse = true
 					modelInfo.supportsPromptCache = true
 					modelInfo.cacheWritesPrice = 3.75
@@ -283,6 +311,7 @@ export async function getOpenRouterModels() {
 					break
 				case "anthropic/claude-3.5-sonnet-20240620":
 				case "anthropic/claude-3.5-sonnet-20240620:beta":
+					modelInfo.maxTokens = 8_192
 					modelInfo.supportsPromptCache = true
 					modelInfo.cacheWritesPrice = 3.75
 					modelInfo.cacheReadsPrice = 0.3
@@ -295,6 +324,7 @@ export async function getOpenRouterModels() {
 				case "anthropic/claude-3.5-haiku:beta":
 				case "anthropic/claude-3.5-haiku-20241022":
 				case "anthropic/claude-3.5-haiku-20241022:beta":
+					modelInfo.maxTokens = 8_192
 					modelInfo.supportsPromptCache = true
 					modelInfo.cacheWritesPrice = 1.25
 					modelInfo.cacheReadsPrice = 0.1

+ 8 - 0
src/shared/api.ts

@@ -89,6 +89,13 @@ export interface ModelInfo {
 	cacheReadsPrice?: number
 	description?: string
 	reasoningEffort?: "low" | "medium" | "high"
+	thinking?: boolean
+}
+
+export const THINKING_BUDGET = {
+	step: 1024,
+	min: 1024,
+	default: 8 * 1024,
 }
 
 // Anthropic
@@ -106,6 +113,7 @@ export const anthropicModels = {
 		outputPrice: 15.0, // $15 per million output tokens
 		cacheWritesPrice: 3.75, // $3.75 per million tokens
 		cacheReadsPrice: 0.3, // $0.30 per million tokens
+		thinking: true,
 	},
 	"claude-3-5-sonnet-20241022": {
 		maxTokens: 8192,

+ 15 - 6
webview-ui/src/components/settings/ApiOptions.tsx

@@ -33,6 +33,7 @@ import {
 	unboundDefaultModelInfo,
 	requestyDefaultModelId,
 	requestyDefaultModelInfo,
+	THINKING_BUDGET,
 } from "../../../../src/shared/api"
 import { ExtensionMessage } from "../../../../src/shared/ExtensionMessage"
 
@@ -1270,12 +1271,20 @@ const ApiOptions = ({
 					</>
 				)}
 
-			{selectedProvider === "anthropic" && selectedModelId === "claude-3-7-sonnet-20250219" && (
+			{selectedModelInfo && selectedModelInfo.thinking && (
 				<div className="flex flex-col gap-2 mt-2">
 					<Checkbox
 						checked={!!anthropicThinkingBudget}
 						onChange={(checked) =>
-							setApiConfigurationField("anthropicThinking", checked ? 16_384 : undefined)
+							setApiConfigurationField(
+								"anthropicThinking",
+								checked
+									? Math.min(
+											THINKING_BUDGET.default,
+											selectedModelInfo.maxTokens ?? THINKING_BUDGET.default,
+										)
+									: undefined,
+							)
 						}>
 						Thinking?
 					</Checkbox>
@@ -1286,13 +1295,13 @@ const ApiOptions = ({
 							</div>
 							<div className="flex items-center gap-2">
 								<Slider
-									min={1024}
-									max={anthropicModels["claude-3-7-sonnet-20250219"].maxTokens - 1}
-									step={1024}
+									min={THINKING_BUDGET.min}
+									max={(selectedModelInfo.maxTokens ?? THINKING_BUDGET.default) - 1}
+									step={THINKING_BUDGET.step}
 									value={[anthropicThinkingBudget]}
 									onValueChange={(value) => setApiConfigurationField("anthropicThinking", value[0])}
 								/>
-								<div className="w-10">{anthropicThinkingBudget}</div>
+								<div className="w-12">{anthropicThinkingBudget}</div>
 							</div>
 						</>
 					)}

+ 2 - 2
webview-ui/src/components/ui/slider.tsx

@@ -11,8 +11,8 @@ const Slider = React.forwardRef<
 		ref={ref}
 		className={cn("relative flex w-full touch-none select-none items-center", className)}
 		{...props}>
-		<SliderPrimitive.Track className="relative h-1 w-full grow overflow-hidden bg-primary/20">
-			<SliderPrimitive.Range className="absolute h-full bg-primary" />
+		<SliderPrimitive.Track className="relative w-full h-[8px] grow overflow-hidden bg-vscode-button-secondaryBackground border border-[#767676] dark:border-[#858585] rounded-sm">
+			<SliderPrimitive.Range className="absolute h-full bg-vscode-button-background" />
 		</SliderPrimitive.Track>
 		<SliderPrimitive.Thumb className="block h-3 w-3 rounded-full border border-primary/50 bg-primary shadow transition-colors cursor-pointer focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50" />
 	</SliderPrimitive.Root>