Przeglądaj źródła

Merge pull request #1307 from RooVetGit/cte/dry-get-model

DRY up getModel
Chris Estreich 10 miesięcy temu
rodzic
commit
a2d441c5a0

+ 257 - 0
src/api/__tests__/index.test.ts

@@ -0,0 +1,257 @@
+// npx jest src/api/__tests__/index.test.ts
+
+import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta/messages/index.mjs"
+
+import { getModelParams } from "../index"
+import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "../providers/constants"
+
+describe("getModelParams", () => {
+	it("should return default values when no custom values are provided", () => {
+		const options = {}
+		const model = {
+			id: "test-model",
+			contextWindow: 16000,
+			supportsPromptCache: true,
+		}
+
+		const result = getModelParams({
+			options,
+			model,
+			defaultMaxTokens: 1000,
+			defaultTemperature: 0.5,
+		})
+
+		expect(result).toEqual({
+			maxTokens: 1000,
+			thinking: undefined,
+			temperature: 0.5,
+		})
+	})
+
+	it("should use custom temperature from options when provided", () => {
+		const options = { modelTemperature: 0.7 }
+		const model = {
+			id: "test-model",
+			contextWindow: 16000,
+			supportsPromptCache: true,
+		}
+
+		const result = getModelParams({
+			options,
+			model,
+			defaultMaxTokens: 1000,
+			defaultTemperature: 0.5,
+		})
+
+		expect(result).toEqual({
+			maxTokens: 1000,
+			thinking: undefined,
+			temperature: 0.7,
+		})
+	})
+
+	it("should use model maxTokens when available", () => {
+		const options = {}
+		const model = {
+			id: "test-model",
+			maxTokens: 2000,
+			contextWindow: 16000,
+			supportsPromptCache: true,
+		}
+
+		const result = getModelParams({
+			options,
+			model,
+			defaultMaxTokens: 1000,
+		})
+
+		expect(result).toEqual({
+			maxTokens: 2000,
+			thinking: undefined,
+			temperature: 0,
+		})
+	})
+
+	it("should handle thinking models correctly", () => {
+		const options = {}
+		const model = {
+			id: "test-model",
+			thinking: true,
+			maxTokens: 2000,
+			contextWindow: 16000,
+			supportsPromptCache: true,
+		}
+
+		const result = getModelParams({
+			options,
+			model,
+		})
+
+		const expectedThinking: BetaThinkingConfigParam = {
+			type: "enabled",
+			budget_tokens: 1600, // 80% of 2000
+		}
+
+		expect(result).toEqual({
+			maxTokens: 2000,
+			thinking: expectedThinking,
+			temperature: 1.0, // Thinking models require temperature 1.0.
+		})
+	})
+
+	it("should honor customMaxTokens for thinking models", () => {
+		const options = { modelMaxTokens: 3000 }
+		const model = {
+			id: "test-model",
+			thinking: true,
+			contextWindow: 16000,
+			supportsPromptCache: true,
+		}
+
+		const result = getModelParams({
+			options,
+			model,
+			defaultMaxTokens: 2000,
+		})
+
+		const expectedThinking: BetaThinkingConfigParam = {
+			type: "enabled",
+			budget_tokens: 2400, // 80% of 3000
+		}
+
+		expect(result).toEqual({
+			maxTokens: 3000,
+			thinking: expectedThinking,
+			temperature: 1.0,
+		})
+	})
+
+	it("should honor customMaxThinkingTokens for thinking models", () => {
+		const options = { modelMaxThinkingTokens: 1500 }
+		const model = {
+			id: "test-model",
+			thinking: true,
+			maxTokens: 4000,
+			contextWindow: 16000,
+			supportsPromptCache: true,
+		}
+
+		const result = getModelParams({
+			options,
+			model,
+		})
+
+		const expectedThinking: BetaThinkingConfigParam = {
+			type: "enabled",
+			budget_tokens: 1500, // Using the custom value
+		}
+
+		expect(result).toEqual({
+			maxTokens: 4000,
+			thinking: expectedThinking,
+			temperature: 1.0,
+		})
+	})
+
+	it("should not honor customMaxThinkingTokens for non-thinking models", () => {
+		const options = { modelMaxThinkingTokens: 1500 }
+		const model = {
+			id: "test-model",
+			maxTokens: 4000,
+			contextWindow: 16000,
+			supportsPromptCache: true,
+			// Note: model.thinking is not set (so it's falsey).
+		}
+
+		const result = getModelParams({
+			options,
+			model,
+		})
+
+		expect(result).toEqual({
+			maxTokens: 4000,
+			thinking: undefined, // Should remain undefined despite customMaxThinkingTokens being set.
+			temperature: 0, // Using default temperature.
+		})
+	})
+
+	it("should clamp thinking budget to at least 1024 tokens", () => {
+		const options = { modelMaxThinkingTokens: 500 }
+		const model = {
+			id: "test-model",
+			thinking: true,
+			maxTokens: 2000,
+			contextWindow: 16000,
+			supportsPromptCache: true,
+		}
+
+		const result = getModelParams({
+			options,
+			model,
+		})
+
+		const expectedThinking: BetaThinkingConfigParam = {
+			type: "enabled",
+			budget_tokens: 1024, // Minimum is 1024
+		}
+
+		expect(result).toEqual({
+			maxTokens: 2000,
+			thinking: expectedThinking,
+			temperature: 1.0,
+		})
+	})
+
+	it("should clamp thinking budget to at most 80% of max tokens", () => {
+		const options = { modelMaxThinkingTokens: 5000 }
+		const model = {
+			id: "test-model",
+			thinking: true,
+			maxTokens: 4000,
+			contextWindow: 16000,
+			supportsPromptCache: true,
+		}
+
+		const result = getModelParams({
+			options,
+			model,
+		})
+
+		const expectedThinking: BetaThinkingConfigParam = {
+			type: "enabled",
+			budget_tokens: 3200, // 80% of 4000
+		}
+
+		expect(result).toEqual({
+			maxTokens: 4000,
+			thinking: expectedThinking,
+			temperature: 1.0,
+		})
+	})
+
+	it("should use ANTHROPIC_DEFAULT_MAX_TOKENS when no maxTokens is provided for thinking models", () => {
+		const options = {}
+		const model = {
+			id: "test-model",
+			thinking: true,
+			contextWindow: 16000,
+			supportsPromptCache: true,
+		}
+
+		const result = getModelParams({
+			options,
+			model,
+		})
+
+		const expectedThinking: BetaThinkingConfigParam = {
+			type: "enabled",
+			budget_tokens: Math.floor(ANTHROPIC_DEFAULT_MAX_TOKENS * 0.8),
+		}
+
+		expect(result).toEqual({
+			maxTokens: undefined,
+			thinking: expectedThinking,
+			temperature: 1.0,
+		})
+	})
+})

