Преглед изворни кода

fix: trim whitespace from OpenAI base URL to fix model detection (#6560)

Co-authored-by: Roo Code <[email protected]>
roomote[bot] пре 4 месеци
родитељ
комит
1237eb825b
2 измењених фајлова са 154 додато и 3 уклоњено
  1. 149 1
      src/api/providers/__tests__/openai.spec.ts
  2. 5 2
      src/api/providers/openai.ts

+ 149 - 1
src/api/providers/__tests__/openai.spec.ts

@@ -1,11 +1,12 @@
 // npx vitest run api/providers/__tests__/openai.spec.ts
 
-import { OpenAiHandler } from "../openai"
+import { OpenAiHandler, getOpenAiModels } from "../openai"
 import { ApiHandlerOptions } from "../../../shared/api"
 import { Anthropic } from "@anthropic-ai/sdk"
 import OpenAI from "openai"
 import { openAiModelInfoSaneDefaults } from "@roo-code/types"
 import { Package } from "../../../shared/package"
+import axios from "axios"
 
 const mockCreate = vitest.fn()
 
@@ -68,6 +69,13 @@ vitest.mock("openai", () => {
 	}
 })
 
+// Mock axios for getOpenAiModels tests
+vitest.mock("axios", () => ({
+	default: {
+		get: vitest.fn(),
+	},
+}))
+
 describe("OpenAiHandler", () => {
 	let handler: OpenAiHandler
 	let mockOptions: ApiHandlerOptions
@@ -776,3 +784,143 @@ describe("OpenAiHandler", () => {
 		})
 	})
 })
+
+describe("getOpenAiModels", () => {
+	beforeEach(() => {
+		vi.mocked(axios.get).mockClear()
+	})
+
+	it("should return empty array when baseUrl is not provided", async () => {
+		const result = await getOpenAiModels(undefined, "test-key")
+		expect(result).toEqual([])
+		expect(axios.get).not.toHaveBeenCalled()
+	})
+
+	it("should return empty array when baseUrl is empty string", async () => {
+		const result = await getOpenAiModels("", "test-key")
+		expect(result).toEqual([])
+		expect(axios.get).not.toHaveBeenCalled()
+	})
+
+	it("should trim whitespace from baseUrl", async () => {
+		const mockResponse = {
+			data: {
+				data: [{ id: "gpt-4" }, { id: "gpt-3.5-turbo" }],
+			},
+		}
+		vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)
+
+		const result = await getOpenAiModels("  https://api.openai.com/v1  ", "test-key")
+
+		expect(axios.get).toHaveBeenCalledWith("https://api.openai.com/v1/models", expect.any(Object))
+		expect(result).toEqual(["gpt-4", "gpt-3.5-turbo"])
+	})
+
+	it("should handle baseUrl with trailing spaces", async () => {
+		const mockResponse = {
+			data: {
+				data: [{ id: "model-1" }, { id: "model-2" }],
+			},
+		}
+		vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)
+
+		const result = await getOpenAiModels("https://api.example.com/v1 ", "test-key")
+
+		expect(axios.get).toHaveBeenCalledWith("https://api.example.com/v1/models", expect.any(Object))
+		expect(result).toEqual(["model-1", "model-2"])
+	})
+
+	it("should handle baseUrl with leading spaces", async () => {
+		const mockResponse = {
+			data: {
+				data: [{ id: "model-1" }],
+			},
+		}
+		vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)
+
+		const result = await getOpenAiModels(" https://api.example.com/v1", "test-key")
+
+		expect(axios.get).toHaveBeenCalledWith("https://api.example.com/v1/models", expect.any(Object))
+		expect(result).toEqual(["model-1"])
+	})
+
+	it("should return empty array for invalid URL after trimming", async () => {
+		const result = await getOpenAiModels("   not-a-valid-url   ", "test-key")
+		expect(result).toEqual([])
+		expect(axios.get).not.toHaveBeenCalled()
+	})
+
+	it("should include authorization header when apiKey is provided", async () => {
+		const mockResponse = {
+			data: {
+				data: [{ id: "model-1" }],
+			},
+		}
+		vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)
+
+		await getOpenAiModels("https://api.example.com/v1", "test-api-key")
+
+		expect(axios.get).toHaveBeenCalledWith(
+			"https://api.example.com/v1/models",
+			expect.objectContaining({
+				headers: expect.objectContaining({
+					Authorization: "Bearer test-api-key",
+				}),
+			}),
+		)
+	})
+
+	it("should include custom headers when provided", async () => {
+		const mockResponse = {
+			data: {
+				data: [{ id: "model-1" }],
+			},
+		}
+		vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)
+
+		const customHeaders = {
+			"X-Custom-Header": "custom-value",
+		}
+
+		await getOpenAiModels("https://api.example.com/v1", "test-key", customHeaders)
+
+		expect(axios.get).toHaveBeenCalledWith(
+			"https://api.example.com/v1/models",
+			expect.objectContaining({
+				headers: expect.objectContaining({
+					"X-Custom-Header": "custom-value",
+					Authorization: "Bearer test-key",
+				}),
+			}),
+		)
+	})
+
+	it("should handle API errors gracefully", async () => {
+		vi.mocked(axios.get).mockRejectedValueOnce(new Error("Network error"))
+
+		const result = await getOpenAiModels("https://api.example.com/v1", "test-key")
+
+		expect(result).toEqual([])
+	})
+
+	it("should handle malformed response data", async () => {
+		vi.mocked(axios.get).mockResolvedValueOnce({ data: null })
+
+		const result = await getOpenAiModels("https://api.example.com/v1", "test-key")
+
+		expect(result).toEqual([])
+	})
+
+	it("should deduplicate model IDs", async () => {
+		const mockResponse = {
+			data: {
+				data: [{ id: "gpt-4" }, { id: "gpt-4" }, { id: "gpt-3.5-turbo" }, { id: "gpt-4" }],
+			},
+		}
+		vi.mocked(axios.get).mockResolvedValueOnce(mockResponse)
+
+		const result = await getOpenAiModels("https://api.example.com/v1", "test-key")
+
+		expect(result).toEqual(["gpt-4", "gpt-3.5-turbo"])
+	})
+})

+ 5 - 2
src/api/providers/openai.ts

@@ -416,7 +416,10 @@ export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiH
 			return []
 		}
 
-		if (!URL.canParse(baseUrl)) {
+		// Trim whitespace from baseUrl to handle cases where users accidentally include spaces
+		const trimmedBaseUrl = baseUrl.trim()
+
+		if (!URL.canParse(trimmedBaseUrl)) {
 			return []
 		}
 
@@ -434,7 +437,7 @@ export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiH
 			config["headers"] = headers
 		}
 
-		const response = await axios.get(`${baseUrl}/models`, config)
+		const response = await axios.get(`${trimmedBaseUrl}/models`, config)
 		const modelsArray = response.data?.data?.map((model: any) => model.id) || []
 		return [...new Set<string>(modelsArray)]
 	} catch (error) {