浏览代码

Fix Requesty extended thinking (#4051)

Chris Estreich 7 月之前
父节点
当前提交
f37e6f6fce

+ 2 - 6
src/api/providers/__tests__/deepseek.test.ts

@@ -140,12 +140,8 @@ describe("DeepSeekHandler", () => {
 
 		it("should set includeMaxTokens to true", () => {
 			// Create a new handler and verify OpenAI client was called with includeMaxTokens
-			new DeepSeekHandler(mockOptions)
-			expect(OpenAI).toHaveBeenCalledWith(
-				expect.objectContaining({
-					apiKey: mockOptions.deepSeekApiKey,
-				}),
-			)
+			const _handler = new DeepSeekHandler(mockOptions)
+			expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: mockOptions.deepSeekApiKey }))
 		})
 	})
 

+ 6 - 4
src/api/providers/__tests__/requesty.test.ts

@@ -7,7 +7,9 @@ import { RequestyHandler } from "../requesty"
 import { ApiHandlerOptions } from "../../../shared/api"
 
 jest.mock("openai")
+
 jest.mock("delay", () => jest.fn(() => Promise.resolve()))
+
 jest.mock("../fetchers/modelCache", () => ({
 	getModels: jest.fn().mockImplementation(() => {
 		return Promise.resolve({
@@ -150,7 +152,7 @@ describe("RequestyHandler", () => {
 			// Verify OpenAI client was called with correct parameters
 			expect(mockCreate).toHaveBeenCalledWith(
 				expect.objectContaining({
-					max_tokens: undefined,
+					max_tokens: 8192,
 					messages: [
 						{
 							role: "system",
@@ -164,7 +166,7 @@ describe("RequestyHandler", () => {
 					model: "coding/claude-4-sonnet",
 					stream: true,
 					stream_options: { include_usage: true },
-					temperature: undefined,
+					temperature: 0,
 				}),
 			)
 		})
@@ -198,9 +200,9 @@ describe("RequestyHandler", () => {
 
 			expect(mockCreate).toHaveBeenCalledWith({
 				model: mockOptions.requestyModelId,
-				max_tokens: undefined,
+				max_tokens: 8192,
 				messages: [{ role: "system", content: "test prompt" }],
-				temperature: undefined,
+				temperature: 0,
 			})
 		})
 

+ 31 - 37
src/api/providers/mistral.ts

@@ -18,60 +18,50 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand
 
 	constructor(options: ApiHandlerOptions) {
 		super()
+
 		if (!options.mistralApiKey) {
 			throw new Error("Mistral API key is required")
 		}
 
-		// Set default model ID if not provided
-		this.options = {
-			...options,
-			apiModelId: options.apiModelId || mistralDefaultModelId,
-		}
+		// Set default model ID if not provided.
+		const apiModelId = options.apiModelId || mistralDefaultModelId
+		this.options = { ...options, apiModelId }
 
-		const baseUrl = this.getBaseUrl()
-		console.debug(`[Roo Code] MistralHandler using baseUrl: ${baseUrl}`)
 		this.client = new Mistral({
-			serverURL: baseUrl,
+			serverURL: apiModelId.startsWith("codestral-")
+				? this.options.mistralCodestralUrl || "https://codestral.mistral.ai"
+				: "https://api.mistral.ai",
 			apiKey: this.options.mistralApiKey,
 		})
 	}
 
-	private getBaseUrl(): string {
-		const modelId = this.options.apiModelId ?? mistralDefaultModelId
-		console.debug(`[Roo Code] MistralHandler using modelId: ${modelId}`)
-		if (modelId?.startsWith("codestral-")) {
-			return this.options.mistralCodestralUrl || "https://codestral.mistral.ai"
-		}
-		return "https://api.mistral.ai"
-	}
-
 	override async *createMessage(
 		systemPrompt: string,
 		messages: Anthropic.Messages.MessageParam[],
 		metadata?: ApiHandlerCreateMessageMetadata,
 	): ApiStream {
-		const { id: model } = this.getModel()
+		const { id: model, maxTokens, temperature } = this.getModel()
 
 		const response = await this.client.chat.stream({
-			model: this.options.apiModelId || mistralDefaultModelId,
+			model,
 			messages: [{ role: "system", content: systemPrompt }, ...convertToMistralMessages(messages)],
-			maxTokens: this.options.includeMaxTokens ? this.getModel().info.maxTokens : undefined,
-			temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE,
+			maxTokens,
+			temperature,
 		})
 
 		for await (const chunk of response) {
 			const delta = chunk.data.choices[0]?.delta
+
 			if (delta?.content) {
 				let content: string = ""
+
 				if (typeof delta.content === "string") {
 					content = delta.content
 				} else if (Array.isArray(delta.content)) {
 					content = delta.content.map((c) => (c.type === "text" ? c.text : "")).join("")
 				}
-				yield {
-					type: "text",
-					text: content,
-				}
+
+				yield { type: "text", text: content }
 			}
 
 			if (chunk.data.usage) {
@@ -84,35 +74,39 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand
 		}
 	}
 
-	override getModel(): { id: MistralModelId; info: ModelInfo } {
-		const modelId = this.options.apiModelId
-		if (modelId && modelId in mistralModels) {
-			const id = modelId as MistralModelId
-			return { id, info: mistralModels[id] }
-		}
-		return {
-			id: mistralDefaultModelId,
-			info: mistralModels[mistralDefaultModelId],
-		}
+	override getModel() {
+		const id = this.options.apiModelId ?? mistralDefaultModelId
+		const info = mistralModels[id as MistralModelId] ?? mistralModels[mistralDefaultModelId]
+
+		// @TODO: Move this to the `getModelParams` function.
+		const maxTokens = this.options.includeMaxTokens ? info.maxTokens : undefined
+		const temperature = this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE
+
+		return { id, info, maxTokens, temperature }
 	}
 
 	async completePrompt(prompt: string): Promise<string> {
 		try {
+			const { id: model, temperature } = this.getModel()
+
 			const response = await this.client.chat.complete({
-				model: this.options.apiModelId || mistralDefaultModelId,
+				model,
 				messages: [{ role: "user", content: prompt }],
-				temperature: this.options.modelTemperature ?? MISTRAL_DEFAULT_TEMPERATURE,
+				temperature,
 			})
 
 			const content = response.choices?.[0]?.message.content
+
 			if (Array.isArray(content)) {
 				return content.map((c) => (c.type === "text" ? c.text : "")).join("")
 			}
+
 			return content || ""
 		} catch (error) {
 			if (error instanceof Error) {
 				throw new Error(`Mistral completion error: ${error.message}`)
 			}
+
 			throw error
 		}
 	}

+ 1 - 0
src/api/providers/openai.ts

@@ -154,6 +154,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
 				...(reasoning && reasoning),
 			}
 
+			// @TODO: Move this to the `getModelParams` function.
 			if (this.options.includeMaxTokens) {
 				requestOptions.max_tokens = modelInfo.maxTokens
 			}

+ 43 - 69
src/api/providers/requesty.ts

@@ -8,11 +8,13 @@ import { calculateApiCostOpenAI } from "../../shared/cost"
 
 import { convertToOpenAiMessages } from "../transform/openai-format"
 import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
+import { getModelParams } from "../transform/model-params"
+import { AnthropicReasoningParams } from "../transform/reasoning"
 
 import { DEFAULT_HEADERS } from "./constants"
 import { getModels } from "./fetchers/modelCache"
 import { BaseProvider } from "./base-provider"
-import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../"
+import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
 
 // Requesty usage includes an extra field for Anthropic use cases.
 // Safely cast the prompt token details section to the appropriate structure.
@@ -31,10 +33,7 @@ type RequestyChatCompletionParams = OpenAI.Chat.ChatCompletionCreateParams & {
 			mode?: string
 		}
 	}
-	thinking?: {
-		type: string
-		budget_tokens?: number
-	}
+	thinking?: AnthropicReasoningParams
 }
 
 export class RequestyHandler extends BaseProvider implements SingleCompletionHandler {
@@ -44,14 +43,14 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan
 
 	constructor(options: ApiHandlerOptions) {
 		super()
-		this.options = options
-
-		const apiKey = this.options.requestyApiKey ?? "not-provided"
-		const baseURL = "https://router.requesty.ai/v1"
 
-		const defaultHeaders = DEFAULT_HEADERS
+		this.options = options
 
-		this.client = new OpenAI({ baseURL, apiKey, defaultHeaders })
+		this.client = new OpenAI({
+			baseURL: "https://router.requesty.ai/v1",
+			apiKey: this.options.requestyApiKey ?? "not-provided",
+			defaultHeaders: DEFAULT_HEADERS,
+		})
 	}
 
 	public async fetchModel() {
@@ -59,10 +58,18 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan
 		return this.getModel()
 	}
 
-	override getModel(): { id: string; info: ModelInfo } {
+	override getModel() {
 		const id = this.options.requestyModelId ?? requestyDefaultModelId
 		const info = this.models[id] ?? requestyDefaultModelInfo
-		return { id, info }
+
+		const params = getModelParams({
+			format: "anthropic",
+			modelId: id,
+			model: info,
+			settings: this.options,
+		})
+
+		return { id, info, ...params }
 	}
 
 	protected processUsageMetrics(usage: any, modelInfo?: ModelInfo): ApiStreamUsageChunk {
@@ -90,70 +97,44 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan
 		messages: Anthropic.Messages.MessageParam[],
 		metadata?: ApiHandlerCreateMessageMetadata,
 	): ApiStream {
-		const model = await this.fetchModel()
-
-		let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
+		const {
+			id: model,
+			info,
+			maxTokens: max_tokens,
+			temperature,
+			reasoningEffort: reasoning_effort,
+			reasoning: thinking,
+		} = await this.fetchModel()
+
+		const openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [
 			{ role: "system", content: systemPrompt },
 			...convertToOpenAiMessages(messages),
 		]
 
-		let maxTokens = undefined
-		if (this.options.modelMaxTokens) {
-			maxTokens = this.options.modelMaxTokens
-		} else if (this.options.includeMaxTokens) {
-			maxTokens = model.info.maxTokens
-		}
-
-		let reasoningEffort = undefined
-		if (this.options.reasoningEffort) {
-			reasoningEffort = this.options.reasoningEffort
-		}
-
-		let thinking = undefined
-		if (this.options.modelMaxThinkingTokens) {
-			thinking = {
-				type: "enabled",
-				budget_tokens: this.options.modelMaxThinkingTokens,
-			}
-		}
-
-		const temperature = this.options.modelTemperature
-
 		const completionParams: RequestyChatCompletionParams = {
-			model: model.id,
-			max_tokens: maxTokens,
 			messages: openAiMessages,
-			temperature: temperature,
+			model,
+			max_tokens,
+			temperature,
+			...(reasoning_effort && { reasoning_effort }),
+			...(thinking && { thinking }),
 			stream: true,
 			stream_options: { include_usage: true },
-			reasoning_effort: reasoningEffort,
-			thinking: thinking,
-			requesty: {
-				trace_id: metadata?.taskId,
-				extra: {
-					mode: metadata?.mode,
-				},
-			},
+			requesty: { trace_id: metadata?.taskId, extra: { mode: metadata?.mode } },
 		}
 
 		const stream = await this.client.chat.completions.create(completionParams)
-
 		let lastUsage: any = undefined
 
 		for await (const chunk of stream) {
 			const delta = chunk.choices[0]?.delta
+
 			if (delta?.content) {
-				yield {
-					type: "text",
-					text: delta.content,
-				}
+				yield { type: "text", text: delta.content }
 			}
 
 			if (delta && "reasoning_content" in delta && delta.reasoning_content) {
-				yield {
-					type: "reasoning",
-					text: (delta.reasoning_content as string | undefined) || "",
-				}
+				yield { type: "reasoning", text: (delta.reasoning_content as string | undefined) || "" }
 			}
 
 			if (chunk.usage) {
@@ -162,25 +143,18 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan
 		}
 
 		if (lastUsage) {
-			yield this.processUsageMetrics(lastUsage, model.info)
+			yield this.processUsageMetrics(lastUsage, info)
 		}
 	}
 
 	async completePrompt(prompt: string): Promise<string> {
-		const model = await this.fetchModel()
+		const { id: model, maxTokens: max_tokens, temperature } = await this.fetchModel()
 
 		let openAiMessages: OpenAI.Chat.ChatCompletionMessageParam[] = [{ role: "system", content: prompt }]
 
-		let maxTokens = undefined
-		if (this.options.includeMaxTokens) {
-			maxTokens = model.info.maxTokens
-		}
-
-		const temperature = this.options.modelTemperature
-
 		const completionParams: RequestyChatCompletionParams = {
-			model: model.id,
-			max_tokens: maxTokens,
+			model,
+			max_tokens,
 			messages: openAiMessages,
 			temperature: temperature,
 		}