Pārlūkot izejas kodu

Fix o1-pro on OpenRouter (#3914)

Chris Estreich 7 mēneši atpakaļ
vecāks
revīzija
2cfb96692d

+ 5 - 0
.changeset/slow-spies-walk.md

@@ -0,0 +1,5 @@
+---
+"roo-cline": patch
+---
+
+Fix o1-pro on OpenRouter

+ 7 - 4
src/api/providers/__tests__/ollama.test.ts → src/api/providers/__tests__/ollama.spec.ts

@@ -1,14 +1,17 @@
+// npx vitest run api/providers/__tests__/ollama.spec.ts
+
+import { vitest } from "vitest"
 import { Anthropic } from "@anthropic-ai/sdk"
 
 import { OllamaHandler } from "../ollama"
 import { ApiHandlerOptions } from "../../../shared/api"
 
-// Mock OpenAI client
-const mockCreate = jest.fn()
-jest.mock("openai", () => {
+const mockCreate = vitest.fn()
+
+vitest.mock("openai", () => {
 	return {
 		__esModule: true,
-		default: jest.fn().mockImplementation(() => ({
+		default: vitest.fn().mockImplementation(() => ({
 			chat: {
 				completions: {
 					create: mockCreate.mockImplementation(async (options) => {

+ 74 - 4
src/api/providers/__tests__/openai-native.test.ts → src/api/providers/__tests__/openai-native.spec.ts

@@ -1,17 +1,18 @@
-// npx jest src/api/providers/__tests__/openai-native.test.ts
+// npx vitest run api/providers/__tests__/openai-native.spec.ts
 
+import { vitest } from "vitest"
 import { Anthropic } from "@anthropic-ai/sdk"
 
 import { OpenAiNativeHandler } from "../openai-native"
 import { ApiHandlerOptions } from "../../../shared/api"
 
 // Mock OpenAI client
-const mockCreate = jest.fn()
+const mockCreate = vitest.fn()
 
-jest.mock("openai", () => {
+vitest.mock("openai", () => {
 	return {
 		__esModule: true,
-		default: jest.fn().mockImplementation(() => ({
+		default: vitest.fn().mockImplementation(() => ({
 			chat: {
 				completions: {
 					create: mockCreate.mockImplementation(async (options) => {
@@ -372,6 +373,75 @@ describe("OpenAiNativeHandler", () => {
 		})
 	})
 
+	describe("temperature parameter handling", () => {
+		it("should include temperature for models that support it", async () => {
+			// Test with gpt-4.1 which supports temperature
+			handler = new OpenAiNativeHandler({
+				apiModelId: "gpt-4.1",
+				openAiNativeApiKey: "test-api-key",
+			})
+
+			await handler.completePrompt("Test prompt")
+			expect(mockCreate).toHaveBeenCalledWith({
+				model: "gpt-4.1",
+				messages: [{ role: "user", content: "Test prompt" }],
+				temperature: 0,
+			})
+		})
+
+		it("should strip temperature for o1 family models", async () => {
+			const o1Models = ["o1", "o1-preview", "o1-mini"]
+
+			for (const modelId of o1Models) {
+				handler = new OpenAiNativeHandler({
+					apiModelId: modelId,
+					openAiNativeApiKey: "test-api-key",
+				})
+
+				mockCreate.mockClear()
+				await handler.completePrompt("Test prompt")
+
+				const callArgs = mockCreate.mock.calls[0][0]
+				// Temperature should be undefined for o1 models
+				expect(callArgs.temperature).toBeUndefined()
+				expect(callArgs.model).toBe(modelId)
+			}
+		})
+
+		it("should strip temperature for o3-mini model", async () => {
+			handler = new OpenAiNativeHandler({
+				apiModelId: "o3-mini",
+				openAiNativeApiKey: "test-api-key",
+			})
+
+			await handler.completePrompt("Test prompt")
+
+			const callArgs = mockCreate.mock.calls[0][0]
+			// Temperature should be undefined for o3-mini models
+			expect(callArgs.temperature).toBeUndefined()
+			expect(callArgs.model).toBe("o3-mini")
+			expect(callArgs.reasoning_effort).toBe("medium")
+		})
+
+		it("should strip temperature in streaming mode for unsupported models", async () => {
+			handler = new OpenAiNativeHandler({
+				apiModelId: "o1",
+				openAiNativeApiKey: "test-api-key",
+			})
+
+			const stream = handler.createMessage(systemPrompt, messages)
+			// Consume the stream
+			for await (const _chunk of stream) {
+				// Just consume the stream
+			}
+
+			const callArgs = mockCreate.mock.calls[0][0]
+			expect(callArgs).not.toHaveProperty("temperature")
+			expect(callArgs.model).toBe("o1")
+			expect(callArgs.stream).toBe(true)
+		})
+	})
+
 	describe("getModel", () => {
 		it("should return model info", () => {
 			const modelInfo = handler.getModel()

+ 10 - 6
src/api/providers/__tests__/openai-usage-tracking.test.ts → src/api/providers/__tests__/openai-usage-tracking.spec.ts

@@ -1,13 +1,17 @@
-import { OpenAiHandler } from "../openai"
-import { ApiHandlerOptions } from "../../../shared/api"
+// npx vitest run api/providers/__tests__/openai-usage-tracking.spec.ts
+
+import { vitest } from "vitest"
 import { Anthropic } from "@anthropic-ai/sdk"
 
-// Mock OpenAI client with multiple chunks that contain usage data
-const mockCreate = jest.fn()
-jest.mock("openai", () => {
+import { ApiHandlerOptions } from "../../../shared/api"
+import { OpenAiHandler } from "../openai"
+
+const mockCreate = vitest.fn()
+
+vitest.mock("openai", () => {
 	return {
 		__esModule: true,
-		default: jest.fn().mockImplementation(() => ({
+		default: vitest.fn().mockImplementation(() => ({
 			chat: {
 				completions: {
 					create: mockCreate.mockImplementation(async (options) => {

+ 10 - 9
src/api/providers/__tests__/openai.test.ts → src/api/providers/__tests__/openai.spec.ts

@@ -1,15 +1,18 @@
-// npx jest src/api/providers/__tests__/openai.test.ts
+// npx vitest run api/providers/__tests__/openai.spec.ts
 
+import { vitest, vi } from "vitest"
 import { OpenAiHandler } from "../openai"
 import { ApiHandlerOptions } from "../../../shared/api"
 import { Anthropic } from "@anthropic-ai/sdk"
+import OpenAI from "openai"
 
-// Mock OpenAI client
-const mockCreate = jest.fn()
-jest.mock("openai", () => {
+const mockCreate = vitest.fn()
+
+vitest.mock("openai", () => {
+	const mockConstructor = vitest.fn()
 	return {
 		__esModule: true,
-		default: jest.fn().mockImplementation(() => ({
+		default: mockConstructor.mockImplementation(() => ({
 			chat: {
 				completions: {
 					create: mockCreate.mockImplementation(async (options) => {
@@ -94,10 +97,8 @@ describe("OpenAiHandler", () => {
 		})
 
 		it("should set default headers correctly", () => {
-			// Get the mock constructor from the jest mock system
-			const openAiMock = jest.requireMock("openai").default
-
-			expect(openAiMock).toHaveBeenCalledWith({
+			// Check that the OpenAI constructor was called with correct parameters
+			expect(vi.mocked(OpenAI)).toHaveBeenCalledWith({
 				baseURL: expect.any(String),
 				apiKey: expect.any(String),
 				defaultHeaders: {

+ 2 - 8
src/api/providers/openai-native.ts

@@ -165,7 +165,7 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 
 		const info: ModelInfo = openAiNativeModels[id]
 
-		const { temperature, ...params } = getModelParams({
+		const params = getModelParams({
 			format: "openai",
 			modelId: id,
 			model: info,
@@ -175,13 +175,7 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
 
 		// The o3 models are named like "o3-mini-[reasoning-effort]", which are
 		// not valid model ids, so we need to strip the suffix.
-		// Also note that temperature is not supported for o1 and o3-mini.
-		return {
-			id: id.startsWith("o3-mini") ? "o3-mini" : id,
-			info,
-			...params,
-			temperature: id.startsWith("o1") || id.startsWith("o3-mini") ? undefined : temperature,
-		}
+		return { id: id.startsWith("o3-mini") ? "o3-mini" : id, info, ...params }
 	}
 
 	async completePrompt(prompt: string): Promise<string> {

+ 16 - 1
src/api/transform/model-params.ts

@@ -25,7 +25,7 @@ type GetModelParamsOptions<T extends "openai" | "anthropic" | "openrouter"> = {
 
 type BaseModelParams = {
 	maxTokens: number | undefined
-	temperature: number
+	temperature: number | undefined
 	reasoningEffort: "low" | "medium" | "high" | undefined
 	reasoningBudget: number | undefined
 }
@@ -114,12 +114,27 @@ export function getModelParams({
 			reasoning: getAnthropicReasoning({ model, reasoningBudget, reasoningEffort, settings }),
 		}
 	} else if (format === "openai") {
+		// Special case for o1 and o3-mini, which don't support temperature.
+		// TODO: Add a `supportsTemperature` field to the model info.
+		if (modelId.startsWith("o1") || modelId.startsWith("o3-mini")) {
+			params.temperature = undefined
+		}
+
 		return {
 			format,
 			...params,
 			reasoning: getOpenAiReasoning({ model, reasoningBudget, reasoningEffort, settings }),
 		}
 	} else {
+		// Special case for o1-pro, which doesn't support temperature.
+		// Note that OpenRouter's `supported_parameters` field includes
+		// `temperature`, which is probably a bug.
+		// TODO: Add a `supportsTemperature` field to the model info and populate
+		// it appropriately in the OpenRouter fetcher.
+		if (modelId === "openai/o1-pro") {
+			params.temperature = undefined
+		}
+
 		return {
 			format,
 			...params,

+ 4 - 1
src/vitest.config.ts

@@ -1,5 +1,8 @@
 import { defineConfig } from "vitest/config"
 
 export default defineConfig({
-	test: { include: ["**/__tests__/**/*.spec.ts"] },
+	test: {
+		include: ["**/__tests__/**/*.spec.ts"],
+		globals: true,
+	},
 })