+ 42 - 1
src/api/index.ts

@@ -1,6 +1,9 @@
 import { Anthropic } from "@anthropic-ai/sdk"
+import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta/messages/index.mjs"
+
+import { ApiConfiguration, ModelInfo, ApiHandlerOptions } from "../shared/api"
+import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./providers/constants"
 import { GlamaHandler } from "./providers/glama"
-import { ApiConfiguration, ModelInfo } from "../shared/api"
 import { AnthropicHandler } from "./providers/anthropic"
 import { AwsBedrockHandler } from "./providers/bedrock"
 import { OpenRouterHandler } from "./providers/openrouter"
@@ -63,3 +66,41 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
 			return new AnthropicHandler(options)
 	}
 }
+
+export function getModelParams({
+	options,
+	model,
+	defaultMaxTokens,
+	defaultTemperature = 0,
+}: {
+	options: ApiHandlerOptions
+	model: ModelInfo
+	defaultMaxTokens?: number
+	defaultTemperature?: number
+}) {
+	const {
+		modelMaxTokens: customMaxTokens,
+		modelMaxThinkingTokens: customMaxThinkingTokens,
+		modelTemperature: customTemperature,
+	} = options
+
+	let maxTokens = model.maxTokens ?? defaultMaxTokens
+	let thinking: BetaThinkingConfigParam | undefined = undefined
+	let temperature = customTemperature ?? defaultTemperature
+
+	if (model.thinking) {
+		// Only honor `customMaxTokens` for thinking models.
+		maxTokens = customMaxTokens ?? maxTokens
+
+		// Clamp the thinking budget to be at most 80% of max tokens and at
+		// least 1024 tokens.
+		const maxBudgetTokens = Math.floor((maxTokens || ANTHROPIC_DEFAULT_MAX_TOKENS) * 0.8)
+		const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024)
+		thinking = { type: "enabled", budget_tokens: budgetTokens }
+
+		// Anthropic "Thinking" models require a temperature of 1.0.
+		temperature = 1.0
+	}
+
+	return { maxTokens, thinking, temperature }
+}

