Kaynağa Gözat

fix: prevent model cache from persisting empty API responses (#9623)

Daniel 2 ay önce
ebeveyn
işleme
f3889195d7

+ 2 - 0
packages/types/src/telemetry.ts

@@ -72,6 +72,7 @@ export enum TelemetryEventName {
 	CONSECUTIVE_MISTAKE_ERROR = "Consecutive Mistake Error",
 	CONSECUTIVE_MISTAKE_ERROR = "Consecutive Mistake Error",
 	CODE_INDEX_ERROR = "Code Index Error",
 	CODE_INDEX_ERROR = "Code Index Error",
 	TELEMETRY_SETTINGS_CHANGED = "Telemetry Settings Changed",
 	TELEMETRY_SETTINGS_CHANGED = "Telemetry Settings Changed",
+	MODEL_CACHE_EMPTY_RESPONSE = "Model Cache Empty Response",
 }
 }
 
 
 /**
 /**
@@ -196,6 +197,7 @@ export const rooCodeTelemetryEventSchema = z.discriminatedUnion("type", [
 			TelemetryEventName.SHELL_INTEGRATION_ERROR,
 			TelemetryEventName.SHELL_INTEGRATION_ERROR,
 			TelemetryEventName.CONSECUTIVE_MISTAKE_ERROR,
 			TelemetryEventName.CONSECUTIVE_MISTAKE_ERROR,
 			TelemetryEventName.CODE_INDEX_ERROR,
 			TelemetryEventName.CODE_INDEX_ERROR,
+			TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE,
 			TelemetryEventName.CONTEXT_CONDENSED,
 			TelemetryEventName.CONTEXT_CONDENSED,
 			TelemetryEventName.SLIDING_WINDOW_TRUNCATION,
 			TelemetryEventName.SLIDING_WINDOW_TRUNCATION,
 			TelemetryEventName.TAB_SHOWN,
 			TelemetryEventName.TAB_SHOWN,

+ 193 - 0
src/api/providers/fetchers/__tests__/modelCache.spec.ts

@@ -1,5 +1,14 @@
 // Mocks must come first, before imports
 // Mocks must come first, before imports
 
 
+// Mock TelemetryService
+vi.mock("@roo-code/telemetry", () => ({
+	TelemetryService: {
+		instance: {
+			captureEvent: vi.fn(),
+		},
+	},
+}))
+
 // Mock NodeCache to allow controlling cache behavior
 // Mock NodeCache to allow controlling cache behavior
 vi.mock("node-cache", () => {
 vi.mock("node-cache", () => {
 	const mockGet = vi.fn().mockReturnValue(undefined)
 	const mockGet = vi.fn().mockReturnValue(undefined)
@@ -301,3 +310,187 @@ describe("getModelsFromCache disk fallback", () => {
 		consoleErrorSpy.mockRestore()
 		consoleErrorSpy.mockRestore()
 	})
 	})
 })
 })
+
+describe("empty cache protection", () => {
+	let mockCache: any
+	let mockGet: Mock
+	let mockSet: Mock
+
+	beforeEach(() => {
+		vi.clearAllMocks()
+		// Get the mock cache instance
+		const MockedNodeCache = vi.mocked(NodeCache)
+		mockCache = new MockedNodeCache()
+		mockGet = mockCache.get
+		mockSet = mockCache.set
+		// Reset memory cache to always miss by default
+		mockGet.mockReturnValue(undefined)
+	})
+
+	describe("getModels", () => {
+		it("does not cache empty API responses", async () => {
+			// API returns empty object (simulating failure)
+			mockGetOpenRouterModels.mockResolvedValue({})
+
+			const result = await getModels({ provider: "openrouter" })
+
+			// Should return empty but NOT cache it
+			expect(result).toEqual({})
+			expect(mockSet).not.toHaveBeenCalled()
+		})
+
+		it("caches non-empty API responses", async () => {
+			const mockModels = {
+				"openrouter/model": {
+					maxTokens: 8192,
+					contextWindow: 128000,
+					supportsPromptCache: false,
+					description: "OpenRouter model",
+				},
+			}
+			mockGetOpenRouterModels.mockResolvedValue(mockModels)
+
+			const result = await getModels({ provider: "openrouter" })
+
+			expect(result).toEqual(mockModels)
+			expect(mockSet).toHaveBeenCalledWith("openrouter", mockModels)
+		})
+	})
+
+	describe("refreshModels", () => {
+		it("keeps existing cache when API returns empty response", async () => {
+			const existingModels = {
+				"openrouter/existing-model": {
+					maxTokens: 8192,
+					contextWindow: 128000,
+					supportsPromptCache: false,
+					description: "Existing cached model",
+				},
+			}
+
+			// Memory cache has existing data
+			mockGet.mockReturnValue(existingModels)
+			// API returns empty (failure)
+			mockGetOpenRouterModels.mockResolvedValue({})
+
+			const { refreshModels } = await import("../modelCache")
+			const result = await refreshModels({ provider: "openrouter" })
+
+			// Should return existing cache, not empty
+			expect(result).toEqual(existingModels)
+			// Should NOT update cache with empty data
+			expect(mockSet).not.toHaveBeenCalled()
+		})
+
+		it("updates cache when API returns valid non-empty response", async () => {
+			const existingModels = {
+				"openrouter/old-model": {
+					maxTokens: 4096,
+					contextWindow: 64000,
+					supportsPromptCache: false,
+					description: "Old model",
+				},
+			}
+			const newModels = {
+				"openrouter/new-model": {
+					maxTokens: 8192,
+					contextWindow: 128000,
+					supportsPromptCache: true,
+					description: "New model",
+				},
+			}
+
+			mockGet.mockReturnValue(existingModels)
+			mockGetOpenRouterModels.mockResolvedValue(newModels)
+
+			const { refreshModels } = await import("../modelCache")
+			const result = await refreshModels({ provider: "openrouter" })
+
+			// Should return new models
+			expect(result).toEqual(newModels)
+			// Should update cache with new data
+			expect(mockSet).toHaveBeenCalledWith("openrouter", newModels)
+		})
+
+		it("returns existing cache on API error", async () => {
+			const existingModels = {
+				"openrouter/cached-model": {
+					maxTokens: 8192,
+					contextWindow: 128000,
+					supportsPromptCache: false,
+					description: "Cached model",
+				},
+			}
+
+			mockGet.mockReturnValue(existingModels)
+			mockGetOpenRouterModels.mockRejectedValue(new Error("API error"))
+
+			const { refreshModels } = await import("../modelCache")
+			const result = await refreshModels({ provider: "openrouter" })
+
+			// Should return existing cache on error
+			expect(result).toEqual(existingModels)
+		})
+
+		it("returns empty object when API errors and no cache exists", async () => {
+			mockGet.mockReturnValue(undefined)
+			mockGetOpenRouterModels.mockRejectedValue(new Error("API error"))
+
+			const { refreshModels } = await import("../modelCache")
+			const result = await refreshModels({ provider: "openrouter" })
+
+			// Should return empty when no cache and API fails
+			expect(result).toEqual({})
+		})
+
+		it("does not cache empty response when no existing cache", async () => {
+			// Both memory and disk cache are empty (initial state)
+			mockGet.mockReturnValue(undefined)
+			// API returns empty (failure/rate limit)
+			mockGetOpenRouterModels.mockResolvedValue({})
+
+			const { refreshModels } = await import("../modelCache")
+			const result = await refreshModels({ provider: "openrouter" })
+
+			// Should return empty but NOT cache it
+			expect(result).toEqual({})
+			expect(mockSet).not.toHaveBeenCalled()
+		})
+
+		it("reuses in-flight request for concurrent calls to same provider", async () => {
+			const mockModels = {
+				"openrouter/model": {
+					maxTokens: 8192,
+					contextWindow: 128000,
+					supportsPromptCache: false,
+					description: "OpenRouter model",
+				},
+			}
+
+			// Create a delayed response to simulate API latency
+			let resolvePromise: (value: typeof mockModels) => void
+			const delayedPromise = new Promise<typeof mockModels>((resolve) => {
+				resolvePromise = resolve
+			})
+			mockGetOpenRouterModels.mockReturnValue(delayedPromise)
+			mockGet.mockReturnValue(undefined)
+
+			const { refreshModels } = await import("../modelCache")
+
+			// Start two concurrent refresh calls
+			const promise1 = refreshModels({ provider: "openrouter" })
+			const promise2 = refreshModels({ provider: "openrouter" })
+
+			// API should only be called once (second call reuses in-flight request)
+			expect(mockGetOpenRouterModels).toHaveBeenCalledTimes(1)
+
+			// Resolve the API call
+			resolvePromise!(mockModels)
+
+			// Both promises should resolve to the same result
+			const [result1, result2] = await Promise.all([promise1, promise2])
+			expect(result1).toEqual(mockModels)
+			expect(result2).toEqual(mockModels)
+		})
+	})
+})

+ 80 - 30
src/api/providers/fetchers/modelCache.ts

@@ -6,7 +6,8 @@ import NodeCache from "node-cache"
 import { z } from "zod"
 import { z } from "zod"
 
 
 import type { ProviderName } from "@roo-code/types"
 import type { ProviderName } from "@roo-code/types"
-import { modelInfoSchema } from "@roo-code/types"
+import { modelInfoSchema, TelemetryEventName } from "@roo-code/types"
+import { TelemetryService } from "@roo-code/telemetry"
 
 
 import { safeWriteJson } from "../../../utils/safeWriteJson"
 import { safeWriteJson } from "../../../utils/safeWriteJson"
 
 
@@ -35,6 +36,10 @@ const memoryCache = new NodeCache({ stdTTL: 5 * 60, checkperiod: 5 * 60 })
 // Zod schema for validating ModelRecord structure from disk cache
 // Zod schema for validating ModelRecord structure from disk cache
 const modelRecordSchema = z.record(z.string(), modelInfoSchema)
 const modelRecordSchema = z.record(z.string(), modelInfoSchema)
 
 
+// Track in-flight refresh requests to prevent concurrent API calls for the same provider
+// This prevents race conditions where multiple calls might overwrite each other's results
+const inFlightRefresh = new Map<RouterName, Promise<ModelRecord>>()
+
 async function writeModels(router: RouterName, data: ModelRecord) {
 async function writeModels(router: RouterName, data: ModelRecord) {
 	const filename = `${router}_models.json`
 	const filename = `${router}_models.json`
 	const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath)
 	const cacheDir = await getCacheDirectoryPath(ContextProxy.instance.globalStorageUri.fsPath)
@@ -139,20 +144,25 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
 
 
 	try {
 	try {
 		models = await fetchModelsFromProvider(options)
 		models = await fetchModelsFromProvider(options)
-
-		// Cache the fetched models (even if empty, to signify a successful fetch with no models).
-		memoryCache.set(provider, models)
-
-		await writeModels(provider, models).catch((err) =>
-			console.error(`[MODEL_CACHE] Error writing ${provider} models to file cache:`, err),
-		)
-
-		try {
-			models = await readModels(provider)
-		} catch (error) {
-			console.error(`[getModels] error reading ${provider} models from file cache`, error)
+		const modelCount = Object.keys(models).length
+
+		// Only cache non-empty results to prevent persisting failed API responses
+		// Empty results could indicate API failure rather than "no models exist"
+		if (modelCount > 0) {
+			memoryCache.set(provider, models)
+
+			await writeModels(provider, models).catch((err) =>
+				console.error(`[MODEL_CACHE] Error writing ${provider} models to file cache:`, err),
+			)
+		} else {
+			TelemetryService.instance.captureEvent(TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE, {
+				provider,
+				context: "getModels",
+				hasExistingCache: false,
+			})
 		}
 		}
-		return models || {}
+
+		return models
 	} catch (error) {
 	} catch (error) {
 		// Log the error and re-throw it so the caller can handle it (e.g., show a UI message).
 		// Log the error and re-throw it so the caller can handle it (e.g., show a UI message).
 		console.error(`[getModels] Failed to fetch models in modelCache for ${provider}:`, error)
 		console.error(`[getModels] Failed to fetch models in modelCache for ${provider}:`, error)
@@ -164,31 +174,71 @@ export const getModels = async (options: GetModelsOptions): Promise<ModelRecord>
 /**
 /**
  * Force-refresh models from API, bypassing cache.
  * Force-refresh models from API, bypassing cache.
  * Uses atomic writes so cache remains available during refresh.
  * Uses atomic writes so cache remains available during refresh.
+ * This function also prevents concurrent API calls for the same provider using
+ * in-flight request tracking to avoid race conditions.
  *
  *
  * @param options - Provider options for fetching models
  * @param options - Provider options for fetching models
- * @returns Fresh models from API
+ * @returns Fresh models from API, or existing cache if refresh yields worse data
  */
  */
 export const refreshModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
 export const refreshModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
 	const { provider } = options
 	const { provider } = options
 
 
