ソースを参照

Store the organization id in credentials (#5002)

* Store the organization id in credentials

* Better organization logic

* Fix tests

* Update cloud settings defaults

* Fix organization_id handling in Clerk API calls

Address review feedback by properly handling 3 cases for organization_id:
1. Have an org id: send organization_id=THE_ORG_ID
2. Have a personal account: send organization_id= (empty string)
3. Don't know if you have an org id (old credentials): don't send organization_id param at all

Changes:
- Updated clerkCreateSessionToken() to check credentials.organizationId !== undefined
- Updated fetchUserInfo() to handle all 3 cases consistently
- Added fallback logic for old credentials without organization context
- Improved logging for better debugging of organization context

* DRY up organization loading code in AuthService

Extract common organization membership processing logic into reusable helper methods:

- findOrganizationMembership(): Find specific org membership by ID
- findPrimaryOrganizationMembership(): Get first/primary org membership
- setUserOrganizationInfo(): Set organization info on user object

This eliminates duplication between the two clerkGetOrganizationMemberships()
call sites that were doing very similar organization data processing.
Matt Rubens 6 ヶ月 前
コミット
abaa3d8179

+ 98 - 15
packages/cloud/src/AuthService.ts

@@ -21,6 +21,7 @@ export interface AuthServiceEvents {
 const authCredentialsSchema = z.object({
 	clientToken: z.string().min(1, "Client token cannot be empty"),
 	sessionId: z.string().min(1, "Session ID cannot be empty"),
+	organizationId: z.string().nullable().optional(),
 })
 
 type AuthCredentials = z.infer<typeof authCredentialsSchema>
@@ -220,7 +221,16 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 
 		try {
 			const parsedJson = JSON.parse(credentialsJson)
-			return authCredentialsSchema.parse(parsedJson)
+			const credentials = authCredentialsSchema.parse(parsedJson)
+
+			// Migration: If no organizationId but we have userInfo, add it
+			if (credentials.organizationId === undefined && this.userInfo?.organizationId) {
+				credentials.organizationId = this.userInfo.organizationId
+				await this.storeCredentials(credentials)
+				this.log("[auth] Migrated credentials with organizationId")
+			}
+
+			return credentials
 		} catch (error) {
 			if (error instanceof z.ZodError) {
 				this.log("[auth] Invalid credentials format:", error.errors)
@@ -269,8 +279,13 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 	 *
 	 * @param code The authorization code from the callback
 	 * @param state The state parameter from the callback
+	 * @param organizationId The organization ID from the callback (null for personal accounts)
 	 */
-	public async handleCallback(code: string | null, state: string | null): Promise<void> {
+	public async handleCallback(
+		code: string | null,
+		state: string | null,
+		organizationId?: string | null,
+	): Promise<void> {
 		if (!code || !state) {
 			vscode.window.showInformationMessage("Invalid Roo Code Cloud sign in url")
 			return
@@ -287,6 +302,9 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 
 			const credentials = await this.clerkSignIn(code)
 
+			// Set organizationId (null for personal accounts)
+			credentials.organizationId = organizationId || null
+
 			await this.storeCredentials(credentials)
 
 			vscode.window.showInformationMessage("Successfully authenticated with Roo Code Cloud")
@@ -417,6 +435,15 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 		return this.userInfo
 	}
 
+	/**
+	 * Get the stored organization ID from credentials
+	 *
+	 * @returns The stored organization ID, null for personal accounts or if no credentials exist
+	 */
+	public getStoredOrganizationId(): string | null {
+		return this.credentials?.organizationId || null
+	}
+
 	private async clerkSignIn(ticket: string): Promise<AuthCredentials> {
 		const formData = new URLSearchParams()
 		formData.append("strategy", "ticket")
@@ -454,6 +481,17 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 		const formData = new URLSearchParams()
 		formData.append("_is_native", "1")
 
+		// Handle 3 cases for organization_id:
+		// 1. Have an org id: organization_id=THE_ORG_ID
+		// 2. Have a personal account: organization_id= (empty string)
+		// 3. Don't know if you have an org id (old style credentials): don't send organization_id param at all
+		const organizationId = this.getStoredOrganizationId()
+		if (this.credentials?.organizationId !== undefined) {
+			// We have organization context info (either org id or personal account)
+			formData.append("organization_id", organizationId || "")
+		}
+		// If organizationId is undefined, don't send the param at all (old credentials)
+
 		const response = await fetch(`${getClerkBaseUrl()}/v1/client/sessions/${this.credentials!.sessionId}/tokens`, {
 			method: "POST",
 			headers: {
@@ -505,29 +543,74 @@ export class AuthService extends EventEmitter<AuthServiceEvents> {
 
 		userInfo.picture = userData.image_url
 
-		// Fetch organization memberships separately
+		// Fetch organization info if user is in organization context
 		try {
-			const orgMemberships = await this.clerkGetOrganizationMemberships()
-			if (orgMemberships && orgMemberships.length > 0) {
-				// Get the first (or active) organization membership
-				const primaryOrgMembership = orgMemberships[0]
-				const organization = primaryOrgMembership?.organization
-
-				if (organization) {
-					userInfo.organizationId = organization.id
-					userInfo.organizationName = organization.name
-					userInfo.organizationRole = primaryOrgMembership.role
-					userInfo.organizationImageUrl = organization.image_url
+			const storedOrgId = this.getStoredOrganizationId()
+
+			if (this.credentials?.organizationId !== undefined) {
+				// We have organization context info
+				if (storedOrgId !== null) {
+					// User is in organization context - fetch user's memberships and filter
+					const orgMemberships = await this.clerkGetOrganizationMemberships()
+					const userMembership = this.findOrganizationMembership(orgMemberships, storedOrgId)
+
+					if (userMembership) {
+						this.setUserOrganizationInfo(userInfo, userMembership)
+						this.log("[auth] User in organization context:", {
+							id: userMembership.organization.id,
+							name: userMembership.organization.name,
+							role: userMembership.role,
+						})
+					} else {
+						this.log("[auth] Warning: User not found in stored organization:", storedOrgId)
+					}
+				} else {
+					this.log("[auth] User in personal account context - not setting organization info")
+				}
+			} else {
+				// Old credentials without organization context - fetch organization info to determine context
+				const orgMemberships = await this.clerkGetOrganizationMemberships()
+				const primaryOrgMembership = this.findPrimaryOrganizationMembership(orgMemberships)
+
+				if (primaryOrgMembership) {
+					this.setUserOrganizationInfo(userInfo, primaryOrgMembership)
+					this.log("[auth] Legacy credentials: Found organization membership:", {
+						id: primaryOrgMembership.organization.id,
+						name: primaryOrgMembership.organization.name,
+						role: primaryOrgMembership.role,
+					})
+				} else {
+					this.log("[auth] Legacy credentials: No organization memberships found")
 				}
 			}
 		} catch (error) {
-			this.log("[auth] Failed to fetch organization memberships:", error)
+			this.log("[auth] Failed to fetch organization info:", error)
 			// Don't throw - organization info is optional
 		}
 
 		return userInfo
 	}
 
+	private findOrganizationMembership(
+		memberships: CloudOrganizationMembership[],
+		organizationId: string,
+	): CloudOrganizationMembership | undefined {
+		return memberships?.find((membership) => membership.organization.id === organizationId)
+	}
+
+	private findPrimaryOrganizationMembership(
+		memberships: CloudOrganizationMembership[],
+	): CloudOrganizationMembership | undefined {
+		return memberships && memberships.length > 0 ? memberships[0] : undefined
+	}
+
+	private setUserOrganizationInfo(userInfo: CloudUserInfo, membership: CloudOrganizationMembership): void {
+		userInfo.organizationId = membership.organization.id
+		userInfo.organizationName = membership.organization.name
+		userInfo.organizationRole = membership.role
+		userInfo.organizationImageUrl = membership.organization.image_url
+	}
+
 	private async clerkGetOrganizationMemberships(): Promise<CloudOrganizationMembership[]> {
 		const response = await fetch(`${getClerkBaseUrl()}/v1/me/organization_memberships`, {
 			headers: {

+ 16 - 2
packages/cloud/src/CloudService.ts

@@ -118,14 +118,28 @@ export class CloudService {
 		return userInfo?.organizationRole || null
 	}
 
+	public hasStoredOrganizationId(): boolean {
+		this.ensureInitialized()
+		return this.authService!.getStoredOrganizationId() !== null
+	}
+
+	public getStoredOrganizationId(): string | null {
+		this.ensureInitialized()
+		return this.authService!.getStoredOrganizationId()
+	}
+
 	public getAuthState(): string {
 		this.ensureInitialized()
 		return this.authService!.getState()
 	}
 
-	public async handleAuthCallback(code: string | null, state: string | null): Promise<void> {
+	public async handleAuthCallback(
+		code: string | null,
+		state: string | null,
+		organizationId?: string | null,
+	): Promise<void> {
 		this.ensureInitialized()
-		return this.authService!.handleCallback(code, state)
+		return this.authService!.handleCallback(code, state, organizationId)
 	}
 
 	// SettingsService

+ 52 - 6
packages/cloud/src/__tests__/AuthService.spec.ts

@@ -328,7 +328,7 @@ describe("AuthService", () => {
 
 			expect(mockContext.secrets.store).toHaveBeenCalledWith(
 				"clerk-auth-credentials",
-				JSON.stringify({ clientToken: "Bearer token-123", sessionId: "session-123" }),
+				JSON.stringify({ clientToken: "Bearer token-123", sessionId: "session-123", organizationId: null }),
 			)
 			expect(mockShowInfo).toHaveBeenCalledWith("Successfully authenticated with Roo Code Cloud")
 		})
@@ -633,9 +633,55 @@ describe("AuthService", () => {
 			expect(authService.getUserInfo()).toBeNull()
 		})
 
-		it("should parse user info correctly", async () => {
-			// Set up with credentials
-			const credentials = { clientToken: "test-token", sessionId: "test-session" }
+		it("should parse user info correctly for personal accounts", async () => {
+			// Set up with credentials for personal account (no organizationId)
+			const credentials = { clientToken: "test-token", sessionId: "test-session", organizationId: null }
+			mockContext.secrets.get.mockResolvedValue(JSON.stringify(credentials))
+			await authService.initialize()
+
+			// Clear previous mock calls
+			mockFetch.mockClear()
+
+			// Mock successful responses
+			mockFetch
+				.mockResolvedValueOnce({
+					ok: true,
+					json: () => Promise.resolve({ jwt: "jwt-token" }),
+				})
+				.mockResolvedValueOnce({
+					ok: true,
+					json: () =>
+						Promise.resolve({
+							response: {
+								first_name: "Jane",
+								last_name: "Smith",
+								image_url: "https://example.com/jane.jpg",
+								primary_email_address_id: "email-2",
+								email_addresses: [
+									{ id: "email-1", email_address: "[email protected]" },
+									{ id: "email-2", email_address: "[email protected]" },
+								],
+							},
+						}),
+				})
+
+			const timerCallback = vi.mocked(RefreshTimer).mock.calls[0][0].callback
+			await timerCallback()
+
+			// Wait for async operations to complete
+			await new Promise((resolve) => setTimeout(resolve, 0))
+
+			const userInfo = authService.getUserInfo()
+			expect(userInfo).toEqual({
+				name: "Jane Smith",
+				email: "[email protected]",
+				picture: "https://example.com/jane.jpg",
+			})
+		})
+
+		it("should parse user info correctly for organization accounts", async () => {
+			// Set up with credentials for organization account
+			const credentials = { clientToken: "test-token", sessionId: "test-session", organizationId: "org_1" }
 			mockContext.secrets.get.mockResolvedValue(JSON.stringify(credentials))
 			await authService.initialize()
 
@@ -699,8 +745,8 @@ describe("AuthService", () => {
 		})
 
 		it("should handle missing user info fields", async () => {
-			// Set up with credentials
-			const credentials = { clientToken: "test-token", sessionId: "test-session" }
+			// Set up with credentials for personal account (no organizationId)
+			const credentials = { clientToken: "test-token", sessionId: "test-session", organizationId: null }
 			mockContext.secrets.get.mockResolvedValue(JSON.stringify(credentials))
 			await authService.initialize()
 

+ 37 - 1
packages/cloud/src/__tests__/CloudService.test.ts

@@ -40,6 +40,7 @@ describe("CloudService", () => {
 		getState: ReturnType<typeof vi.fn>
 		getSessionToken: ReturnType<typeof vi.fn>
 		handleCallback: ReturnType<typeof vi.fn>
+		getStoredOrganizationId: ReturnType<typeof vi.fn>
 		on: ReturnType<typeof vi.fn>
 		off: ReturnType<typeof vi.fn>
 		once: ReturnType<typeof vi.fn>
@@ -88,6 +89,7 @@ describe("CloudService", () => {
 			getState: vi.fn().mockReturnValue("logged-out"),
 			getSessionToken: vi.fn(),
 			handleCallback: vi.fn(),
+			getStoredOrganizationId: vi.fn().mockReturnValue(null),
 			on: vi.fn(),
 			off: vi.fn(),
 			once: vi.fn(),
@@ -255,7 +257,41 @@ describe("CloudService", () => {
 
 		it("should delegate handleAuthCallback to AuthService", async () => {
 			await cloudService.handleAuthCallback("code", "state")
-			expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state")
+			expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state", undefined)
+		})
+
+		it("should delegate handleAuthCallback with organizationId to AuthService", async () => {
+			await cloudService.handleAuthCallback("code", "state", "org_123")
+			expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state", "org_123")
+		})
+
+		it("should return stored organization ID from AuthService", () => {
+			mockAuthService.getStoredOrganizationId.mockReturnValue("org_456")
+
+			const result = cloudService.getStoredOrganizationId()
+			expect(mockAuthService.getStoredOrganizationId).toHaveBeenCalled()
+			expect(result).toBe("org_456")
+		})
+
+		it("should return null when no stored organization ID available", () => {
+			mockAuthService.getStoredOrganizationId.mockReturnValue(null)
+
+			const result = cloudService.getStoredOrganizationId()
+			expect(result).toBe(null)
+		})
+
+		it("should return true when stored organization ID exists", () => {
+			mockAuthService.getStoredOrganizationId.mockReturnValue("org_789")
+
+			const result = cloudService.hasStoredOrganizationId()
+			expect(result).toBe(true)
+		})
+
+		it("should return false when no stored organization ID exists", () => {
+			mockAuthService.getStoredOrganizationId.mockReturnValue(null)
+
+			const result = cloudService.hasStoredOrganizationId()
+			expect(result).toBe(false)
 		})
 	})
 

+ 1 - 0
packages/types/src/cloud.ts

@@ -125,6 +125,7 @@ export const ORGANIZATION_ALLOW_ALL: OrganizationAllowList = {
 export const ORGANIZATION_DEFAULT: OrganizationSettings = {
 	version: 0,
 	cloudSettings: {
+		recordTaskMessages: true,
 		enableTaskSharing: true,
 		taskShareExpirationDays: 30,
 	},

+ 7 - 1
src/activate/handleUri.ts

@@ -38,7 +38,13 @@ export const handleUri = async (uri: vscode.Uri) => {
 		case "/auth/clerk/callback": {
 			const code = query.get("code")
 			const state = query.get("state")
-			await CloudService.instance.handleAuthCallback(code, state)
+			const organizationId = query.get("organizationId")
+
+			await CloudService.instance.handleAuthCallback(
+				code,
+				state,
+				organizationId === "null" ? null : organizationId,
+			)
 			break
 		}
 		default:

+ 22 - 4
src/services/mdm/MdmService.ts

@@ -6,6 +6,7 @@ import { z } from "zod"
 
 import { CloudService, getClerkBaseUrl, PRODUCTION_CLERK_BASE_URL } from "@roo-code/cloud"
 import { Package } from "../../shared/package"
+import { t } from "../../i18n"
 
 // MDM Configuration Schema
 const mdmConfigSchema = z.object({
@@ -89,7 +90,7 @@ export class MdmService {
 		if (!CloudService.hasInstance() || !CloudService.instance.hasOrIsAcquiringActiveSession()) {
 			return {
 				compliant: false,
-				reason: "Your organization requires Roo Code Cloud authentication. Please sign in to continue.",
+				reason: t("mdm.errors.cloud_auth_required"),
 			}
 		}
 
@@ -97,18 +98,35 @@ export class MdmService {
 		const requiredOrgId = this.getRequiredOrganizationId()
 		if (requiredOrgId) {
 			try {
-				const currentOrgId = CloudService.instance.getOrganizationId()
+				// First try to get from active session
+				let currentOrgId = CloudService.instance.getOrganizationId()
+
+				// If no active session, check stored credentials
+				if (!currentOrgId) {
+					const storedOrgId = CloudService.instance.getStoredOrganizationId()
+
+					// null means personal account, which is not compliant for org requirements
+					if (storedOrgId === null || storedOrgId !== requiredOrgId) {
+						return {
+							compliant: false,
+							reason: t("mdm.errors.organization_mismatch"),
+						}
+					}
+
+					currentOrgId = storedOrgId
+				}
+
 				if (currentOrgId !== requiredOrgId) {
 					return {
 						compliant: false,
-						reason: "You must be authenticated with your organization's Roo Code Cloud account.",
+						reason: t("mdm.errors.organization_mismatch"),
 					}
 				}
 			} catch (error) {
 				this.log("[MDM] Error checking organization ID:", error)
 				return {
 					compliant: false,
-					reason: "Unable to verify organization authentication.",
+					reason: t("mdm.errors.verification_failed"),
 				}
 			}
 		}

+ 17 - 2
src/services/mdm/__tests__/MdmService.spec.ts

@@ -43,6 +43,19 @@ vi.mock("../../../shared/package", () => ({
 	},
 }))
 
+vi.mock("../../../i18n", () => ({
+	t: vi.fn((key: string) => {
+		const translations: Record<string, string> = {
+			"mdm.errors.cloud_auth_required":
+				"Your organization requires Roo Code Cloud authentication. Please sign in to continue.",
+			"mdm.errors.organization_mismatch":
+				"You must be authenticated with your organization's Roo Code Cloud account.",
+			"mdm.errors.verification_failed": "Unable to verify organization authentication.",
+		}
+		return translations[key] || key
+	}),
+}))
+
 import * as fs from "fs"
 import * as os from "os"
 import * as vscode from "vscode"
@@ -265,7 +278,7 @@ describe("MdmService", () => {
 
 			expect(compliance.compliant).toBe(false)
 			if (!compliance.compliant) {
-				expect(compliance.reason).toContain("requires Roo Code Cloud authentication")
+				expect(compliance.reason).toContain("Your organization requires Roo Code Cloud authentication")
 			}
 		})
 
@@ -287,7 +300,9 @@ describe("MdmService", () => {
 
 			expect(compliance.compliant).toBe(false)
 			if (!compliance.compliant) {
-				expect(compliance.reason).toContain("organization's Roo Code Cloud account")
+				expect(compliance.reason).toContain(
+					"You must be authenticated with your organization's Roo Code Cloud account",
+				)
 			}
 		})