Browse Source

Stop leaking other provider settings (#3357)

* Stop leaking other provider settings

* Also filter out leaked properties on export
John Richmond 8 months ago
parent
commit
a79d18739a

+ 18 - 3
src/core/config/ProviderSettingsManager.ts

@@ -1,11 +1,14 @@
 import { ExtensionContext } from "vscode"
 import { z, ZodError } from "zod"
 
-import { providerSettingsSchema, ApiConfigMeta } from "../../schemas"
+import { providerSettingsSchema, ApiConfigMeta, providerSettingsSchemaDiscriminated } from "../../schemas"
 import { Mode, modes } from "../../shared/modes"
 import { telemetryService } from "../../services/telemetry/TelemetryService"
 
 const providerSettingsWithIdSchema = providerSettingsSchema.extend({ id: z.string().optional() })
+const discriminatedProviderSettingsWithIdSchema = providerSettingsSchemaDiscriminated.and(
+	z.object({ id: z.string().optional() }),
+)
 
 type ProviderSettingsWithId = z.infer<typeof providerSettingsWithIdSchema>
 
@@ -250,7 +253,11 @@ export class ProviderSettingsManager {
 				const providerProfiles = await this.load()
 				// Preserve the existing ID if this is an update to an existing config.
 				const existingId = providerProfiles.apiConfigs[name]?.id
-				providerProfiles.apiConfigs[name] = { ...config, id: config.id || existingId || this.generateId() }
+				const id = config.id || existingId || this.generateId()
+
+				// Filter out settings from other providers.
+				const filteredConfig = providerSettingsSchemaDiscriminated.parse(config)
+				providerProfiles.apiConfigs[name] = { ...filteredConfig, id }
 				await this.store(providerProfiles)
 			})
 		} catch (error) {
@@ -381,7 +388,15 @@ export class ProviderSettingsManager {
 
 	public async export() {
 		try {
-			return await this.lock(async () => providerProfilesSchema.parse(await this.load()))
+			return await this.lock(async () => {
+				const profiles = providerProfilesSchema.parse(await this.load())
+				const configs = profiles.apiConfigs
+				for (const name in configs) {
+					// Avoid leaking properties from other providers.
+					configs[name] = discriminatedProviderSettingsWithIdSchema.parse(configs[name])
+				}
+				return profiles
+			})
 		} catch (error) {
 			throw new Error(`Failed to export provider profiles: ${error}`)
 		}

+ 54 - 7
src/core/config/__tests__/ProviderSettingsManager.test.ts

@@ -247,10 +247,58 @@ describe("ProviderSettingsManager", () => {
 				},
 			}
 
-			expect(mockSecrets.store).toHaveBeenCalledWith(
-				"roo_cline_config_api_config",
-				JSON.stringify(expectedConfig, null, 2),
+			expect(mockSecrets.store.mock.calls[0][0]).toEqual("roo_cline_config_api_config")
+			expect(storedConfig).toEqual(expectedConfig)
+		})
+
+		it("should only save provider relevant settings", async () => {
+			mockSecrets.get.mockResolvedValue(
+				JSON.stringify({
+					currentApiConfigName: "default",
+					apiConfigs: {
+						default: {},
+					},
+					modeApiConfigs: {
+						code: "default",
+						architect: "default",
+						ask: "default",
+					},
+				}),
 			)
+
+			const newConfig: ProviderSettings = {
+				apiProvider: "anthropic",
+				apiKey: "test-key",
+			}
+			const newConfigWithExtra: ProviderSettings = {
+				...newConfig,
+				openRouterApiKey: "another-key",
+			}
+
+			await providerSettingsManager.saveConfig("test", newConfigWithExtra)
+
+			// Get the actual stored config to check the generated ID
+			const storedConfig = JSON.parse(mockSecrets.store.mock.lastCall[1])
+			const testConfigId = storedConfig.apiConfigs.test.id
+
+			const expectedConfig = {
+				currentApiConfigName: "default",
+				apiConfigs: {
+					default: {},
+					test: {
+						...newConfig,
+						id: testConfigId,
+					},
+				},
+				modeApiConfigs: {
+					code: "default",
+					architect: "default",
+					ask: "default",
+				},
+			}
+
+			expect(mockSecrets.store.mock.calls[0][0]).toEqual("roo_cline_config_api_config")
+			expect(storedConfig).toEqual(expectedConfig)
 		})
 
 		it("should update existing config", async () => {
@@ -291,10 +339,9 @@ describe("ProviderSettingsManager", () => {
 				},
 			}
 
-			expect(mockSecrets.store).toHaveBeenCalledWith(
-				"roo_cline_config_api_config",
-				JSON.stringify(expectedConfig, null, 2),
-			)
+			const storedConfig = JSON.parse(mockSecrets.store.mock.lastCall[1])
+			expect(mockSecrets.store.mock.lastCall[0]).toEqual("roo_cline_config_api_config")
+			expect(storedConfig).toEqual(expectedConfig)
 		})
 
 		it("should throw error if secrets storage fails", async () => {

+ 3 - 3
src/exports/roo-code.d.ts

@@ -121,14 +121,13 @@ type ProviderSettings = {
 	unboundModelId?: string | undefined
 	requestyApiKey?: string | undefined
 	requestyModelId?: string | undefined
+	fakeAi?: unknown | undefined
 	xaiApiKey?: string | undefined
 	groqApiKey?: string | undefined
 	chutesApiKey?: string | undefined
 	litellmBaseUrl?: string | undefined
 	litellmApiKey?: string | undefined
 	litellmModelId?: string | undefined
-	modelMaxTokens?: number | undefined
-	modelMaxThinkingTokens?: number | undefined
 	includeMaxTokens?: boolean | undefined
 	reasoningEffort?: ("low" | "medium" | "high") | undefined
 	promptCachingDisabled?: boolean | undefined
@@ -136,7 +135,8 @@ type ProviderSettings = {
 	fuzzyMatchThreshold?: number | undefined
 	modelTemperature?: (number | null) | undefined
 	rateLimitSeconds?: number | undefined
-	fakeAi?: unknown | undefined
+	modelMaxTokens?: number | undefined
+	modelMaxThinkingTokens?: number | undefined
 }
 
 type GlobalSettings = {

+ 3 - 3
src/exports/types.ts

@@ -122,14 +122,13 @@ type ProviderSettings = {
 	unboundModelId?: string | undefined
 	requestyApiKey?: string | undefined
 	requestyModelId?: string | undefined
+	fakeAi?: unknown | undefined
 	xaiApiKey?: string | undefined
 	groqApiKey?: string | undefined
 	chutesApiKey?: string | undefined
 	litellmBaseUrl?: string | undefined
 	litellmApiKey?: string | undefined
 	litellmModelId?: string | undefined
-	modelMaxTokens?: number | undefined
-	modelMaxThinkingTokens?: number | undefined
 	includeMaxTokens?: boolean | undefined
 	reasoningEffort?: ("low" | "medium" | "high") | undefined
 	promptCachingDisabled?: boolean | undefined
@@ -137,7 +136,8 @@ type ProviderSettings = {
 	fuzzyMatchThreshold?: number | undefined
 	modelTemperature?: (number | null) | undefined
 	rateLimitSeconds?: number | undefined
-	fakeAi?: unknown | undefined
+	modelMaxTokens?: number | undefined
+	modelMaxThinkingTokens?: number | undefined
 }
 
 export type { ProviderSettings }

+ 220 - 34
src/schemas/index.ts

@@ -345,23 +345,42 @@ type _AssertExperiments = AssertEqual<Equals<ExperimentId, Keys<Experiments>>>
  * ProviderSettings
  */
 
-export const providerSettingsSchema = z.object({
-	apiProvider: providerNamesSchema.optional(),
-	// Anthropic
+// Generic settings that apply to all providers
+const genericProviderSettingsSchema = z.object({
+	includeMaxTokens: z.boolean().optional(),
+	reasoningEffort: reasoningEffortsSchema.optional(),
+	promptCachingDisabled: z.boolean().optional(),
+	diffEnabled: z.boolean().optional(),
+	fuzzyMatchThreshold: z.number().optional(),
+	modelTemperature: z.number().nullish(),
+	rateLimitSeconds: z.number().optional(),
+	// Claude 3.7 Sonnet Thinking
+	modelMaxTokens: z.number().optional(),
+	modelMaxThinkingTokens: z.number().optional(),
+})
+
+// Provider-specific schemas
+const anthropicSchema = z.object({
 	apiModelId: z.string().optional(),
 	apiKey: z.string().optional(),
 	anthropicBaseUrl: z.string().optional(),
 	anthropicUseAuthToken: z.boolean().optional(),
-	// Glama
+})
+
+const glamaSchema = z.object({
 	glamaModelId: z.string().optional(),
 	glamaApiKey: z.string().optional(),
-	// OpenRouter
+})
+
+const openRouterSchema = z.object({
 	openRouterApiKey: z.string().optional(),
 	openRouterModelId: z.string().optional(),
 	openRouterBaseUrl: z.string().optional(),
 	openRouterSpecificProvider: z.string().optional(),
 	openRouterUseMiddleOutTransform: z.boolean().optional(),
-	// Amazon Bedrock
+})
+
+const bedrockSchema = z.object({
 	awsAccessKey: z.string().optional(),
 	awsSecretKey: z.string().optional(),
 	awsSessionToken: z.string().optional(),
@@ -371,12 +390,16 @@ export const providerSettingsSchema = z.object({
 	awsProfile: z.string().optional(),
 	awsUseProfile: z.boolean().optional(),
 	awsCustomArn: z.string().optional(),
-	// Google Vertex
+})
+
+const vertexSchema = z.object({
 	vertexKeyFile: z.string().optional(),
 	vertexJsonCredentials: z.string().optional(),
 	vertexProjectId: z.string().optional(),
 	vertexRegion: z.string().optional(),
-	// OpenAI
+})
+
+const openAiSchema = z.object({
 	openAiBaseUrl: z.string().optional(),
 	openAiApiKey: z.string().optional(),
 	openAiLegacyFormat: z.boolean().optional(),
@@ -389,10 +412,14 @@ export const providerSettingsSchema = z.object({
 	enableReasoningEffort: z.boolean().optional(),
 	openAiHostHeader: z.string().optional(), // Keep temporarily for backward compatibility during migration
 	openAiHeaders: z.record(z.string(), z.string()).optional(),
-	// Ollama
+})
+
+const ollamaSchema = z.object({
 	ollamaModelId: z.string().optional(),
 	ollamaBaseUrl: z.string().optional(),
-	// VS Code LM
+})
+
+const vsCodeLmSchema = z.object({
 	vsCodeLmModelSelector: z
 		.object({
 			vendor: z.string().optional(),
@@ -401,54 +428,213 @@ export const providerSettingsSchema = z.object({
 			id: z.string().optional(),
 		})
 		.optional(),
-	// LM Studio
+})
+
+const lmStudioSchema = z.object({
 	lmStudioModelId: z.string().optional(),
 	lmStudioBaseUrl: z.string().optional(),
 	lmStudioDraftModelId: z.string().optional(),
 	lmStudioSpeculativeDecodingEnabled: z.boolean().optional(),
-	// Gemini
+})
+
+const geminiSchema = z.object({
 	geminiApiKey: z.string().optional(),
 	googleGeminiBaseUrl: z.string().optional(),
-	// OpenAI Native
+})
+
+const openAiNativeSchema = z.object({
 	openAiNativeApiKey: z.string().optional(),
 	openAiNativeBaseUrl: z.string().optional(),
-	// Mistral
+})
+
+const mistralSchema = z.object({
 	mistralApiKey: z.string().optional(),
 	mistralCodestralUrl: z.string().optional(),
-	// DeepSeek
+})
+
+const deepSeekSchema = z.object({
 	deepSeekBaseUrl: z.string().optional(),
 	deepSeekApiKey: z.string().optional(),
-	// Unbound
+})
+
+const unboundSchema = z.object({
 	unboundApiKey: z.string().optional(),
 	unboundModelId: z.string().optional(),
-	// Requesty
+})
+
+const requestySchema = z.object({
 	requestyApiKey: z.string().optional(),
 	requestyModelId: z.string().optional(),
-	// X.AI (Grok)
+})
+
+const humanRelaySchema = z.object({})
+
+const fakeAiSchema = z.object({
+	fakeAi: z.unknown().optional(),
+})
+
+const xaiSchema = z.object({
 	xaiApiKey: z.string().optional(),
-	// Groq
+})
+
+const groqSchema = z.object({
 	groqApiKey: z.string().optional(),
-	// Chutes AI
+})
+
+const chutesSchema = z.object({
 	chutesApiKey: z.string().optional(),
-	// LiteLLM
+})
+
+const litellmSchema = z.object({
 	litellmBaseUrl: z.string().optional(),
 	litellmApiKey: z.string().optional(),
 	litellmModelId: z.string().optional(),
-	// Claude 3.7 Sonnet Thinking
-	modelMaxTokens: z.number().optional(),
-	modelMaxThinkingTokens: z.number().optional(),
-	// Generic
-	includeMaxTokens: z.boolean().optional(),
-	reasoningEffort: reasoningEffortsSchema.optional(),
-	promptCachingDisabled: z.boolean().optional(),
-	diffEnabled: z.boolean().optional(),
-	fuzzyMatchThreshold: z.number().optional(),
-	modelTemperature: z.number().nullish(),
-	rateLimitSeconds: z.number().optional(),
-	// Fake AI
-	fakeAi: z.unknown().optional(),
 })
 
+// Default schema for when apiProvider is not specified
+const defaultSchema = z.object({
+	apiProvider: z.undefined(),
+})
+
+// Create the discriminated union
+export const providerSettingsSchemaDiscriminated = z
+	.discriminatedUnion("apiProvider", [
+		anthropicSchema.merge(
+			z.object({
+				apiProvider: z.literal("anthropic"),
+			}),
+		),
+		glamaSchema.merge(
+			z.object({
+				apiProvider: z.literal("glama"),
+			}),
+		),
+		openRouterSchema.merge(
+			z.object({
+				apiProvider: z.literal("openrouter"),
+			}),
+		),
+		bedrockSchema.merge(
+			z.object({
+				apiProvider: z.literal("bedrock"),
+			}),
+		),
+		vertexSchema.merge(
+			z.object({
+				apiProvider: z.literal("vertex"),
+			}),
+		),
+		openAiSchema.merge(
+			z.object({
+				apiProvider: z.literal("openai"),
+			}),
+		),
+		ollamaSchema.merge(
+			z.object({
+				apiProvider: z.literal("ollama"),
+			}),
+		),
+		vsCodeLmSchema.merge(
+			z.object({
+				apiProvider: z.literal("vscode-lm"),
+			}),
+		),
+		lmStudioSchema.merge(
+			z.object({
+				apiProvider: z.literal("lmstudio"),
+			}),
+		),
+		geminiSchema.merge(
+			z.object({
+				apiProvider: z.literal("gemini"),
+			}),
+		),
+		openAiNativeSchema.merge(
+			z.object({
+				apiProvider: z.literal("openai-native"),
+			}),
+		),
+		mistralSchema.merge(
+			z.object({
+				apiProvider: z.literal("mistral"),
+			}),
+		),
+		deepSeekSchema.merge(
+			z.object({
+				apiProvider: z.literal("deepseek"),
+			}),
+		),
+		unboundSchema.merge(
+			z.object({
+				apiProvider: z.literal("unbound"),
+			}),
+		),
+		requestySchema.merge(
+			z.object({
+				apiProvider: z.literal("requesty"),
+			}),
+		),
+		humanRelaySchema.merge(
+			z.object({
+				apiProvider: z.literal("human-relay"),
+			}),
+		),
+		fakeAiSchema.merge(
+			z.object({
+				apiProvider: z.literal("fake-ai"),
+			}),
+		),
+		xaiSchema.merge(
+			z.object({
+				apiProvider: z.literal("xai"),
+			}),
+		),
+		groqSchema.merge(
+			z.object({
+				apiProvider: z.literal("groq"),
+			}),
+		),
+		chutesSchema.merge(
+			z.object({
+				apiProvider: z.literal("chutes"),
+			}),
+		),
+		litellmSchema.merge(
+			z.object({
+				apiProvider: z.literal("litellm"),
+			}),
+		),
+		defaultSchema,
+	])
+	.and(genericProviderSettingsSchema)
+
+export const providerSettingsSchema = z
+	.object({
+		apiProvider: providerNamesSchema.optional(),
+	})
+	.merge(anthropicSchema)
+	.merge(glamaSchema)
+	.merge(openRouterSchema)
+	.merge(bedrockSchema)
+	.merge(vertexSchema)
+	.merge(openAiSchema)
+	.merge(ollamaSchema)
+	.merge(vsCodeLmSchema)
+	.merge(lmStudioSchema)
+	.merge(geminiSchema)
+	.merge(openAiNativeSchema)
+	.merge(mistralSchema)
+	.merge(deepSeekSchema)
+	.merge(unboundSchema)
+	.merge(requestySchema)
+	.merge(humanRelaySchema)
+	.merge(fakeAiSchema)
+	.merge(xaiSchema)
+	.merge(groqSchema)
+	.merge(chutesSchema)
+	.merge(litellmSchema)
+	.merge(genericProviderSettingsSchema)
+
 export type ProviderSettings = z.infer<typeof providerSettingsSchema>
 
 type ProviderSettingsRecord = Record<Keys<ProviderSettings>, undefined>