Browse Source

Merge pull request #1305 from RooVetGit/cte/max-tokens-fix

Custom max tokens fix for non-thinking models
Chris Estreich 10 months ago
parent
commit
15a4c4c68c

+ 5 - 0
.changeset/fuzzy-donkeys-whisper.md

@@ -0,0 +1,5 @@
+---
+"roo-cline": patch
+---
+
+Don't honor custom max tokens for non thinking models

+ 28 - 0
src/api/providers/__tests__/anthropic.test.ts

@@ -194,5 +194,33 @@ describe("AnthropicHandler", () => {
 			expect(model.info.supportsImages).toBe(true)
 			expect(model.info.supportsPromptCache).toBe(true)
 		})
+
+		it("honors custom maxTokens for thinking models", () => {
+			const handler = new AnthropicHandler({
+				apiKey: "test-api-key",
+				apiModelId: "claude-3-7-sonnet-20250219:thinking",
+				modelMaxTokens: 32_768,
+				modelMaxThinkingTokens: 16_384,
+			})
+
+			const result = handler.getModel()
+			expect(result.maxTokens).toBe(32_768)
+			expect(result.thinking).toEqual({ type: "enabled", budget_tokens: 16_384 })
+			expect(result.temperature).toBe(1.0)
+		})
+
+		it("does not honor custom maxTokens for non-thinking models", () => {
+			const handler = new AnthropicHandler({
+				apiKey: "test-api-key",
+				apiModelId: "claude-3-7-sonnet-20250219",
+				modelMaxTokens: 32_768,
+				modelMaxThinkingTokens: 16_384,
+			})
+
+			const result = handler.getModel()
+			expect(result.maxTokens).toBe(16_384)
+			expect(result.thinking).toBeUndefined()
+			expect(result.temperature).toBe(0)
+		})
 	})
 })

+ 55 - 25
src/api/providers/__tests__/openrouter.test.ts

@@ -1,29 +1,30 @@
 // npx jest src/api/providers/__tests__/openrouter.test.ts
 
-import { OpenRouterHandler } from "../openrouter"
-import { ApiHandlerOptions, ModelInfo } from "../../../shared/api"
-import OpenAI from "openai"
 import axios from "axios"
 import { Anthropic } from "@anthropic-ai/sdk"
+import OpenAI from "openai"
+
+import { OpenRouterHandler } from "../openrouter"
+import { ApiHandlerOptions, ModelInfo } from "../../../shared/api"
 
 // Mock dependencies
 jest.mock("openai")
 jest.mock("axios")
 jest.mock("delay", () => jest.fn(() => Promise.resolve()))
 
