Просмотр исходного кода

It's compliant if you're on your first attempt to fetch a session token (#4957)

John Richmond 6 месяцев назад
Родитель
Сommit
805fe7021d

+ 36 - 7
packages/cloud/src/AuthService.ts

@@ -11,6 +11,7 @@ import { RefreshTimer } from "./RefreshTimer"
 import { getUserAgent } from "./utils"
 
 export interface AuthServiceEvents {
+	"attempting-session": [data: { previousState: AuthState }]
 	"inactive-session": [data: { previousState: AuthState }]
 	"active-session": [data: { previousState: AuthState }]
 	"logged-out": [data: { previousState: AuthState }]
@@ -26,7 +27,7 @@ type AuthCredentials = z.infer<typeof authCredentialsSchema>
 
 const AUTH_STATE_KEY = "clerk-auth-state"
 
-type AuthState = "initializing" | "logged-out" | "active-session" | "inactive-session"
+type AuthState = "initializing" | "logged-out" | "active-session" | "attempting-session" | "inactive-session"
 
 const clerkSignInResponseSchema = z.object({
 	response: z.object({
@@ -93,6 +94,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 	private credentials: AuthCredentials | null = null
 	private sessionToken: string | null = null
 	private userInfo: CloudUserInfo | null = null
+	private isFirstRefreshAttempt: boolean = false
 
 	constructor(context: vscode.ExtensionContext, log?: (...args: unknown[]) => void) {
 		super()
@@ -129,7 +131,7 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 					this.credentials.clientToken !== credentials.clientToken ||
 					this.credentials.sessionId !== credentials.sessionId
 				) {
-					this.transitionToInactiveSession(credentials)
+					this.transitionToAttemptingSession(credentials)
 				}
 			} else {
 				if (this.state !== "logged-out") {
@@ -156,19 +158,32 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 		this.log("[auth] Transitioned to logged-out state")
 	}
 
-	private transitionToInactiveSession(credentials: AuthCredentials): void {
+	private transitionToAttemptingSession(credentials: AuthCredentials): void {
 		this.credentials = credentials
 
 		const previousState = this.state
-		this.state = "inactive-session"
+		this.state = "attempting-session"
 
 		this.sessionToken = null
 		this.userInfo = null
+		this.isFirstRefreshAttempt = true
 
-		this.emit("inactive-session", { previousState })
+		this.emit("attempting-session", { previousState })
 
 		this.timer.start()
 
+		this.log("[auth] Transitioned to attempting-session state")
+	}
+
+	private transitionToInactiveSession(): void {
+		const previousState = this.state
+		this.state = "inactive-session"
+
+		this.sessionToken = null
+		this.userInfo = null
+
+		this.emit("inactive-session", { previousState })
+
 		this.log("[auth] Transitioned to inactive-session state")
 	}
 
@@ -329,16 +344,27 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 	/**
 	 * Check if the user is authenticated
 	 *
-	 * @returns True if the user is authenticated (has an active or inactive session)
+	 * @returns True if the user is authenticated (has an active, attempting, or inactive session)
 	 */
 	public isAuthenticated(): boolean {
-		return this.state === "active-session" || this.state === "inactive-session"
+		return (
+			this.state === "active-session" || this.state === "attempting-session" || this.state === "inactive-session"
+		)
 	}
 
 	public hasActiveSession(): boolean {
 		return this.state === "active-session"
 	}
 
+	/**
+	 * Check if the user has an active session or is currently attempting to acquire one
+	 *
+	 * @returns True if the user has an active session or is attempting to get one
+	 */
+	public hasOrIsAcquiringActiveSession(): boolean {
+		return this.state === "active-session" || this.state === "attempting-session"
+	}
+
 	/**
 	 * Refresh the session
 	 *
@@ -364,6 +390,9 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 			if (error instanceof InvalidClientTokenError) {
 				this.log("[auth] Invalid/Expired client token: clearing credentials")
 				this.clearCredentials()
+			} else if (this.isFirstRefreshAttempt && this.state === "attempting-session") {
+				this.isFirstRefreshAttempt = false
+				this.transitionToInactiveSession()
 			}
 			this.log("[auth] Failed to refresh session", error)
 			throw error

+ 8 - 0
packages/cloud/src/CloudService.ts

@@ -40,6 +40,7 @@ export class CloudService {
 			this.authService = new AuthService(this.context, this.log)
 			await this.authService.initialize()
 
+			this.authService.on("attempting-session", this.authListener)
 			this.authService.on("inactive-session", this.authListener)
 			this.authService.on("active-session", this.authListener)
 			this.authService.on("logged-out", this.authListener)
@@ -89,6 +90,11 @@ export class CloudService {
 		return this.authService!.hasActiveSession()
 	}
 
+	public hasOrIsAcquiringActiveSession(): boolean {
+		this.ensureInitialized()
+		return this.authService!.hasOrIsAcquiringActiveSession()
+	}
+
 	public getUserInfo(): CloudUserInfo | null {
 		this.ensureInitialized()
 		return this.authService!.getUserInfo()
@@ -152,6 +158,8 @@ export class CloudService {
 
 	public dispose(): void {
 		if (this.authService) {
+			this.authService.off("attempting-session", this.authListener)
+			this.authService.off("inactive-session", this.authListener)
 			this.authService.off("active-session", this.authListener)
 			this.authService.off("logged-out", this.authListener)
 			this.authService.off("user-info", this.authListener)

+ 5 - 3
packages/cloud/src/SettingsService.ts

@@ -14,7 +14,6 @@ import { RefreshTimer } from "./RefreshTimer"
 const ORGANIZATION_SETTINGS_CACHE_KEY = "organization-settings"
 
 export class SettingsService {
-
 	private context: vscode.ExtensionContext
 	private authService: AuthService
 	private settings: OrganizationSettings | undefined = undefined
@@ -43,6 +42,10 @@ export class SettingsService {
 			this.removeSettings()
 		}
 
+		this.authService.on("attempting-session", () => {
+			this.timer.start()
+		})
+
 		this.authService.on("active-session", () => {
 			this.timer.start()
 		})
@@ -52,7 +55,7 @@ export class SettingsService {
 			this.removeSettings()
 		})
 
-		if (this.authService.hasActiveSession()) {
+		if (this.authService.hasOrIsAcquiringActiveSession()) {
 			this.timer.start()
 		}
 	}
@@ -120,5 +123,4 @@ export class SettingsService {
 	public dispose(): void {
 		this.timer.stop()
 	}
-
 }

+ 114 - 18
packages/cloud/src/__tests__/AuthService.spec.ts

@@ -173,17 +173,17 @@ describe("AuthService", () => {
 			expect(loggedOutSpy).toHaveBeenCalledWith({ previousState: "initializing" })
 		})
 
-		it("should transition to inactive-session when valid credentials exist", async () => {
+		it("should transition to attempting-session when valid credentials exist", async () => {
 			const credentials = { clientToken: "test-token", sessionId: "test-session" }
 			mockContext.secrets.get.mockResolvedValue(JSON.stringify(credentials))
 
-			const inactiveSessionSpy = vi.fn()
-			authService.on("inactive-session", inactiveSessionSpy)
+			const attemptingSessionSpy = vi.fn()
+			authService.on("attempting-session", attemptingSessionSpy)
 
 			await authService.initialize()
 
-			expect(authService.getState()).toBe("inactive-session")
-			expect(inactiveSessionSpy).toHaveBeenCalledWith({ previousState: "initializing" })
+			expect(authService.getState()).toBe("attempting-session")
+			expect(attemptingSessionSpy).toHaveBeenCalledWith({ previousState: "initializing" })
 			expect(mockTimer.start).toHaveBeenCalled()
 		})
 
@@ -213,13 +213,13 @@ describe("AuthService", () => {
 			const newCredentials = { clientToken: "new-token", sessionId: "new-session" }
 			mockContext.secrets.get.mockResolvedValue(JSON.stringify(newCredentials))
 
-			const inactiveSessionSpy = vi.fn()
-			authService.on("inactive-session", inactiveSessionSpy)
+			const attemptingSessionSpy = vi.fn()
+			authService.on("attempting-session", attemptingSessionSpy)
 
 			onDidChangeCallback!({ key: "clerk-auth-credentials" })
 			await new Promise((resolve) => setTimeout(resolve, 0)) // Wait for async handling
 
-			expect(inactiveSessionSpy).toHaveBeenCalled()
+			expect(attemptingSessionSpy).toHaveBeenCalled()
 		})
 	})
 
@@ -451,6 +451,26 @@ describe("AuthService", () => {
 
 			expect(authService.getSessionToken()).toBe("test-jwt")
 		})
+
+		it("should return correct values for new methods", async () => {
+			await authService.initialize()
+			expect(authService.hasOrIsAcquiringActiveSession()).toBe(false)
+
+			// Create a new service instance with credentials (attempting-session)
+			const credentials = { clientToken: "test-token", sessionId: "test-session" }
+			mockContext.secrets.get.mockResolvedValue(JSON.stringify(credentials))
+
+			const attemptingService = new AuthService(mockContext as unknown as vscode.ExtensionContext, mockLog)
+			await attemptingService.initialize()
+
+			expect(attemptingService.hasOrIsAcquiringActiveSession()).toBe(true)
+			expect(attemptingService.hasActiveSession()).toBe(false)
+
+			// Manually set state to active-session for testing
+			attemptingService["state"] = "active-session"
+			expect(attemptingService.hasOrIsAcquiringActiveSession()).toBe(true)
+			expect(attemptingService.hasActiveSession()).toBe(true)
+		})
 	})
 
 	describe("session refresh", () => {
@@ -497,7 +517,7 @@ describe("AuthService", () => {
 			expect(authService.getState()).toBe("active-session")
 			expect(authService.hasActiveSession()).toBe(true)
 			expect(authService.getSessionToken()).toBe("new-jwt-token")
-			expect(activeSessionSpy).toHaveBeenCalledWith({ previousState: "inactive-session" })
+			expect(activeSessionSpy).toHaveBeenCalledWith({ previousState: "attempting-session" })
 			expect(userInfoSpy).toHaveBeenCalledWith({
 				userInfo: {
 					name: "John Doe",
@@ -530,6 +550,82 @@ describe("AuthService", () => {
 			await expect(timerCallback()).rejects.toThrow("Network error")
 			expect(mockLog).toHaveBeenCalledWith("[auth] Failed to refresh session", expect.any(Error))
 		})
+
+		it("should transition to inactive-session on first attempt failure", async () => {
+			// Mock failed token creation response
+			mockFetch.mockResolvedValue({
+				ok: false,
+				status: 500,
+				statusText: "Internal Server Error",
+			})
+
+			const inactiveSessionSpy = vi.fn()
+			authService.on("inactive-session", inactiveSessionSpy)
+
+			// Verify we start in attempting-session state
+			expect(authService.getState()).toBe("attempting-session")
+			expect(authService["isFirstRefreshAttempt"]).toBe(true)
+
+			const timerCallback = vi.mocked(RefreshTimer).mock.calls[0][0].callback
+
+			await expect(timerCallback()).rejects.toThrow()
+
+			// Should transition to inactive-session after first failure
+			expect(authService.getState()).toBe("inactive-session")
+			expect(authService["isFirstRefreshAttempt"]).toBe(false)
+			expect(inactiveSessionSpy).toHaveBeenCalledWith({ previousState: "attempting-session" })
+		})
+
+		it("should not transition to inactive-session on subsequent failures", async () => {
+			// First, transition to inactive-session by failing the first attempt
+			mockFetch.mockResolvedValue({
+				ok: false,
+				status: 500,
+				statusText: "Internal Server Error",
+			})
+
+			const timerCallback = vi.mocked(RefreshTimer).mock.calls[0][0].callback
+			await expect(timerCallback()).rejects.toThrow()
+
+			// Verify we're now in inactive-session
+			expect(authService.getState()).toBe("inactive-session")
+			expect(authService["isFirstRefreshAttempt"]).toBe(false)
+
+			const inactiveSessionSpy = vi.fn()
+			authService.on("inactive-session", inactiveSessionSpy)
+
+			// Subsequent failure should not trigger another transition
+			await expect(timerCallback()).rejects.toThrow()
+
+			expect(authService.getState()).toBe("inactive-session")
+			expect(inactiveSessionSpy).not.toHaveBeenCalled()
+		})
+
+		it("should clear credentials on 401 during first refresh attempt (bug fix)", async () => {
+			// Mock 401 response during first refresh attempt
+			mockFetch.mockResolvedValue({
+				ok: false,
+				status: 401,
+				statusText: "Unauthorized",
+			})
+
+			const loggedOutSpy = vi.fn()
+			authService.on("logged-out", loggedOutSpy)
+
+			const timerCallback = vi.mocked(RefreshTimer).mock.calls[0][0].callback
+			await expect(timerCallback()).rejects.toThrow()
+
+			// Should clear credentials (not just transition to inactive-session)
+			expect(mockContext.secrets.delete).toHaveBeenCalledWith("clerk-auth-credentials")
+			expect(mockLog).toHaveBeenCalledWith("[auth] Invalid/Expired client token: clearing credentials")
+
+			// Simulate credentials cleared event
+			mockContext.secrets.get.mockResolvedValue(undefined)
+			await authService["handleCredentialsChange"]()
+
+			expect(authService.getState()).toBe("logged-out")
+			expect(loggedOutSpy).toHaveBeenCalledWith({ previousState: "attempting-session" })
+		})
 	})
 
 	describe("user info", () => {
@@ -654,16 +750,16 @@ describe("AuthService", () => {
 			expect(loggedOutSpy).toHaveBeenCalledWith({ previousState: "initializing" })
 		})
 
-		it("should emit inactive-session event", async () => {
+		it("should emit attempting-session event", async () => {
 			const credentials = { clientToken: "test-token", sessionId: "test-session" }
 			mockContext.secrets.get.mockResolvedValue(JSON.stringify(credentials))
 
-			const inactiveSessionSpy = vi.fn()
-			authService.on("inactive-session", inactiveSessionSpy)
+			const attemptingSessionSpy = vi.fn()
+			authService.on("attempting-session", attemptingSessionSpy)
 
 			await authService.initialize()
 
-			expect(inactiveSessionSpy).toHaveBeenCalledWith({ previousState: "initializing" })
+			expect(attemptingSessionSpy).toHaveBeenCalledWith({ previousState: "initializing" })
 		})
 
 		it("should emit active-session event", async () => {
@@ -701,7 +797,7 @@ describe("AuthService", () => {
 			// Wait for async operations to complete
 			await new Promise((resolve) => setTimeout(resolve, 0))
 
-			expect(activeSessionSpy).toHaveBeenCalledWith({ previousState: "inactive-session" })
+			expect(activeSessionSpy).toHaveBeenCalledWith({ previousState: "attempting-session" })
 		})
 
 		it("should emit user-info event", async () => {
@@ -803,7 +899,7 @@ describe("AuthService", () => {
 			expect(mockTimer.stop).toHaveBeenCalled()
 		})
 
-		it("should start timer on inactive-session transition", async () => {
+		it("should start timer on attempting-session transition", async () => {
 			const credentials = { clientToken: "test-token", sessionId: "test-session" }
 			mockContext.secrets.get.mockResolvedValue(JSON.stringify(credentials))
 
@@ -892,13 +988,13 @@ describe("AuthService", () => {
 			const newCredentials = { clientToken: "new-token", sessionId: "new-session" }
 			mockContext.secrets.get.mockResolvedValue(JSON.stringify(newCredentials))
 
-			const inactiveSessionSpy = vi.fn()
-			service.on("inactive-session", inactiveSessionSpy)
+			const attemptingSessionSpy = vi.fn()
+			service.on("attempting-session", attemptingSessionSpy)
 
 			onDidChangeCallback!({ key: `clerk-auth-credentials-${customUrl}` })
 			await new Promise((resolve) => setTimeout(resolve, 0)) // Wait for async handling
 
-			expect(inactiveSessionSpy).toHaveBeenCalled()
+			expect(attemptingSessionSpy).toHaveBeenCalled()
 		})
 
 		it("should not respond to changes on different scoped keys", async () => {

+ 2 - 2
src/services/mdm/MdmService.ts

@@ -85,8 +85,8 @@ export class MdmService {
 			return { compliant: true }
 		}
 
-		// Check if cloud service is available and authenticated
-		if (!CloudService.hasInstance() || !CloudService.instance.hasActiveSession()) {
+		// Check if cloud service is available and has active or attempting session
+		if (!CloudService.hasInstance() || !CloudService.instance.hasOrIsAcquiringActiveSession()) {
 			return {
 				compliant: false,
 				reason: "Your organization requires Roo Code Cloud authentication. Please sign in to continue.",

+ 19 - 3
src/services/mdm/__tests__/MdmService.spec.ts

@@ -16,6 +16,7 @@ vi.mock("@roo-code/cloud", () => ({
 		hasInstance: vi.fn(),
 		instance: {
 			hasActiveSession: vi.fn(),
+			hasOrIsAcquiringActiveSession: vi.fn(),
 			getOrganizationId: vi.fn(),
 		},
 	},
@@ -243,7 +244,7 @@ describe("MdmService", () => {
 			mockFs.readFileSync.mockReturnValue(JSON.stringify(mockConfig))
 
 			mockCloudService.hasInstance.mockReturnValue(true)
-			mockCloudService.instance.hasActiveSession.mockReturnValue(true)
+			mockCloudService.instance.hasOrIsAcquiringActiveSession.mockReturnValue(true)
 
 			const service = await MdmService.createInstance()
 			const compliance = service.isCompliant()
@@ -278,7 +279,7 @@ describe("MdmService", () => {
 
 			// Mock CloudService to have instance and active session but wrong org
 			mockCloudService.hasInstance.mockReturnValue(true)
-			mockCloudService.instance.hasActiveSession.mockReturnValue(true)
+			mockCloudService.instance.hasOrIsAcquiringActiveSession.mockReturnValue(true)
 			mockCloudService.instance.getOrganizationId.mockReturnValue("different-org-456")
 
 			const service = await MdmService.createInstance()
@@ -299,7 +300,7 @@ describe("MdmService", () => {
 			mockFs.readFileSync.mockReturnValue(JSON.stringify(mockConfig))
 
 			mockCloudService.hasInstance.mockReturnValue(true)
-			mockCloudService.instance.hasActiveSession.mockReturnValue(true)
+			mockCloudService.instance.hasOrIsAcquiringActiveSession.mockReturnValue(true)
 			mockCloudService.instance.getOrganizationId.mockReturnValue("correct-org-123")
 
 			const service = await MdmService.createInstance()
@@ -307,6 +308,21 @@ describe("MdmService", () => {
 
 			expect(compliance.compliant).toBe(true)
 		})
+
+		it("should be compliant when in attempting-session state", async () => {
+			const mockConfig = { requireCloudAuth: true }
+			mockFs.existsSync.mockReturnValue(true)
+			mockFs.readFileSync.mockReturnValue(JSON.stringify(mockConfig))
+
+			mockCloudService.hasInstance.mockReturnValue(true)
+			// Mock attempting session (not active, but acquiring)
+			mockCloudService.instance.hasOrIsAcquiringActiveSession.mockReturnValue(true)
+
+			const service = await MdmService.createInstance()
+			const compliance = service.isCompliant()
+
+			expect(compliance.compliant).toBe(true)
+		})
 	})
 
 	describe("cloud enablement", () => {