+ 9 - 30
src/api/providers/anthropic.ts

@@ -1,7 +1,6 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming"
 import { CacheControlEphemeral } from "@anthropic-ai/sdk/resources"
-import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta"
 import {
 	anthropicDefaultModelId,
 	AnthropicModelId,
@@ -9,8 +8,9 @@ import {
 	ApiHandlerOptions,
 	ModelInfo,
 } from "../../shared/api"
-import { ApiHandler, SingleCompletionHandler } from "../index"
 import { ApiStream } from "../transform/stream"
+import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./constants"
+import { ApiHandler, SingleCompletionHandler, getModelParams } from "../index"
 
 export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
@@ -51,7 +51,7 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
 				stream = await this.client.messages.create(
 					{
 						model: modelId,
-						max_tokens: maxTokens,
+						max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS,
 						temperature,
 						thinking,
 						// Setting cache breakpoint for system prompt so new tasks can reuse it.
@@ -99,7 +99,7 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
 			default: {
 				stream = (await this.client.messages.create({
 					model: modelId,
-					max_tokens: maxTokens,
+					max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS,
 					temperature,
 					system: [{ text: systemPrompt, type: "text" }],
 					messages,
@@ -180,13 +180,6 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
 
 	getModel() {
 		const modelId = this.options.apiModelId
-
-		const {
-			modelMaxTokens: customMaxTokens,
-			modelMaxThinkingTokens: customMaxThinkingTokens,
-			modelTemperature: customTemperature,
-		} = this.options
-
 		let id = modelId && modelId in anthropicModels ? (modelId as AnthropicModelId) : anthropicDefaultModelId
 		const info: ModelInfo = anthropicModels[id]
 
@@ -197,25 +190,11 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
 			id = "claude-3-7-sonnet-20250219"
 		}
 
-		let maxTokens = info.maxTokens ?? 8192
-		let thinking: BetaThinkingConfigParam | undefined = undefined
-		let temperature = customTemperature ?? 0
-
-		if (info.thinking) {
-			// Only honor `customMaxTokens` for thinking models.
-			maxTokens = customMaxTokens ?? maxTokens
-
-			// Clamp the thinking budget to be at most 80% of max tokens and at
-			// least 1024 tokens.
-			const maxBudgetTokens = Math.floor(maxTokens * 0.8)
-			const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024)
-			thinking = { type: "enabled", budget_tokens: budgetTokens }
-
-			// Anthropic "Thinking" models require a temperature of 1.0.
-			temperature = 1.0
+		return {
+			id,
+			info,
+			...getModelParams({ options: this.options, model: info, defaultMaxTokens: ANTHROPIC_DEFAULT_MAX_TOKENS }),
 		}
-
-		return { id, info, maxTokens, thinking, temperature }
 	}
 
 	async completePrompt(prompt: string) {
@@ -223,7 +202,7 @@ export class AnthropicHandler implements ApiHandler, SingleCompletionHandler {
 
 		const message = await this.client.messages.create({
 			model: modelId,
-			max_tokens: maxTokens,
+			max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS,
 			thinking,
 			temperature,
 			messages: [{ role: "user", content: prompt }],

+ 3 - 0
src/api/providers/constants.ts

@@ -0,0 +1,3 @@
+export const ANTHROPIC_DEFAULT_MAX_TOKENS = 8192
+
+export const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6

+ 3 - 7
src/api/providers/ollama.ts

@@ -7,11 +7,9 @@ import { ApiHandlerOptions, ModelInfo, openAiModelInfoSaneDefaults } from "../..
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { convertToR1Format } from "../transform/r1-format"
 import { ApiStream } from "../transform/stream"
-import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./openai"
+import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
 import { XmlMatcher } from "../../utils/xml-matcher"
 
-const OLLAMA_DEFAULT_TEMPERATURE = 0
-
 export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
 	private options: ApiHandlerOptions
 	private client: OpenAI
@@ -35,7 +33,7 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
 		const stream = await this.client.chat.completions.create({
 			model: this.getModel().id,
 			messages: openAiMessages,
-			temperature: this.options.modelTemperature ?? OLLAMA_DEFAULT_TEMPERATURE,
+			temperature: this.options.modelTemperature ?? 0,
 			stream: true,
 		})
 		const matcher = new XmlMatcher(
@@ -76,9 +74,7 @@ export class OllamaHandler implements ApiHandler, SingleCompletionHandler {
 				messages: useR1Format
 					? convertToR1Format([{ role: "user", content: prompt }])
 					: [{ role: "user", content: prompt }],
-				temperature:
-					this.options.modelTemperature ??
-					(useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : OLLAMA_DEFAULT_TEMPERATURE),
+				temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
 				stream: false,
 			})
 			return response.choices[0]?.message.content || ""

+ 2 - 6
src/api/providers/openai.ts

@@ -13,14 +13,12 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
 import { convertToR1Format } from "../transform/r1-format"
 import { convertToSimpleMessages } from "../transform/simple-format"
 import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
 
 export interface OpenAiHandlerOptions extends ApiHandlerOptions {
 	defaultHeaders?: Record<string, string>
 }
 
-export const DEEP_SEEK_DEFAULT_TEMPERATURE = 0.6
-const OPENAI_DEFAULT_TEMPERATURE = 0
-
 export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
 	protected options: OpenAiHandlerOptions
 	private client: OpenAI
@@ -78,9 +76,7 @@ export class OpenAiHandler implements ApiHandler, SingleCompletionHandler {
 
 			const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
 				model: modelId,
-				temperature:
-					this.options.modelTemperature ??
-					(deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : OPENAI_DEFAULT_TEMPERATURE),
+				temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
 				messages: convertedMessages,
 				stream: true as const,
 				stream_options: { include_usage: true },

+ 10 - 36
src/api/providers/openrouter.ts

@@ -9,10 +9,8 @@ import { parseApiPrice } from "../../utils/cost"
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStreamChunk, ApiStreamUsageChunk } from "../transform/stream"
 import { convertToR1Format } from "../transform/r1-format"
-import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./openai"
-import { ApiHandler, SingleCompletionHandler } from ".."
-
-const OPENROUTER_DEFAULT_TEMPERATURE = 0
+import { DEEP_SEEK_DEFAULT_TEMPERATURE } from "./constants"
+import { ApiHandler, getModelParams, SingleCompletionHandler } from ".."
 
 // Add custom interface for OpenRouter params.
 type OpenRouterChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
@@ -200,40 +198,16 @@ export class OpenRouterHandler implements ApiHandler, SingleCompletionHandler {
 		let id = modelId ?? openRouterDefaultModelId
 		const info = modelInfo ?? openRouterDefaultModelInfo
 
-		const {
-			modelMaxTokens: customMaxTokens,
-			modelMaxThinkingTokens: customMaxThinkingTokens,
-			modelTemperature: customTemperature,
-		} = this.options
-
-		let maxTokens = info.maxTokens
-		let thinking: BetaThinkingConfigParam | undefined = undefined
-		let temperature = customTemperature ?? OPENROUTER_DEFAULT_TEMPERATURE
-		let topP: number | undefined = undefined
-
-		// Handle models based on deepseek-r1
-		if (id.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning") {
-			// Recommended temperature for DeepSeek reasoning models.
-			temperature = customTemperature ?? DEEP_SEEK_DEFAULT_TEMPERATURE
-			// Some provider support topP and 0.95 is value that Deepseek used in their benchmarks.
-			topP = 0.95
-		}
-
-		if (info.thinking) {
-			// Only honor `customMaxTokens` for thinking models.
-			maxTokens = customMaxTokens ?? maxTokens
+		const isDeepSeekR1 = id.startsWith("deepseek/deepseek-r1") || modelId === "perplexity/sonar-reasoning"
+		const defaultTemperature = isDeepSeekR1 ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0
+		const topP = isDeepSeekR1 ? 0.95 : undefined
 
-			// Clamp the thinking budget to be at most 80% of max tokens and at
-			// least 1024 tokens.
-			const maxBudgetTokens = Math.floor((maxTokens || 8192) * 0.8)
-			const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024)
-			thinking = { type: "enabled", budget_tokens: budgetTokens }
-
-			// Anthropic "Thinking" models require a temperature of 1.0.
-			temperature = 1.0
+		return {
+			id,
+			info,
+			...getModelParams({ options: this.options, model: info, defaultTemperature }),
+			topP,
 		}
-
-		return { id, info, maxTokens, thinking, temperature, topP }
 	}
 
 	async completePrompt(prompt: string) {

+ 15 - 41
src/api/providers/vertex.ts

@@ -1,12 +1,13 @@
 import { Anthropic } from "@anthropic-ai/sdk"
 import { AnthropicVertex } from "@anthropic-ai/vertex-sdk"
 import { Stream as AnthropicStream } from "@anthropic-ai/sdk/streaming"
-import { ApiHandler, SingleCompletionHandler } from "../"
-import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta"
+import { VertexAI } from "@google-cloud/vertexai"
+
 import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
 import { ApiStream } from "../transform/stream"
-import { VertexAI } from "@google-cloud/vertexai"
 import { convertAnthropicMessageToVertexGemini } from "../transform/vertex-gemini-format"
+import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./constants"
+import { ApiHandler, getModelParams, SingleCompletionHandler } from "../"
 
 // Types for Vertex SDK
 
@@ -344,21 +345,8 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 		}
 	}
 
-	getModel(): {
-		id: VertexModelId
-		info: ModelInfo
-		temperature: number
-		maxTokens: number
-		thinking?: BetaThinkingConfigParam
-	} {
+	getModel() {
 		const modelId = this.options.apiModelId
-
-		const {
-			modelMaxTokens: customMaxTokens,
-			modelMaxThinkingTokens: customMaxThinkingTokens,
-			modelTemperature: customTemperature,
-		} = this.options
-
 		let id = modelId && modelId in vertexModels ? (modelId as VertexModelId) : vertexDefaultModelId
 		const info: ModelInfo = vertexModels[id]
 
@@ -368,25 +356,11 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 			id = id.replace(":thinking", "") as VertexModelId
 		}
 
-		let maxTokens = info.maxTokens || 8192
-		let thinking: BetaThinkingConfigParam | undefined = undefined
-		let temperature = customTemperature ?? 0
-
-		if (info.thinking) {
-			// Only honor `customMaxTokens` for thinking models.
-			maxTokens = customMaxTokens ?? maxTokens
-
-			// Clamp the thinking budget to be at most 80% of max tokens and at
-			// least 1024 tokens.
-			const maxBudgetTokens = Math.floor(maxTokens * 0.8)
-			const budgetTokens = Math.max(Math.min(customMaxThinkingTokens ?? maxBudgetTokens, maxBudgetTokens), 1024)
-			thinking = { type: "enabled", budget_tokens: budgetTokens }
-
-			// Anthropic "Thinking" models require a temperature of 1.0.
-			temperature = 1.0
+		return {
+			id,
+			info,
+			...getModelParams({ options: this.options, model: info, defaultMaxTokens: ANTHROPIC_DEFAULT_MAX_TOKENS }),
 		}
-
-		return { id, info, maxTokens, thinking, temperature }
 	}
 
 	private async completePromptGemini(prompt: string) {
@@ -423,9 +397,9 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 			let { id, info, temperature, maxTokens, thinking } = this.getModel()
 			const useCache = info.supportsPromptCache
 
-			const params = {
+			const params: Anthropic.Messages.MessageCreateParamsNonStreaming = {
 				model: id,
-				max_tokens: maxTokens,
+				max_tokens: maxTokens ?? ANTHROPIC_DEFAULT_MAX_TOKENS,
 				temperature,
 				thinking,
 				system: "", // No system prompt needed for single completions
@@ -446,19 +420,19 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 				stream: false,
 			}
 
-			const response = (await this.anthropicClient.messages.create(
-				params as Anthropic.Messages.MessageCreateParamsNonStreaming,
-			)) as unknown as VertexMessageResponse
-
+			const response = (await this.anthropicClient.messages.create(params)) as unknown as VertexMessageResponse
 			const content = response.content[0]
+
 			if (content.type === "text") {
 				return content.text
 			}
+
 			return ""
 		} catch (error) {
 			if (error instanceof Error) {
 				throw new Error(`Vertex completion error: ${error.message}`)
 			}
+
 			throw error
 		}
 	}