+const mockOpenRouterModelInfo: ModelInfo = {
+	maxTokens: 1000,
+	contextWindow: 2000,
+	supportsPromptCache: true,
+	inputPrice: 0.01,
+	outputPrice: 0.02,
+}
+
 describe("OpenRouterHandler", () => {
 	const mockOptions: ApiHandlerOptions = {
 		openRouterApiKey: "test-key",
 		openRouterModelId: "test-model",
-		openRouterModelInfo: {
-			name: "Test Model",
-			description: "Test Description",
-			maxTokens: 1000,
-			contextWindow: 2000,
-			supportsPromptCache: true,
-			inputPrice: 0.01,
-			outputPrice: 0.02,
-		} as ModelInfo,
+		openRouterModelInfo: mockOpenRouterModelInfo,
 	}
 
 	beforeEach(() => {
@@ -50,6 +51,10 @@ describe("OpenRouterHandler", () => {
 		expect(result).toEqual({
 			id: mockOptions.openRouterModelId,
 			info: mockOptions.openRouterModelInfo,
+			maxTokens: 1000,
+			temperature: 0,
+			thinking: undefined,
+			topP: undefined,
 		})
 	})
 
@@ -61,6 +66,38 @@ describe("OpenRouterHandler", () => {
 		expect(result.info.supportsPromptCache).toBe(true)
 	})
 
+	test("getModel honors custom maxTokens for thinking models", () => {
+		const handler = new OpenRouterHandler({
+			openRouterApiKey: "test-key",
+			openRouterModelId: "test-model",
+			openRouterModelInfo: {
+				...mockOpenRouterModelInfo,
+				maxTokens: 64_000,
+				thinking: true,
+			},
+			modelMaxTokens: 32_768,
+			modelMaxThinkingTokens: 16_384,
+		})
+
+		const result = handler.getModel()
+		expect(result.maxTokens).toBe(32_768)
+		expect(result.thinking).toEqual({ type: "enabled", budget_tokens: 16_384 })
+		expect(result.temperature).toBe(1.0)
+	})
+
+	test("getModel does not honor custom maxTokens for non-thinking models", () => {
+		const handler = new OpenRouterHandler({
+			...mockOptions,
+			modelMaxTokens: 32_768,
+			modelMaxThinkingTokens: 16_384,
+		})
+
+		const result = handler.getModel()
+		expect(result.maxTokens).toBe(1000)
+		expect(result.thinking).toBeUndefined()
+		expect(result.temperature).toBe(0)
+	})
+
 	test("createMessage generates correct stream chunks", async () => {
 		const handler = new OpenRouterHandler(mockOptions)
 		const mockStream = {
@@ -242,15 +279,7 @@ describe("OpenRouterHandler", () => {
 
 	test("completePrompt returns correct response", async () => {
 		const handler = new OpenRouterHandler(mockOptions)
-		const mockResponse = {
-			choices: [
-				{
-					message: {
-						content: "test completion",
-					},
-				},
-			],
-		}
+		const mockResponse = { choices: [{ message: { content: "test completion" } }] }
 
 		const mockCreate = jest.fn().mockResolvedValue(mockResponse)
 		;(OpenAI as jest.MockedClass<typeof OpenAI>).prototype.chat = {
@@ -260,10 +289,13 @@ describe("OpenRouterHandler", () => {
 		const result = await handler.completePrompt("test prompt")
 
 		expect(result).toBe("test completion")
+
 		expect(mockCreate).toHaveBeenCalledWith({
 			model: mockOptions.openRouterModelId,
-			messages: [{ role: "user", content: "test prompt" }],
+			max_tokens: 1000,
+			thinking: undefined,
 			temperature: 0,
+			messages: [{ role: "user", content: "test prompt" }],
 			stream: false,
 		})
 	})
@@ -292,8 +324,6 @@ describe("OpenRouterHandler", () => {
 			completions: { create: mockCreate },
 		} as any
 
-		await expect(handler.completePrompt("test prompt")).rejects.toThrow(
-			"OpenRouter completion error: Unexpected error",
-		)
+		await expect(handler.completePrompt("test prompt")).rejects.toThrow("Unexpected error")
 	})
 })

+ 28 - 0
src/api/providers/__tests__/vertex.test.ts

@@ -890,6 +890,34 @@ describe("VertexHandler", () => {
 			expect(modelInfo.info.maxTokens).toBe(8192)
 			expect(modelInfo.info.contextWindow).toBe(1048576)
 		})
+
+		it("honors custom maxTokens for thinking models", () => {
+			const handler = new VertexHandler({
+				apiKey: "test-api-key",
+				apiModelId: "claude-3-7-sonnet@20250219:thinking",
+				modelMaxTokens: 32_768,
+				modelMaxThinkingTokens: 16_384,
+			})
+
+			const result = handler.getModel()
+			expect(result.maxTokens).toBe(32_768)
+			expect(result.thinking).toEqual({ type: "enabled", budget_tokens: 16_384 })
+			expect(result.temperature).toBe(1.0)
+		})
+
+		it("does not honor custom maxTokens for non-thinking models", () => {
+			const handler = new VertexHandler({
+				apiKey: "test-api-key",
+				apiModelId: "claude-3-7-sonnet@20250219",
+				modelMaxTokens: 32_768,
+				modelMaxThinkingTokens: 16_384,
+			})
+
+			const result = handler.getModel()
+			expect(result.maxTokens).toBe(16_384)
+			expect(result.thinking).toBeUndefined()
+			expect(result.temperature).toBe(0)
+		})
 	})
 
 	describe("thinking model configuration", () => {

+ 30 - 35
src/api/providers/anthropic.ts

@@ -12,8 +12,6 @@ import {
 import { ApiHandler, SingleCompletionHandler } from "../index"
 import { ApiStream } from "../transform/stream"
 
-const ANTHROPIC_DEFAULT_TEMPERATURE = 0
-
 export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
 	private client: Anthropic
@@ -30,7 +28,7 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
 	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
 		let stream: AnthropicStream<Anthropic.Messages.RawMessageStreamEvent>
 		const cacheControl: CacheControlEphemeral = { type: "ephemeral" }
-		let { id: modelId, temperature, maxTokens, thinking } = this.getModel()
+		let { id: modelId, maxTokens, thinking, temperature } = this.getModel()
 
 		switch (modelId) {
 			case "claude-3-7-sonnet-20250219":
@@ -182,55 +180,52 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
 
 	getModel() {
 		const modelId = this.options.apiModelId
-		let temperature = this.options.modelTemperature ?? ANTHROPIC_DEFAULT_TEMPERATURE
-		let thinking: BetaThinkingConfigParam | undefined = undefined
 
-		if (modelId && modelId in anthropicModels) {
-			let id = modelId as AnthropicModelId
-			const info: ModelInfo = anthropicModels[id]
+		const {
+			modelMaxTokens: customMaxTokens,
+			modelMaxThinkingTokens: customMaxThinkingTokens,
+			modelTemperature: customTemperature,
+		} = this.options
 
-			// The `:thinking` variant is a virtual identifier for the
-			// `claude-3-7-sonnet-20250219` model with a thinking budget.
-			// We can handle this more elegantly in the future.
-			if (id === "claude-3-7-sonnet-20250219:thinking") {
-				id = "claude-3-7-sonnet-20250219"
-			}
+		let id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId
+		const info: ModelInfo = anthropicModels[id]
 
-			const maxTokens = this.options.modelMaxTokens || info.maxTokens || 8192
+		// The `:thinking` variant is a virtual identifier for the
+		// `claude-3-7-sonnet-20250219` model with a thinking budget.
+		// We can handle this more elegantly in the future.
+		if (id === "claude-3-7-sonnet-20250219:thinking") {
+			id = "claude-3-7-sonnet-20250219"
+		}
 
-			if (info.thinking) {
-				// Anthropic "Thinking" models require a temperature of 1.0.
-				temperature = 1.0
+		let maxTokens = info.maxTokens ?? 8192
+		let thinking: BetaThinkingConfigParam | undefined = undefined
+		let temperature = customTemperature ?? 0
 
-				// Clamp the thinking budget to be at most 80% of max tokens and at
-				// least 1024 tokens.
-				const maxBudgetTokens = Math.floor(maxTokens * 0.8)
-				const budgetTokens = Math.max(
-					Math.min(this.options.modelMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens),
-					1024,
-				)
+		if (info.thinking) {
+			// Only honor `customMaxTokens` for thinking models.
+			maxTokens = customMaxTokens ?? maxTokens
 
-				thinking = { type: "enabled", budget_tokens: budgetTokens }
-			}
+			// Clamp the thinking budget to be at most 80% of max tokens and at
+			// least 1024 tokens.
+			const maxBudgetTokens = Math.floor(maxTokens * 0.8)
+			const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024)
+			thinking = { type: "enabled", budget_tokens: budgetTokens }
 
-			return { id, info, temperature, maxTokens, thinking }
+			// Anthropic "Thinking" models require a temperature of 1.0.
+			temperature = 1.0
 		}
 
-		const id = anthropicDefaultModelId
-		const info: ModelInfo = anthropicModels[id]
-		const maxTokens = this.options.modelMaxTokens || info.maxTokens || 8192
-
-		return { id, info, temperature, maxTokens, thinking }
+		return { id, info, maxTokens, thinking, temperature }
 	}
 
 	async completePrompt(prompt: string) {
-		let { id: modelId, temperature, maxTokens, thinking } = this.getModel()
+		let { id: modelId, maxTokens, thinking, temperature } = this.getModel()
 
 		const message = await this.client.messages.create({
 			model: modelId,
 			max_tokens: maxTokens,
-			temperature,
 			thinking,
+			temperature,
 			messages: [{ role: "user", content: prompt }],
 			stream: false,
 		})

+ 64 - 56
src/api/providers/openrouter.ts

@@ -48,13 +48,18 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
 	): AsyncGenerator<ApiStreamChunk> {
-		// Convert Anthropic messages to OpenAI format
+		let { id: modelId, maxTokens, thinking, temperature, topP } = this.getModel()
+
+		// Convert Anthropic messages to OpenAI format.
 		let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
 			{ role: "system", content: systemPrompt },
 			...convertToOpenAiMessages(messages),
 		]
 
-		const { id: modelId, info: modelInfo } = this.getModel()
+		// DeepSeek highly recommends using user instead of system role.
+		if (modelId.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning") {
+			openAiMessages = convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
+		}
 
 		// prompt caching: https://openrouter.ai/docs/prompt-caching
 		// this is specifically for claude models (some models may 'support prompt caching' automatically without this)
@@ -95,42 +100,12 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 				break
 		}
 
-		let defaultTemperature = OPENROUTER_DEFAULT_TEMPERATURE
-		let topP: number | undefined = undefined
-
-		// Handle models based on deepseek-r1
-		if (modelId.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning") {
-			// Recommended temperature for DeepSeek reasoning models
-			defaultTemperature = DEEP_SEEK_DEFAULT_TEMPERATURE
-			// DeepSeek highly recommends using user instead of system role
-			openAiMessages = convertToR1Format([{ role: "user", content: systemPrompt }, ...messages])
-			// Some provider support topP and 0.95 is value that Deepseek used in their benchmarks
-			topP = 0.95
-		}
-
-		const maxTokens = this.options.modelMaxTokens || modelInfo.maxTokens
-		let temperature = this.options.modelTemperature ?? defaultTemperature
-		let thinking: BetaThinkingConfigParam | undefined = undefined
-
-		if (modelInfo.thinking) {
-			// Clamp the thinking budget to be at most 80% of max tokens and at
-			// least 1024 tokens.
-			const maxBudgetTokens = Math.floor((maxTokens || 8192) * 0.8)
-			const budgetTokens = Math.max(
-				Math.min(this.options.modelMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens),
-				1024,
-			)
-
-			thinking = { type: "enabled", budget_tokens: budgetTokens }
-			temperature = 1.0
-		}
-
 		// https://openrouter.ai/docs/transforms
 		let fullResponseText = ""
 
 		const completionParams: OpenRouterChatCompletionParams = {
 			model: modelId,
-			max_tokens: modelInfo.maxTokens,
+			max_tokens: maxTokens,
 			temperature,
 			thinking, // OpenRouter is temporarily supporting this.
 			top_p: topP,
@@ -221,34 +196,67 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 	getModel() {
 		const modelId = this.options.openRouterModelId
 		const modelInfo = this.options.openRouterModelInfo
-		return modelId && modelInfo
-			? { id: modelId, info: modelInfo }
-			: { id: openRouterDefaultModelId, info: openRouterDefaultModelInfo }
+
+		let id = modelId ?? openRouterDefaultModelId
+		const info = modelInfo ?? openRouterDefaultModelInfo
+
+		const {
+			modelMaxTokens: customMaxTokens,
+			modelMaxThinkingTokens: customMaxThinkingTokens,
+			modelTemperature: customTemperature,
+		} = this.options
+
+		let maxTokens = info.maxTokens
+		let thinking: BetaThinkingConfigParam | undefined = undefined
+		let temperature = customTemperature ?? OPENROUTER_DEFAULT_TEMPERATURE
+		let topP: number | undefined = undefined
+
+		// Handle models based on deepseek-r1
+		if (id.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning") {
+			// Recommended temperature for DeepSeek reasoning models.
+			temperature = customTemperature ?? DEEP_SEEK_DEFAULT_TEMPERATURE
+			// Some provider support topP and 0.95 is value that Deepseek used in their benchmarks.
+			topP = 0.95
+		}
+
+		if (info.thinking) {
+			// Only honor `customMaxTokens` for thinking models.
+			maxTokens = customMaxTokens ?? maxTokens
+
+			// Clamp the thinking budget to be at most 80% of max tokens and at
+			// least 1024 tokens.
+			const maxBudgetTokens = Math.floor((maxTokens || 8192) * 0.8)
+			const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024)
+			thinking = { type: "enabled", budget_tokens: budgetTokens }
+
+			// Anthropic "Thinking" models require a temperature of 1.0.
+			temperature = 1.0
+		}
+
+		return { id, info, maxTokens, thinking, temperature, topP }
 	}
 
-	async completePrompt(prompt: string): Promise<string> {
-		try {
-			const response = await this.client.chat.completions.create({
-				model: this.getModel().id,
-				messages: [{ role: "user", content: prompt }],
-				temperature: this.options.modelTemperature ?? OPENROUTER_DEFAULT_TEMPERATURE,
-				stream: false,
-			})
-
-			if ("error" in response) {
-				const error = response.error as { message?: string; code?: number }
-				throw new Error(`OpenRouter API Error ${error?.code}: ${error?.message}`)
-			}
+	async completePrompt(prompt: string) {
+		let { id: modelId, maxTokens, thinking, temperature } = this.getModel()
 
-			const completion = response as OpenAI.Chat.ChatCompletion
-			return completion.choices[0]?.message?.content || ""
-		} catch (error) {
-			if (error instanceof Error) {
-				throw new Error(`OpenRouter completion error: ${error.message}`)
-			}
+		const completionParams: OpenRouterChatCompletionParams = {
+			model: modelId,
+			max_tokens: maxTokens,
+			thinking,
+			temperature,
+			messages: [{ role: "user", content: prompt }],
+			stream: false,
+		}
+
+		const response = await this.client.chat.completions.create(completionParams)
 
-			throw error
+		if ("error" in response) {
+			const error = response.error as { message?: string; code?: number }
+			throw new Error(`OpenRouter API Error ${error?.code}: ${error?.message}`)
 		}
+
+		const completion = response as OpenAI.Chat.ChatCompletion
+		return completion.choices[0]?.message?.content || ""
 	}
 }
 

+ 32 - 30
src/api/providers/vertex.ts

@@ -202,6 +202,7 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 		}
 
 		const response = await result.response
+
 		yield {
 			type: "usage",
 			inputTokens: response.usageMetadata?.promptTokenCount ?? 0,
@@ -351,43 +352,44 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 		thinking?: BetaThinkingConfigParam
 	} {
 		const modelId = this.options.apiModelId
-		let temperature = this.options.modelTemperature ?? 0
-		let thinking: BetaThinkingConfigParam | undefined = undefined
 
-		if (modelId && modelId in vertexModels) {
-			const id = modelId as VertexModelId
-			const info: ModelInfo = vertexModels[id]
+		const {
+			modelMaxTokens: customMaxTokens,
+			modelMaxThinkingTokens: customMaxThinkingTokens,
+			modelTemperature: customTemperature,
+		} = this.options
 
-			// The `:thinking` variant is a virtual identifier for thinking-enabled models
-			// Similar to how it's handled in the Anthropic provider
-			let actualId = id
-			if (id.endsWith(":thinking")) {
-				actualId = id.replace(":thinking", "") as VertexModelId
-			}
+		let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId
+		const info: ModelInfo = vertexModels[id]
 
-			const maxTokens = this.options.modelMaxTokens || info.maxTokens || 8192
+		// The `:thinking` variant is a virtual identifier for thinking-enabled
+		// models (similar to how it's handled in the Anthropic provider.)
+		if (id.endsWith(":thinking")) {
+			id = id.replace(":thinking", "") as VertexModelId
+		}
 
-			if (info.thinking) {
-				temperature = 1.0 // Thinking requires temperature 1.0
-				const maxBudgetTokens = Math.floor(maxTokens * 0.8)
-				const budgetTokens = Math.max(
-					Math.min(this.options.modelMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens),
-					1024,
-				)
-				thinking = { type: "enabled", budget_tokens: budgetTokens }
-			}
+		let maxTokens = info.maxTokens || 8192
+		let thinking: BetaThinkingConfigParam | undefined = undefined
+		let temperature = customTemperature ?? 0
 
-			return { id: actualId, info, temperature, maxTokens, thinking }
-		}
+		if (info.thinking) {
+			// Only honor `customMaxTokens` for thinking models.
+			maxTokens = customMaxTokens ?? maxTokens
 
-		const id = vertexDefaultModelId
-		const info = vertexModels[id]
-		const maxTokens = this.options.modelMaxTokens || info.maxTokens || 8192
+			// Clamp the thinking budget to be at most 80% of max tokens and at
+			// least 1024 tokens.
+			const maxBudgetTokens = Math.floor(maxTokens * 0.8)
+			const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024)
+			thinking = { type: "enabled", budget_tokens: budgetTokens }
+
+			// Anthropic "Thinking" models require a temperature of 1.0.
+			temperature = 1.0
+		}
 
-		return { id, info, temperature, maxTokens, thinking }
+		return { id, info, maxTokens, thinking, temperature }
 	}
 
-	private async completePromptGemini(prompt: string): Promise<string> {
+	private async completePromptGemini(prompt: string) {
 		try {
 			const model = this.geminiClient.getGenerativeModel({
 				model: this.getModel().id,
@@ -416,7 +418,7 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 		}
 	}
 
-	private async completePromptClaude(prompt: string): Promise<string> {
+	private async completePromptClaude(prompt: string) {
 		try {
 			let { id, info, temperature, maxTokens, thinking } = this.getModel()
 			const useCache = info.supportsPromptCache
@@ -461,7 +463,7 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 		}
 	}
 
-	async completePrompt(prompt: string): Promise<string> {
+	async completePrompt(prompt: string) {
 		switch (this.modelType) {
 			case this.MODEL_CLAUDE: {
 				return this.completePromptClaude(prompt)