Parcourir la source

Support Claude 3.7 Sonnet "Thinking" in OpenRouter

Chris Estreich il y a 10 mois
Parent
commit
392a237985

+ 61 - 31
src/api/providers/openrouter.ts

@@ -52,10 +52,14 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			...convertToOpenAiMessages(messages),
 			...convertToOpenAiMessages(messages),
 		]
 		]
 
 
+		const { id: modelId, info: modelInfo } = this.getModel()
+
 		// prompt caching: https://openrouter.ai/docs/prompt-caching
 		// prompt caching: https://openrouter.ai/docs/prompt-caching
 		// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
 		// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
-		switch (this.getModel().id) {
+		switch (modelId) {
+			case "anthropic/claude-3.7-sonnet:thinking":
 			case "anthropic/claude-3.7-sonnet":
 			case "anthropic/claude-3.7-sonnet":
+			case "anthropic/claude-3.7-sonnet:beta":
 			case "anthropic/claude-3.5-sonnet":
 			case "anthropic/claude-3.5-sonnet":
 			case "anthropic/claude-3.5-sonnet:beta":
 			case "anthropic/claude-3.5-sonnet:beta":
 			case "anthropic/claude-3.5-sonnet-20240620":
 			case "anthropic/claude-3.5-sonnet-20240620":
@@ -103,31 +107,25 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 				break
 				break
 		}
 		}
 
 
-		// Not sure how openrouter defaults max tokens when no value is provided, but the anthropic api requires this value and since they offer both 4096 and 8192 variants, we should ensure 8192.
-		// (models usually default to max tokens allowed)
-		let maxTokens: number | undefined
-		switch (this.getModel().id) {
-			case "anthropic/claude-3.7-sonnet":
-			case "anthropic/claude-3.5-sonnet":
-			case "anthropic/claude-3.5-sonnet:beta":
-			case "anthropic/claude-3.5-sonnet-20240620":
-			case "anthropic/claude-3.5-sonnet-20240620:beta":
-			case "anthropic/claude-3-5-haiku":
-			case "anthropic/claude-3-5-haiku:beta":
-			case "anthropic/claude-3-5-haiku-20241022":
-			case "anthropic/claude-3-5-haiku-20241022:beta":
-				maxTokens = 8_192
-				break
+		// Not sure how openrouter defaults max tokens when no value is
+		// provided, but the Anthropic API requires this value and since they
+		// offer both 4096 and 8192 variants, we should ensure 8192.
+		// (Models usually default to max tokens allowed.)
+		let maxTokens: number | undefined = undefined
+
+		if (modelId.startsWith("anthropic/claude-3.5")) {
+			maxTokens = modelInfo.maxTokens ?? 8_192
+		}
+
+		if (modelId.startsWith("anthropic/claude-3.7")) {
+			maxTokens = modelInfo.maxTokens ?? 16_384
 		}
 		}
 
 
 		let defaultTemperature = OPENROUTER_DEFAULT_TEMPERATURE
 		let defaultTemperature = OPENROUTER_DEFAULT_TEMPERATURE
 		let topP: number | undefined = undefined
 		let topP: number | undefined = undefined
 
 
 		// Handle models based on deepseek-r1
 		// Handle models based on deepseek-r1
-		if (
-			this.getModel().id.startsWith("deepseek/deepseek-r1") ||
-			this.getModel().id === "perplexity/sonar-reasoning"
-		) {
+		if (modelId.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning") {
 			// Recommended temperature for DeepSeek reasoning models
 			// Recommended temperature for DeepSeek reasoning models
 			defaultTemperature = DEEP_SEEK_DEFAULT_TEMPERATURE
 			defaultTemperature = DEEP_SEEK_DEFAULT_TEMPERATURE
 			// DeepSeek highly recommends using user instead of system role
 			// DeepSeek highly recommends using user instead of system role
@@ -136,24 +134,37 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			topP = 0.95
 			topP = 0.95
 		}
 		}
 
 
+		let temperature = this.options.modelTemperature ?? defaultTemperature
+
+		if (modelInfo.thinking) {
+			temperature = 1.0
+		}
+
 		// https://openrouter.ai/docs/transforms
 		// https://openrouter.ai/docs/transforms
 		let fullResponseText = ""
 		let fullResponseText = ""
-		const stream = await this.client.chat.completions.create({
-			model: this.getModel().id,
+
+		const completionParams: OpenRouterChatCompletionParams = {
+			model: modelId,
 			max_tokens: maxTokens,
 			max_tokens: maxTokens,
-			temperature: this.options.modelTemperature ?? defaultTemperature,
+			temperature,
 			top_p: topP,
 			top_p: topP,
 			messages: openAiMessages,
 			messages: openAiMessages,
 			stream: true,
 			stream: true,
 			include_reasoning: true,
 			include_reasoning: true,
 			// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
 			// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
 			...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }),
 			...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }),