-	try {
-		// Force fresh API fetch - skip getModelsFromCache() check
-		const models = await fetchModelsFromProvider(options)
+	// Check if there's already an in-flight refresh for this provider
+	// This prevents race conditions where multiple concurrent refreshes might
+	// overwrite each other's results
+	const existingRequest = inFlightRefresh.get(provider)
+	if (existingRequest) {
+		return existingRequest
+	}
 
 
-		// Update memory cache first
-		memoryCache.set(provider, models)
+	// Create the refresh promise and track it
+	const refreshPromise = (async (): Promise<ModelRecord> => {
+		try {
+			// Force fresh API fetch - skip getModelsFromCache() check
+			const models = await fetchModelsFromProvider(options)
+			const modelCount = Object.keys(models).length
+
+			// Get existing cached data for comparison
+			const existingCache = getModelsFromCache(provider)
+			const existingCount = existingCache ? Object.keys(existingCache).length : 0
+
+			if (modelCount === 0) {
+				TelemetryService.instance.captureEvent(TelemetryEventName.MODEL_CACHE_EMPTY_RESPONSE, {
+					provider,
+					context: "refreshModels",
+					hasExistingCache: existingCount > 0,
+					existingCacheSize: existingCount,
+				})
+				if (existingCount > 0) {
+					return existingCache!
+				} else {
+					return {}
+				}
+			}
 
 
-		// Atomically write to disk (safeWriteJson handles atomic writes)
-		await writeModels(provider, models).catch((err) =>
-			console.error(`[refreshModels] Error writing ${provider} models to disk:`, err),
-		)
+			// Update memory cache first
+			memoryCache.set(provider, models)
 
 
-		return models
-	} catch (error) {
-		console.debug(`[refreshModels] Failed to refresh ${provider}:`, error)
-		// On error, return existing cache if available (graceful degradation)
-		return getModelsFromCache(provider) || {}
-	}
+			// Atomically write to disk (safeWriteJson handles atomic writes)
+			await writeModels(provider, models).catch((err) =>
+				console.error(`[refreshModels] Error writing ${provider} models to disk:`, err),
+			)
+
+			return models
+		} catch (error) {
+			// Log the error for debugging, then return existing cache if available (graceful degradation)
+			console.error(`[refreshModels] Failed to refresh ${provider} models:`, error)
+			return getModelsFromCache(provider) || {}
+		} finally {
+			// Always clean up the in-flight tracking
+			inFlightRefresh.delete(provider)
+		}
+	})()
+
+	// Track the in-flight request
+	inFlightRefresh.set(provider, refreshPromise)
+
+	return refreshPromise
 }
 }
 
 
 /**
 /**