Browse Source

Merge pull request #1705 from RooVetGit/cte/context-proxy-fixes

Fix API configuration profile switching
Matt Rubens 11 months ago
parent
commit
7bfdcbe64e

+ 61 - 0
src/core/__tests__/contextProxy.test.ts

@@ -250,6 +250,67 @@ describe("ContextProxy", () => {
 		})
 	})
 
+	describe("setApiConfiguration", () => {
+		it("should clear old API configuration values and set new ones", async () => {
+			// Set up initial API configuration values
+			await proxy.updateGlobalState("apiModelId", "old-model")
+			await proxy.updateGlobalState("openAiBaseUrl", "https://old-url.com")
+			await proxy.updateGlobalState("modelTemperature", 0.7)
+
+			// Spy on setValues
+			const setValuesSpy = jest.spyOn(proxy, "setValues")
+
+			// Call setApiConfiguration with new configuration
+			await proxy.setApiConfiguration({
+				apiModelId: "new-model",
+				apiProvider: "anthropic",
+				// Note: openAiBaseUrl is not included in the new config
+			})
+
+			// Verify setValues was called with the correct parameters
+			// It should include undefined for openAiBaseUrl (to clear it)
+			// and the new values for apiModelId and apiProvider
+			expect(setValuesSpy).toHaveBeenCalledWith(
+				expect.objectContaining({
+					apiModelId: "new-model",
+					apiProvider: "anthropic",
+					openAiBaseUrl: undefined,
+					modelTemperature: undefined,
+				}),
+			)
+
+			// Verify the state cache has been updated correctly
+			expect(proxy.getGlobalState("apiModelId")).toBe("new-model")
+			expect(proxy.getGlobalState("apiProvider")).toBe("anthropic")
+			expect(proxy.getGlobalState("openAiBaseUrl")).toBeUndefined()
+			expect(proxy.getGlobalState("modelTemperature")).toBeUndefined()
+		})
+
+		it("should handle empty API configuration", async () => {
+			// Set up initial API configuration values
+			await proxy.updateGlobalState("apiModelId", "old-model")
+			await proxy.updateGlobalState("openAiBaseUrl", "https://old-url.com")
+
+			// Spy on setValues
+			const setValuesSpy = jest.spyOn(proxy, "setValues")
+
+			// Call setApiConfiguration with empty configuration
+			await proxy.setApiConfiguration({})
+
+			// Verify setValues was called with undefined for all existing API config keys
+			expect(setValuesSpy).toHaveBeenCalledWith(
+				expect.objectContaining({
+					apiModelId: undefined,
+					openAiBaseUrl: undefined,
+				}),
+			)
+
+			// Verify the state cache has been cleared
+			expect(proxy.getGlobalState("apiModelId")).toBeUndefined()
+			expect(proxy.getGlobalState("openAiBaseUrl")).toBeUndefined()
+		})
+	})
+
 	describe("resetAllState", () => {
 		it("should clear all in-memory caches", async () => {
 			// Setup initial state in caches

+ 17 - 0
src/core/contextProxy.ts

@@ -11,6 +11,7 @@ import {
 	isSecretKey,
 	isGlobalStateKey,
 } from "../shared/globalState"
+import { API_CONFIG_KEYS, ApiConfiguration } from "../shared/api"
 
 export class ContextProxy {
 	private readonly originalContext: vscode.ExtensionContext
@@ -101,6 +102,7 @@ export class ContextProxy {
 			? this.originalContext.secrets.delete(key)
 			: this.originalContext.secrets.store(key, value)
 	}
+
 	/**
 	 * Set a value in either secrets or global state based on key type.
 	 * If the key is in SECRET_KEYS, it will be stored as a secret.
@@ -136,6 +138,21 @@ export class ContextProxy {
 		await Promise.all(promises)
 	}
 
+	async setApiConfiguration(apiConfiguration: ApiConfiguration) {
+		// Explicitly clear out any old API configuration values before that
+		// might not be present in the new configuration.
+		// If a value is not present in the new configuration, then it is assumed
+		// that the setting's value should be `undefined` and therefore we
+		// need to remove it from the state cache if it exists.
+		await this.setValues({
+			...API_CONFIG_KEYS.filter((key) => !!this.stateCache.get(key)).reduce(
+				(acc, key) => ({ ...acc, [key]: undefined }),
+				{} as Partial<ConfigurationValues>,
+			),
+			...apiConfiguration,
+		})
+	}
+
 	/**
 	 * Resets all global state, secrets, and in-memory caches.
 	 * This clears all data from both the in-memory caches and the VSCode storage.

+ 1 - 1
src/core/webview/ClineProvider.ts

@@ -1984,7 +1984,7 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			}
 		}
 
-		await this.contextProxy.setValues(apiConfiguration)
+		await this.contextProxy.setApiConfiguration(apiConfiguration)
 
 		if (this.getCurrentCline()) {
 			this.getCurrentCline()!.api = buildApiHandler(apiConfiguration)

+ 1 - 0
src/exports/api.ts

@@ -68,6 +68,7 @@ export class API extends EventEmitter<RooCodeEvents> implements RooCodeAPI {
 		await this.provider.postMessageToWebview({ type: "invoke", invoke: "secondaryButtonClick" })
 	}
 
+	// TODO: Change this to `setApiConfiguration`.
 	public async setConfiguration(values: Partial<ConfigurationValues>) {
 		await this.provider.setValues(values)
 	}

+ 3 - 1
src/shared/api.ts

@@ -85,7 +85,9 @@ export type ApiConfiguration = ApiHandlerOptions & {
 // Import GlobalStateKey type from globalState.ts
 import { GlobalStateKey } from "./globalState"
 
-// Define API configuration keys for dynamic object building
+// Define API configuration keys for dynamic object building.
+// TODO: This needs actual type safety; a type error should be thrown if
+// this is not an exhaustive list of all `GlobalStateKey` values.
 export const API_CONFIG_KEYS: GlobalStateKey[] = [
 	"apiModelId",
 	"anthropicBaseUrl",

+ 4 - 3
webview-ui/src/context/ExtensionStateContext.tsx

@@ -81,7 +81,6 @@ export const ExtensionStateContext = createContext<ExtensionStateContextType | u
 
 export const mergeExtensionState = (prevState: ExtensionState, newState: ExtensionState) => {
 	const {
-		apiConfiguration: prevApiConfiguration,
 		customModePrompts: prevCustomModePrompts,
 		customSupportPrompts: prevCustomSupportPrompts,
 		experiments: prevExperiments,
@@ -89,19 +88,21 @@ export const mergeExtensionState = (prevState: ExtensionState, newState: Extensi
 	} = prevState
 
 	const {
-		apiConfiguration: newApiConfiguration,
+		apiConfiguration,
 		customModePrompts: newCustomModePrompts,
 		customSupportPrompts: newCustomSupportPrompts,
 		experiments: newExperiments,
 		...newRest
 	} = newState
 
-	const apiConfiguration = { ...prevApiConfiguration, ...newApiConfiguration }
 	const customModePrompts = { ...prevCustomModePrompts, ...newCustomModePrompts }
 	const customSupportPrompts = { ...prevCustomSupportPrompts, ...newCustomSupportPrompts }
 	const experiments = { ...prevExperiments, ...newExperiments }
 	const rest = { ...prevRest, ...newRest }
 
+	// Note that we completely replace the previous apiConfiguration object with
+	// a new one since the state that is broadcast is the entire apiConfiguration
+	// and therefore merging is not necessary.
 	return { ...rest, apiConfiguration, customModePrompts, customSupportPrompts, experiments }
 }
 

+ 100 - 2
webview-ui/src/context/__tests__/ExtensionStateContext.test.tsx

@@ -1,4 +1,4 @@
-// npx jest webview-ui/src/context/__tests__/ExtensionStateContext.test.tsx
+// cd webview-ui && npx jest src/context/__tests__/ExtensionStateContext.test.tsx
 
 import { render, screen, act } from "@testing-library/react"
 
@@ -26,6 +26,24 @@ const TestComponent = () => {
 	)
 }
 
+// Test component for API configuration
+const ApiConfigTestComponent = () => {
+	const { apiConfiguration, setApiConfiguration } = useExtensionState()
+	return (
+		<div>
+			<div data-testid="api-configuration">{JSON.stringify(apiConfiguration)}</div>
+			<button
+				data-testid="update-api-config-button"
+				onClick={() => setApiConfiguration({ apiModelId: "new-model", apiProvider: "anthropic" })}>
+				Update API Config
+			</button>
+			<button data-testid="partial-update-button" onClick={() => setApiConfiguration({ modelTemperature: 0.7 })}>
+				Partial Update
+			</button>
+		</div>
+	)
+}
+
 describe("ExtensionStateContext", () => {
 	it("initializes with empty allowedCommands array", () => {
 		render(
@@ -96,6 +114,70 @@ describe("ExtensionStateContext", () => {
 
 		consoleSpy.mockRestore()
 	})
+
+	it("updates apiConfiguration through setApiConfiguration", () => {
+		render(
+			<ExtensionStateContextProvider>
+				<ApiConfigTestComponent />
+			</ExtensionStateContextProvider>,
+		)
+
+		const initialContent = screen.getByTestId("api-configuration").textContent!
+		expect(initialContent).toBeDefined()
+
+		act(() => {
+			screen.getByTestId("update-api-config-button").click()
+		})
+
+		const updatedContent = screen.getByTestId("api-configuration").textContent!
+		const updatedConfig = JSON.parse(updatedContent || "{}")
+
+		expect(updatedConfig).toEqual(
+			expect.objectContaining({
+				apiModelId: "new-model",
+				apiProvider: "anthropic",
+			}),
+		)
+	})
+
+	it("correctly merges partial updates to apiConfiguration", () => {
+		render(
+			<ExtensionStateContextProvider>
+				<ApiConfigTestComponent />
+			</ExtensionStateContextProvider>,
+		)
+
+		// First set the initial configuration
+		act(() => {
+			screen.getByTestId("update-api-config-button").click()
+		})
+
+		// Verify initial update
+		const initialContent = screen.getByTestId("api-configuration").textContent!
+		const initialConfig = JSON.parse(initialContent || "{}")
+		expect(initialConfig).toEqual(
+			expect.objectContaining({
+				apiModelId: "new-model",
+				apiProvider: "anthropic",
+			}),
+		)
+
+		// Now perform a partial update
+		act(() => {
+			screen.getByTestId("partial-update-button").click()
+		})
+
+		// Verify that the partial update was merged with the existing configuration
+		const updatedContent = screen.getByTestId("api-configuration").textContent!
+		const updatedConfig = JSON.parse(updatedContent || "{}")
+		expect(updatedConfig).toEqual(
+			expect.objectContaining({
+				apiModelId: "new-model", // Should retain this from previous update
+				apiProvider: "anthropic", // Should retain this from previous update
+				modelTemperature: 0.7, // Should add this from partial update
+			}),
+		)
+	})
 })
 
 describe("mergeExtensionState", () => {
@@ -125,19 +207,35 @@ describe("mergeExtensionState", () => {
 		const prevState: ExtensionState = {
 			...baseState,
 			apiConfiguration: { modelMaxTokens: 1234, modelMaxThinkingTokens: 123 },
+			experiments: {
+				experimentalDiffStrategy: true,
+				search_and_replace: true,
+				insert_content: true,
+			} as Record<ExperimentId, boolean>,
 		}
 
 		const newState: ExtensionState = {
 			...baseState,
 			apiConfiguration: { modelMaxThinkingTokens: 456, modelTemperature: 0.3 },
+			experiments: {
+				powerSteering: true,
+				multi_search_and_replace: true,
+			} as Record<ExperimentId, boolean>,
 		}
 
 		const result = mergeExtensionState(prevState, newState)
 
 		expect(result.apiConfiguration).toEqual({
-			modelMaxTokens: 1234,
 			modelMaxThinkingTokens: 456,
 			modelTemperature: 0.3,
 		})
+
+		expect(result.experiments).toEqual({
+			experimentalDiffStrategy: true,
+			search_and_replace: true,
+			insert_content: true,
+			powerSteering: true,
+			multi_search_and_replace: true,
+		})
 	})
 })