Parcourir la source

Simplify CloudService callbacks (#4097)

- Remove AuthService callbacks, add user-info event
- Store userInfo on AuthService and fetch it when sessions go active
- Simplify CloudService callbacks to just stateChanged
- Add cloudUserInfo to state and remove authenticatedUser message
John Richmond il y a 7 mois
Parent
commit
d4c00e58ed

+ 20 - 22
packages/cloud/src/AuthService.ts

@@ -6,13 +6,13 @@ import * as vscode from "vscode"
 
 import type { CloudUserInfo } from "@roo-code/types"
 
-import { CloudServiceCallbacks } from "./types"
 import { getClerkBaseUrl, getRooCodeApiUrl } from "./Config"
 import { RefreshTimer } from "./RefreshTimer"
 
 export interface AuthServiceEvents {
 	"active-session": [data: { previousState: AuthState }]
 	"logged-out": [data: { previousState: AuthState }]
+	"user-info": [data: { userInfo: CloudUserInfo }]
 }
 
 const CLIENT_TOKEN_KEY = "clerk-client-token"
@@ -23,19 +23,18 @@ type AuthState = "initializing" | "logged-out" | "active-session" | "inactive-se
 
 export class AuthService extends EventEmitter<AuthServiceEvents> {
 	private context: vscode.ExtensionContext
-	private userChanged: CloudServiceCallbacks["userChanged"]
 	private timer: RefreshTimer
 	private state: AuthState = "initializing"
 
 	private clientToken: string | null = null
 	private sessionToken: string | null = null
 	private sessionId: string | null = null
+	private userInfo: CloudUserInfo | null = null
 
-	constructor(context: vscode.ExtensionContext, userChanged: CloudServiceCallbacks["userChanged"]) {
+	constructor(context: vscode.ExtensionContext) {
 		super()
 
 		this.context = context
-		this.userChanged = userChanged
 
 		this.timer = new RefreshTimer({
 			callback: async () => {
@@ -140,9 +139,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 			this.emit("active-session", { previousState })
 			this.timer.start()
 
-			if (this.userChanged) {
-				this.getUserInfo().then(this.userChanged)
-			}
+			this.fetchUserInfo()
 
 			vscode.window.showInformationMessage("Successfully authenticated with Roo Code Cloud")
 			console.log("[auth] Successfully authenticated with Roo Code Cloud")
@@ -174,6 +171,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 			this.clientToken = null
 			this.sessionToken = null
 			this.sessionId = null
+			this.userInfo = null
 			const previousState = this.state
 			this.state = "logged-out"
 			this.emit("logged-out", { previousState })
@@ -182,9 +180,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 				await this.clerkLogout(oldClientToken, oldSessionId)
 			}
 
-			if (this.userChanged) {
-				this.getUserInfo().then(this.userChanged)
-			}
+			this.fetchUserInfo()
 
 			vscode.window.showInformationMessage("Logged out from Roo Code Cloud")
 			console.log("[auth] Logged out from Roo Code Cloud")
@@ -224,7 +220,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 	 *
 	 * This method refreshes the session token using the client token.
 	 */
-	private async refreshSession() {
+	private async refreshSession(): Promise<void> {
 		if (!this.sessionId || !this.clientToken) {
 			console.log("[auth] Cannot refresh session: missing session ID or token")
 			this.state = "inactive-session"
@@ -237,11 +233,17 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 
 		if (previousState !== "active-session") {
 			this.emit("active-session", { previousState })
+			this.fetchUserInfo()
+		}
+	}
 
-			if (this.userChanged) {
-				this.getUserInfo().then(this.userChanged)
-			}
+	private async fetchUserInfo(): Promise<void> {
+		if (!this.clientToken) {
+			return
 		}
+
+		this.userInfo = await this.clerkMe()
+		this.emit("user-info", { userInfo: this.userInfo })
 	}
 
 	/**
@@ -249,12 +251,8 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 	 *
 	 * @returns User information from ID token claims or null if no ID token available
 	 */
-	public async getUserInfo(): Promise<CloudUserInfo | undefined> {
-		if (!this.clientToken) {
-			return undefined
-		}
-
-		return await this.clerkMe()
+	public getUserInfo(): CloudUserInfo | null {
+		return this.userInfo
 	}
 
 	private async clerkSignIn(
@@ -383,12 +381,12 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 		return this._instance
 	}
 
-	static async createInstance(context: vscode.ExtensionContext, userChanged: CloudServiceCallbacks["userChanged"]) {
+	static async createInstance(context: vscode.ExtensionContext) {
 		if (this._instance) {
 			throw new Error("AuthService instance already created")
 		}
 
-		this._instance = new AuthService(context, userChanged)
+		this._instance = new AuthService(context)
 		await this._instance.initialize()
 		return this._instance
 	}

+ 16 - 5
packages/cloud/src/CloudService.ts

@@ -13,6 +13,7 @@ export class CloudService {
 
 	private context: vscode.ExtensionContext
 	private callbacks: CloudServiceCallbacks
+	private authListener: () => void
 	private authService: AuthService | null = null
 	private settingsService: SettingsService | null = null
 	private telemetryClient: TelemetryClient | null = null
@@ -21,6 +22,9 @@ export class CloudService {
 	private constructor(context: vscode.ExtensionContext, callbacks: CloudServiceCallbacks) {
 		this.context = context
 		this.callbacks = callbacks
+		this.authListener = () => {
+			this.callbacks.stateChanged?.()
+		}
 	}
 
 	public async initialize(): Promise<void> {
@@ -29,12 +33,14 @@ export class CloudService {
 		}
 
 		try {
-			this.authService = await AuthService.createInstance(this.context, (userInfo) => {
-				this.callbacks.userChanged?.(userInfo)
-			})
+			this.authService = await AuthService.createInstance(this.context)
+
+			this.authService.on("active-session", this.authListener)
+			this.authService.on("logged-out", this.authListener)
+			this.authService.on("user-info", this.authListener)
 
 			this.settingsService = await SettingsService.createInstance(this.context, () =>
-				this.callbacks.settingsChanged?.(),
+				this.callbacks.stateChanged?.(),
 			)
 
 			this.telemetryClient = new TelemetryClient(this.authService)
@@ -74,7 +80,7 @@ export class CloudService {
 		return this.authService!.hasActiveSession()
 	}
 
-	public async getUserInfo(): Promise<CloudUserInfo | undefined> {
+	public getUserInfo(): CloudUserInfo | null {
 		this.ensureInitialized()
 		return this.authService!.getUserInfo()
 	}
@@ -106,6 +112,11 @@ export class CloudService {
 	// Lifecycle
 
 	public dispose(): void {
+		if (this.authService) {
+			this.authService.off("active-session", this.authListener)
+			this.authService.off("logged-out", this.authListener)
+			this.authService.off("user-info", this.authListener)
+		}
 		if (this.settingsService) {
 			this.settingsService.dispose()
 		}

+ 6 - 3
packages/cloud/src/__tests__/CloudService.test.ts

@@ -128,11 +128,14 @@ describe("CloudService", () => {
 
 	describe("createInstance", () => {
 		it("should create and initialize CloudService instance", async () => {
-			const callbacks = { userChanged: vi.fn(), settingsChanged: vi.fn() }
+			const callbacks = {
+				stateChanged: vi.fn(),
+			}
+
 			const cloudService = await CloudService.createInstance(mockContext, callbacks)
 
 			expect(cloudService).toBeInstanceOf(CloudService)
-			expect(AuthService.createInstance).toHaveBeenCalledWith(mockContext, expect.any(Function))
+			expect(AuthService.createInstance).toHaveBeenCalledWith(mockContext)
 			expect(SettingsService.createInstance).toHaveBeenCalledWith(mockContext, expect.any(Function))
 		})
 
@@ -150,7 +153,7 @@ describe("CloudService", () => {
 		let callbacks: CloudServiceCallbacks
 
 		beforeEach(async () => {
-			callbacks = { userChanged: vi.fn(), settingsChanged: vi.fn() }
+			callbacks = { stateChanged: vi.fn() }
 			cloudService = await CloudService.createInstance(mockContext, callbacks)
 		})
 

+ 1 - 4
packages/cloud/src/types.ts

@@ -1,6 +1,3 @@
-import { CloudUserInfo } from "@roo-code/types"
-
 export interface CloudServiceCallbacks {
-	userChanged?: (userInfo: CloudUserInfo | undefined) => void
-	settingsChanged?: () => void
+	stateChanged?: () => void
 }

+ 14 - 0
src/core/webview/ClineProvider.ts

@@ -23,6 +23,7 @@ import {
 	type TerminalActionPromptType,
 	type HistoryItem,
 	ORGANIZATION_ALLOW_ALL,
+	CloudUserInfo,
 } from "@roo-code/types"
 import { TelemetryService } from "@roo-code/telemetry"
 import { CloudService } from "@roo-code/cloud"
@@ -1296,6 +1297,7 @@ export class ClineProvider
 			maxReadFileLine,
 			terminalCompressProgressBar,
 			historyPreviewCollapsed,
+			cloudUserInfo,
 			organizationAllowList,
 			condensingApiConfigId,
 			customCondensingPrompt,
@@ -1391,6 +1393,7 @@ export class ClineProvider
 			terminalCompressProgressBar: terminalCompressProgressBar ?? true,
 			hasSystemPromptOverride,
 			historyPreviewCollapsed: historyPreviewCollapsed ?? false,
+			cloudUserInfo,
 			organizationAllowList,
 			condensingApiConfigId,
 			customCondensingPrompt,
@@ -1436,6 +1439,16 @@ export class ClineProvider
 			)
 		}
 
+		let cloudUserInfo: CloudUserInfo | null = null
+
+		try {
+			cloudUserInfo = CloudService.instance.getUserInfo()
+		} catch (error) {
+			console.error(
+				`[getState] failed to get cloud user info: ${error instanceof Error ? error.message : String(error)}`,
+			)
+		}
+
 		// Return the same structure as before
 		return {
 			apiConfiguration: providerSettings,
@@ -1504,6 +1517,7 @@ export class ClineProvider
 			showRooIgnoredFiles: stateValues.showRooIgnoredFiles ?? true,
 			maxReadFileLine: stateValues.maxReadFileLine ?? -1,
 			historyPreviewCollapsed: stateValues.historyPreviewCollapsed ?? false,
+			cloudUserInfo,
 			organizationAllowList,
 			// Explicitly add condensing settings
 			condensingApiConfigId: stateValues.condensingApiConfigId,

+ 1 - 0
src/core/webview/__tests__/ClineProvider.test.ts

@@ -423,6 +423,7 @@ describe("ClineProvider", () => {
 			showRooIgnoredFiles: true,
 			renderContext: "sidebar",
 			maxReadFileLine: 500,
+			cloudUserInfo: null,
 			organizationAllowList: ORGANIZATION_ALLOW_ALL,
 			autoCondenseContext: true,
 			autoCondenseContextPercent: 100,

+ 0 - 10
src/core/webview/webviewMessageHandler.ts

@@ -106,16 +106,6 @@ export const webviewMessageHandler = async (provider: ClineProvider, message: We
 					await Promise.all([
 						await updateGlobalState("listApiConfigMeta", listApiConfig),
 						await provider.postMessageToWebview({ type: "listApiConfig", listApiConfig }),
-						async () => {
-							try {
-								if (CloudService.instance.hasActiveSession()) {
-									const userInfo = await CloudService.instance.getUserInfo()
-									provider.postMessageToWebview({ type: "authenticatedUser", userInfo })
-								}
-							} catch (error) {
-								provider.log(`AuthService#getUserInfo failed: ${error}`)
-							}
-						},
 					])
 				})
 				.catch((error) =>

+ 1 - 3
src/extension.ts

@@ -70,9 +70,7 @@ export async function activate(context: vscode.ExtensionContext) {
 
 	// Initialize Roo Code Cloud service.
 	await CloudService.createInstance(context, {
-		userChanged: (userInfo) =>
-			ClineProvider.getVisibleInstance()?.postMessageToWebview({ type: "authenticatedUser", userInfo }),
-		settingsChanged: () => ClineProvider.getVisibleInstance()?.postStateToWebview(),
+		stateChanged: () => ClineProvider.getVisibleInstance()?.postStateToWebview(),
 	})
 
 	// Initialize i18n for internationalization support

+ 1 - 0
src/shared/ExtensionMessage.ts

@@ -217,6 +217,7 @@ export type ExtensionState = Pick<
 	settingsImportedAt?: number
 	historyPreviewCollapsed?: boolean
 
+	cloudUserInfo: CloudUserInfo | null
 	organizationAllowList: OrganizationAllowList
 
 	autoCondenseContext: boolean

+ 10 - 9
webview-ui/src/App.tsx

@@ -2,7 +2,6 @@ import { useCallback, useEffect, useRef, useState } from "react"
 import { useEvent } from "react-use"
 import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
 
-import type { CloudUserInfo } from "@roo-code/types"
 import { ExtensionMessage } from "@roo/ExtensionMessage"
 
 import TranslationProvider from "./i18n/TranslationContext"
@@ -30,12 +29,18 @@ const tabsByMessageAction: Partial<Record<NonNullable<ExtensionMessage["action"]
 }
 
 const App = () => {
-	const { didHydrateState, showWelcome, shouldShowAnnouncement, telemetrySetting, telemetryKey, machineId } =
-		useExtensionState()
+	const {
+		didHydrateState,
+		showWelcome,
+		shouldShowAnnouncement,
+		telemetrySetting,
+		telemetryKey,
+		machineId,
+		cloudUserInfo,
+	} = useExtensionState()
 
 	const [showAnnouncement, setShowAnnouncement] = useState(false)
 	const [tab, setTab] = useState<Tab>("chat")
-	const [userInfo, setUserInfo] = useState<CloudUserInfo | null>(null)
 
 	const [humanRelayDialogState, setHumanRelayDialogState] = useState<{
 		isOpen: boolean
@@ -84,10 +89,6 @@ const App = () => {
 			if (message.type === "acceptInput") {
 				chatViewRef.current?.acceptInput()
 			}
-
-			if (message.type === "authenticatedUser") {
-				setUserInfo(message.userInfo || null)
-			}
 		},
 		[switchTab],
 	)
@@ -126,7 +127,7 @@ const App = () => {
 			{tab === "settings" && (
 				<SettingsView ref={settingsRef} onDone={() => setTab("chat")} targetSection={currentSection} />
 			)}
-			{tab === "account" && <AccountView userInfo={userInfo} onDone={() => switchTab("chat")} />}
+			{tab === "account" && <AccountView userInfo={cloudUserInfo} onDone={() => switchTab("chat")} />}
 			<ChatView
 				ref={chatViewRef}
 				isHidden={tab !== "chat"}

+ 1 - 0
webview-ui/src/context/ExtensionStateContext.tsx

@@ -196,6 +196,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
 		terminalZdotdir: false, // Default ZDOTDIR handling setting
 		terminalCompressProgressBar: true, // Default to compress progress bar output
 		historyPreviewCollapsed: false, // Initialize the new state (default to expanded)
+		cloudUserInfo: null,
 		organizationAllowList: ORGANIZATION_ALLOW_ALL,
 		autoCondenseContext: true,
 		autoCondenseContextPercent: 100,

+ 1 - 0
webview-ui/src/context/__tests__/ExtensionStateContext.test.tsx

@@ -203,6 +203,7 @@ describe("mergeExtensionState", () => {
 			showRooIgnoredFiles: true,
 			renderContext: "sidebar",
 			maxReadFileLine: 500,
+			cloudUserInfo: null,
 			organizationAllowList: { allowAll: true, providers: {} },
 			autoCondenseContext: true,
 			autoCondenseContextPercent: 100,