Просмотр исходного кода

Merge pull request #1195 from RooVetGit/cte/openrouter-claude-thinking

Support Claude 3.7 Sonnet "Thinking" in OpenRouter
Chris Estreich 10 месяцев назад
Родитель
Сommit
06f98ca13a

+ 34 - 17
src/api/providers/openrouter.ts

@@ -52,6 +52,8 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			...convertToOpenAiMessages(messages),
 		]
 
+		const { id: modelId, info: modelInfo } = this.getModel()
+
 		// prompt caching: https://openrouter.ai/docs/prompt-caching
 		// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
 		switch (true) {
@@ -95,10 +97,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 		let topP: number | undefined = undefined
 
 		// Handle models based on deepseek-r1
-		if (
-			this.getModel().id.startsWith("deepseek/deepseek-r1") ||
-			this.getModel().id === "perplexity/sonar-reasoning"
-		) {
+		if (modelId.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning") {
 			// Recommended temperature for DeepSeek reasoning models
 			defaultTemperature = DEEP_SEEK_DEFAULT_TEMPERATURE
 			// DeepSeek highly recommends using user instead of system role
@@ -107,24 +106,34 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			topP = 0.95
 		}
 
+		let temperature = this.options.modelTemperature ?? defaultTemperature
+
+		// Anthropic "Thinking" models require a temperature of 1.0.
+		if (modelInfo.thinking) {
+			temperature = 1.0
+		}
+
 		// https://openrouter.ai/docs/transforms
 		let fullResponseText = ""
-		const stream = await this.client.chat.completions.create({
-			model: this.getModel().id,
-			max_tokens: this.getModel().info.maxTokens,
-			temperature: this.options.modelTemperature ?? defaultTemperature,
+
+		const completionParams: OpenRouterChatCompletionParams = {
+			model: modelId,
+			max_tokens: modelInfo.maxTokens,
+			temperature,
 			top_p: topP,
 			messages: openAiMessages,
 			stream: true,
 			include_reasoning: true,
 			// This way, the transforms field will only be included in the parameters when openRouterUseMiddleOutTransform is true.
 			...(this.options.openRouterUseMiddleOutTransform && { transforms: ["middle-out"] }),
-		} as OpenRouterChatCompletionParams)
+		}
+
+		const stream = await this.client.chat.completions.create(completionParams)
 
 		let genId: string | undefined
 
 		for await (const chunk of stream as unknown as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>) {
-			// openrouter returns an error object instead of the openai sdk throwing an error
+			// OpenRouter returns an error object instead of the OpenAI SDK throwing an error.
 			if ("error" in chunk) {
 				const error = chunk.error as { message?: string; code?: number }
 				console.error(`OpenRouter API Error: ${error?.code} - ${error?.message}`)
@@ -136,12 +145,14 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			}
 
 			const delta = chunk.choices[0]?.delta
+
 			if ("reasoning" in delta && delta.reasoning) {
 				yield {
 					type: "reasoning",
 					text: delta.reasoning,
 				} as ApiStreamChunk
 			}
+
 			if (delta?.content) {
 				fullResponseText += delta.content
 				yield {
@@ -149,6 +160,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 					text: delta.content,
 				} as ApiStreamChunk
 			}
+
 			// if (chunk.usage) {
 			// 	yield {
 			// 		type: "usage",
@@ -158,10 +170,12 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			// }
 		}
 
-		// retry fetching generation details
+		// Retry fetching generation details.
 		let attempt = 0
+
 		while (attempt++ < 10) {
 			await delay(200) // FIXME: necessary delay to ensure generation endpoint is ready
+
 			try {
 				const response = await axios.get(`https://openrouter.ai/api/v1/generation?id=${genId}`, {
 					headers: {
@@ -171,7 +185,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 				})
 
 				const generation = response.data?.data
-				console.log("OpenRouter generation details:", response.data)
+
 				yield {
 					type: "usage",
 					// cacheWriteTokens: 0,
@@ -182,6 +196,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 					totalCost: generation?.total_cost || 0,
 					fullResponseText,
 				} as OpenRouterApiStreamUsageChunk
+
 				return
 			} catch (error) {
 				// ignore if fails
@@ -189,13 +204,13 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			}
 		}
 	}
-	getModel(): { id: string; info: ModelInfo } {
+
+	getModel() {
 		const modelId = this.options.openRouterModelId
 		const modelInfo = this.options.openRouterModelInfo
-		if (modelId && modelInfo) {
-			return { id: modelId, info: modelInfo }
-		}
-		return { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
+		return modelId && modelInfo
+			? { id: modelId, info: modelInfo }
+			: { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
 	}
 
 	async completePrompt(prompt: string): Promise<string> {
@@ -218,6 +233,7 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 			if (error instanceof Error) {
 				throw new Error(`OpenRouter completion error: ${error.message}`)
 			}
+
 			throw error
 		}
 	}
@@ -239,6 +255,7 @@ export async function getOpenRouterModels() {
 				inputPrice: parseApiPrice(rawModel.pricing?.prompt),
 				outputPrice: parseApiPrice(rawModel.pricing?.completion),
 				description: rawModel.description,
+				thinking: rawModel.id === "anthropic/claude-3.7-sonnet:thinking",
 			}
 
 			// NOTE: this needs to be synced with api.ts/openrouter default model info.

+ 8 - 0
src/shared/api.ts

@@ -89,6 +89,13 @@ export interface ModelInfo {
 	cacheReadsPrice?: number
 	description?: string
 	reasoningEffort?: "low" | "medium" | "high"
+	thinking?: boolean
+}
+
+export const THINKING_BUDGET = {
+	step: 1024,
+	min: 1024,
+	default: 8 * 1024,
 }
 
 // Anthropic
@@ -106,6 +113,7 @@ export const anthropicModels = {
 		outputPrice: 15.0, // $15 per million output tokens
 		cacheWritesPrice: 3.75, // $3.75 per million tokens
 		cacheReadsPrice: 0.3, // $0.30 per million tokens
+		thinking: true,
 	},
 	"claude-3-5-sonnet-20241022": {
 		maxTokens: 8192,

+ 15 - 6
webview-ui/src/components/settings/ApiOptions.tsx

@@ -33,6 +33,7 @@ import {
 	unboundDefaultModelInfo,
 	requestyDefaultModelId,
 	requestyDefaultModelInfo,
+	THINKING_BUDGET,
 } from "../../../../src/shared/api"
 import { ExtensionMessage } from "../../../../src/shared/ExtensionMessage"
 
@@ -1270,12 +1271,20 @@ const ApiOptions = ({
 					</>
 				)}
 
-			{selectedProvider === "anthropic" && selectedModelId === "claude-3-7-sonnet-20250219" && (
+			{selectedModelInfo && selectedModelInfo.thinking && (
 				<div className="flex flex-col gap-2 mt-2">
 					<Checkbox
 						checked={!!anthropicThinkingBudget}
 						onChange={(checked) =>
-							setApiConfigurationField("anthropicThinking", checked ? 16_384 : undefined)
+							setApiConfigurationField(
+								"anthropicThinking",
+								checked
+									? Math.min(
+											THINKING_BUDGET.default,
+											selectedModelInfo.maxTokens ?? THINKING_BUDGET.default,
+										)
+									: undefined,
+							)
 						}>
 						Thinking?
 					</Checkbox>
@@ -1286,13 +1295,13 @@ const ApiOptions = ({
 							</div>
 							<div className="flex items-center gap-2">
 								<Slider
-									min={1024}
-									max={anthropicModels["claude-3-7-sonnet-20250219"].maxTokens - 1}
-									step={1024}
+									min={THINKING_BUDGET.min}
+									max={(selectedModelInfo.maxTokens ?? THINKING_BUDGET.default) - 1}
+									step={THINKING_BUDGET.step}
 									value={[anthropicThinkingBudget]}
 									onValueChange={(value) => setApiConfigurationField("anthropicThinking", value[0])}
 								/>
-								<div className="w-10">{anthropicThinkingBudget}</div>
+								<div className="w-12">{anthropicThinkingBudget}</div>
 							</div>
 						</>
 					)}

+ 2 - 2
webview-ui/src/components/ui/slider.tsx

@@ -11,8 +11,8 @@ const Slider = React.forwardRef<
 		ref={ref}
 		className={cn("relative flex w-full touch-none select-none items-center", className)}
 		{...props}>
-		<SliderPrimitive.Track className="relative h-1 w-full grow overflow-hidden bg-primary/20">
-			<SliderPrimitive.Range className="absolute h-full bg-primary" />
+		<SliderPrimitive.Track className="relative w-full h-[8px] grow overflow-hidden bg-vscode-button-secondaryBackground border border-[#767676] dark:border-[#858585] rounded-sm">
+			<SliderPrimitive.Range className="absolute h-full bg-vscode-button-background" />
 		</SliderPrimitive.Track>
 		<SliderPrimitive.Thumb className="block h-3 w-3 rounded-full border border-primary/50 bg-primary shadow transition-colors cursor-pointer focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring disabled:pointer-events-none disabled:opacity-50" />
 	</SliderPrimitive.Root>