-		} as OpenRouterChatCompletionParams)
+		}
+
+		console.log("OpenRouter completionParams:", completionParams)
+
+		const stream = await this.client.chat.completions.create(completionParams)
 
 
 		let genId: string | undefined
 		let genId: string | undefined
 
 
 		for await (const chunk of stream as unknown as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>) {
 		for await (const chunk of stream as unknown as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>) {
-			// openrouter returns an error object instead of the openai sdk throwing an error
+			console.log("OpenRouter chunk:", chunk)
+
+			// OpenRouter returns an error object instead of the OpenAI SDK throwing an error.
 			if ("error" in chunk) {
 			if ("error" in chunk) {
 				const error = chunk.error as { message?: string; code?: number }
 				const error = chunk.error as { message?: string; code?: number }
 				console.error(`OpenRouter API Error: ${error?.code} - ${error?.message}`)
 				console.error(`OpenRouter API Error: ${error?.code} - ${error?.message}`)
@@ -165,12 +176,14 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			}
 			}
 
 
 			const delta = chunk.choices[0]?.delta
 			const delta = chunk.choices[0]?.delta
+
 			if ("reasoning" in delta && delta.reasoning) {
 			if ("reasoning" in delta && delta.reasoning) {
 				yield {
 				yield {
 					type: "reasoning",
 					type: "reasoning",
 					text: delta.reasoning,
 					text: delta.reasoning,
 				} as ApiStreamChunk
 				} as ApiStreamChunk
 			}
 			}
+
 			if (delta?.content) {
 			if (delta?.content) {
 				fullResponseText += delta.content
 				fullResponseText += delta.content
 				yield {
 				yield {
@@ -178,6 +191,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 					text: delta.content,
 					text: delta.content,
 				} as ApiStreamChunk
 				} as ApiStreamChunk
 			}
 			}
+
 			// if (chunk.usage) {
 			// if (chunk.usage) {
 			// 	yield {
 			// 	yield {
 			// 		type: "usage",
 			// 		type: "usage",
@@ -187,10 +201,12 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			// }
 			// }
 		}
 		}
 
 
-		// retry fetching generation details
+		// Retry fetching generation details.
 		let attempt = 0
 		let attempt = 0
+
 		while (attempt++ < 10) {
 		while (attempt++ < 10) {
 			await delay(200) // FIXME: necessary delay to ensure generation endpoint is ready
 			await delay(200) // FIXME: necessary delay to ensure generation endpoint is ready
+
 			try {
 			try {
 				const response = await axios.get(`https://openrouter.ai/api/v1/generation?id=${genId}`, {
 				const response = await axios.get(`https://openrouter.ai/api/v1/generation?id=${genId}`, {
 					headers: {
 					headers: {
@@ -201,6 +217,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 
 
 				const generation = response.data?.data
 				const generation = response.data?.data
 				console.log("OpenRouter generation details:", response.data)
 				console.log("OpenRouter generation details:", response.data)
+
 				yield {
 				yield {
 					type: "usage",
 					type: "usage",
 					// cacheWriteTokens: 0,
 					// cacheWriteTokens: 0,
@@ -211,6 +228,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 					totalCost: generation?.total_cost || 0,
 					totalCost: generation?.total_cost || 0,
 					fullResponseText,
 					fullResponseText,
 				} as OpenRouterApiStreamUsageChunk
 				} as OpenRouterApiStreamUsageChunk
+
 				return
 				return
 			} catch (error) {
 			} catch (error) {
 				// ignore if fails
 				// ignore if fails
@@ -218,13 +236,13 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			}
 			}
 		}
 		}
 	}
 	}
-	getModel(): { id: string; info: ModelInfo } {
+
+	getModel() {
 		const modelId = this.options.openRouterModelId
 		const modelId = this.options.openRouterModelId
 		const modelInfo = this.options.openRouterModelInfo
 		const modelInfo = this.options.openRouterModelInfo
-		if (modelId && modelInfo) {
-			return { id: modelId, info: modelInfo }
-		}
-		return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
+		return modelId && modelInfo
+			? { id: modelId, info: modelInfo }
+			: { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
 	}
 	}
 
 
 	async completePrompt(prompt: string): Promise<string> {
 	async completePrompt(prompt: string): Promise<string> {
@@ -247,6 +265,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			if (error instanceof Error) {
 			if (error instanceof Error) {
 				throw new Error(`OpenRouter completion error: ${error.message}`)
 				throw new Error(`OpenRouter completion error: ${error.message}`)
 			}
 			}
+
 			throw error
 			throw error
 		}
 		}
 	}
 	}
@@ -268,14 +287,23 @@ export async function getOpenRouterModels() {
 				inputPrice: parseApiPrice(rawModel.pricing?.prompt),
 				inputPrice: parseApiPrice(rawModel.pricing?.prompt),
 				outputPrice: parseApiPrice(rawModel.pricing?.completion),
 				outputPrice: parseApiPrice(rawModel.pricing?.completion),
 				description: rawModel.description,
 				description: rawModel.description,
+				thinking: rawModel.id === "anthropic/claude-3.7-sonnet:thinking",
 			}
 			}
 
 
 			switch (rawModel.id) {
 			switch (rawModel.id) {
+				case "anthropic/claude-3.7-sonnet:thinking":
 				case "anthropic/claude-3.7-sonnet":
 				case "anthropic/claude-3.7-sonnet":
 				case "anthropic/claude-3.7-sonnet:beta":
 				case "anthropic/claude-3.7-sonnet:beta":
+					modelInfo.maxTokens = 16_384
+					modelInfo.supportsComputerUse = true
+					modelInfo.supportsPromptCache = true
+					modelInfo.cacheWritesPrice = 3.75
+					modelInfo.cacheReadsPrice = 0.3
+					break
 				case "anthropic/claude-3.5-sonnet":
 				case "anthropic/claude-3.5-sonnet":
 				case "anthropic/claude-3.5-sonnet:beta":
 				case "anthropic/claude-3.5-sonnet:beta":
 					// NOTE: This needs to be synced with api.ts/openrouter default model info.
 					// NOTE: This needs to be synced with api.ts/openrouter default model info.
+					modelInfo.maxTokens = 8_192
 					modelInfo.supportsComputerUse = true
 					modelInfo.supportsComputerUse = true
 					modelInfo.supportsPromptCache = true
 					modelInfo.supportsPromptCache = true
 					modelInfo.cacheWritesPrice = 3.75
 					modelInfo.cacheWritesPrice = 3.75
@@ -283,6 +311,7 @@ export async function getOpenRouterModels() {
 					break
 					break
 				case "anthropic/claude-3.5-sonnet-20240620":
 				case "anthropic/claude-3.5-sonnet-20240620":
 				case "anthropic/claude-3.5-sonnet-20240620:beta":
 				case "anthropic/claude-3.5-sonnet-20240620:beta":
+					modelInfo.maxTokens = 8_192
 					modelInfo.supportsPromptCache = true
 					modelInfo.supportsPromptCache = true
 					modelInfo.cacheWritesPrice = 3.75
 					modelInfo.cacheWritesPrice = 3.75
 					modelInfo.cacheReadsPrice = 0.3
 					modelInfo.cacheReadsPrice = 0.3
@@ -295,6 +324,7 @@ export async function getOpenRouterModels() {
 				case "anthropic/claude-3.5-haiku:beta":
 				case "anthropic/claude-3.5-haiku:beta":
 				case "anthropic/claude-3.5-haiku-20241022":
 				case "anthropic/claude-3.5-haiku-20241022":
 				case "anthropic/claude-3.5-haiku-20241022:beta":
 				case "anthropic/claude-3.5-haiku-20241022:beta":
+					modelInfo.maxTokens = 8_192
 					modelInfo.supportsPromptCache = true
 					modelInfo.supportsPromptCache = true
 					modelInfo.cacheWritesPrice = 1.25
 					modelInfo.cacheWritesPrice = 1.25
 					modelInfo.cacheReadsPrice = 0.1
 					modelInfo.cacheReadsPrice = 0.1

+ 8 - 0
src/shared/api.ts

@@ -89,6 +89,13 @@ export interface ModelInfo {
 	cacheReadsPrice?: number
 	cacheReadsPrice?: number
 	description?: string
 	description?: string
 	reasoningEffort?: "low" | "medium" | "high"
 	reasoningEffort?: "low" | "medium" | "high"
+	thinking?: boolean
+}
+
+export const THINKING_BUDGET = {
+	step: 1024,
+	min: 1024,
+	default: 8 * 1024,
 }
 }
 
 
 // Anthropic
 // Anthropic
@@ -106,6 +113,7 @@ export const anthropicModels = {
 		outputPrice: 15.0, // $15 per million output tokens
 		outputPrice: 15.0, // $15 per million output tokens
 		cacheWritesPrice: 3.75, // $3.75 per million tokens
 		cacheWritesPrice: 3.75, // $3.75 per million tokens
 		cacheReadsPrice: 0.3, // $0.30 per million tokens
 		cacheReadsPrice: 0.3, // $0.30 per million tokens
+		thinking: true,
 	},
 	},
 	"claude-3-5-sonnet-20241022": {
 	"claude-3-5-sonnet-20241022": {
 		maxTokens: 8192,
 		maxTokens: 8192,

+ 15 - 6
webview-ui/src/components/settings/ApiOptions.tsx

@@ -33,6 +33,7 @@ import {
 	unboundDefaultModelInfo,
 	unboundDefaultModelInfo,
 	requestyDefaultModelId,
 	requestyDefaultModelId,
 	requestyDefaultModelInfo,
 	requestyDefaultModelInfo,
+	THINKING_BUDGET,
 } from "../../../../src/shared/api"
 } from "../../../../src/shared/api"
 import { ExtensionMessage } from "../../../../src/shared/ExtensionMessage"
 import { ExtensionMessage } from "../../../../src/shared/ExtensionMessage"
 
 
@@ -1270,12 +1271,20 @@ const ApiOptions = ({
 					</>
 					</>
 				)}
 				)}
 
 
-			{selectedProvider === "anthropic" && selectedModelId === "claude-3-7-sonnet-20250219" && (
+			{selectedModelInfo && selectedModelInfo.thinking && (
 				<div className="flex flex-col gap-2 mt-2">
 				<div className="flex flex-col gap-2 mt-2">
 					<Checkbox
 					<Checkbox
 						checked={!!anthropicThinkingBudget}
 						checked={!!anthropicThinkingBudget}
 						onChange={(checked) =>
 						onChange={(checked) =>
-							setApiConfigurationField("anthropicThinking", checked ? 16_384 : undefined)
+							setApiConfigurationField(
+								"anthropicThinking",
+								checked
+									? Math.min(
+											THINKING_BUDGET.default,
+											selectedModelInfo.maxTokens ?? THINKING_BUDGET.default,
+										)
+									: undefined,
+							)
 						}>
 						}>
 						Thinking?
 						Thinking?
 					</Checkbox>
 					</Checkbox>
@@ -1286,13 +1295,13 @@ const ApiOptions = ({
 							</div>
 							</div>
 							<div className="flex items-center gap-2">
 							<div className="flex items-center gap-2">
 								<Slider
 								<Slider
-									min={1024}
-									max={anthropicModels["claude-3-7-sonnet-20250219"].maxTokens - 1}
-									step={1024}
+									min={THINKING_BUDGET.min}
+									max={(selectedModelInfo.maxTokens ?? THINKING_BUDGET.default) - 1}
+									step={THINKING_BUDGET.step}
 									value={[anthropicThinkingBudget]}
 									value={[anthropicThinkingBudget]}
 									onValueChange={(value) => setApiConfigurationField("anthropicThinking", value[0])}
 									onValueChange={(value) => setApiConfigurationField("anthropicThinking", value[0])}
 								/>
 								/>
-								<div className="w-10">{anthropicThinkingBudget}</div>
+								<div className="w-12">{anthropicThinkingBudget}</div>
 							</div>
 							</div>
 						</>
 						</>
 					)}
 					)}

+ 2 - 2
webview-ui/src/components/ui/slider.tsx

@@ -11,8 +11,8 @@ const Slider = React.forwardRef<
 		ref={ref}
 		ref={ref}
 		className={cn("relative flex w-full touch-none select-none items-center", className)}
 		className={cn("relative flex w-full touch-none select-none items-center", className)}
 		{...props}>
 		{...props}>
-		<SliderPrimitive.Track className="relative h-1 w-full grow overflow-hidden bg-primary/20">
-			<SliderPrimitive.Range className="absolute h-full bg-primary" />
+		<SliderPrimitive.Track className="relative w-full h-[8px] grow overflow-hidden bg-vscode-button-secondaryBackground border border-[#767676] dark:border-[#858585] rounded-sm">
+			<SliderPrimitive.Range className="absolute h-full bg-vscode-button-background" />
 		</SliderPrimitive.Track>
 		</SliderPrimitive.Track>
 		<SliderPrimitive.Thumb className="block h-3 w-3 rounded-full border border-primary/50 bg-primary shadow transition-colors cursor-pointer focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50" />
 		<SliderPrimitive.Thumb className="block h-3 w-3 rounded-full border border-primary/50 bg-primary shadow transition-colors cursor-pointer focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50" />
 	</SliderPrimitive.Root>
 	</SliderPrimitive.Root>