Browse Source

Add credentials auth for Google vertex

eong 1 year ago
parent
commit
01f83b4d17

+ 5 - 0
.changeset/tame-carpets-bake.md

@@ -0,0 +1,5 @@
+---
+"roo-cline": patch
+---
+
+Add credentials auth for Google vertex

+ 50 - 9
src/api/providers/vertex.ts

@@ -11,6 +11,7 @@ import { BaseProvider } from "./base-provider"
 
 import { ANTHROPIC_DEFAULT_MAX_TOKENS } from "./constants"
 import { getModelParams, SingleCompletionHandler } from "../"
+import { GoogleAuth } from "google-auth-library"
 
 // Types for Vertex SDK
 
@@ -120,16 +121,56 @@ export class VertexHandler extends BaseProvider implements SingleCompletionHandl
 			throw new Error(`Unknown model ID: ${this.options.apiModelId}`)
 		}
 
-		this.anthropicClient = new AnthropicVertex({
-			projectId: this.options.vertexProjectId ?? "not-provided",
-			// https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions
-			region: this.options.vertexRegion ?? "us-east5",
-		})
+		if (this.options.vertexJsonCredentials) {
+			this.anthropicClient = new AnthropicVertex({
+				projectId: this.options.vertexProjectId ?? "not-provided",
+				// https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions
+				region: this.options.vertexRegion ?? "us-east5",
+				googleAuth: new GoogleAuth({
+					scopes: ["https://www.googleapis.com/auth/cloud-platform"],
+					credentials: JSON.parse(this.options.vertexJsonCredentials),
+				}),
+			})
+		} else if (this.options.vertexKeyFile) {
+			this.anthropicClient = new AnthropicVertex({
+				projectId: this.options.vertexProjectId ?? "not-provided",
+				// https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions
+				region: this.options.vertexRegion ?? "us-east5",
+				googleAuth: new GoogleAuth({
+					scopes: ["https://www.googleapis.com/auth/cloud-platform"],
+					keyFile: this.options.vertexKeyFile,
+				}),
+			})
+		} else {
+			this.anthropicClient = new AnthropicVertex({
+				projectId: this.options.vertexProjectId ?? "not-provided",
+				// https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions
+				region: this.options.vertexRegion ?? "us-east5",
+			})
+		}
 
-		this.geminiClient = new VertexAI({
-			project: this.options.vertexProjectId ?? "not-provided",
-			location: this.options.vertexRegion ?? "us-east5",
-		})
+		if (this.options.vertexJsonCredentials) {
+			this.geminiClient = new VertexAI({
+				project: this.options.vertexProjectId ?? "not-provided",
+				location: this.options.vertexRegion ?? "us-east5",
+				googleAuthOptions: {
+					credentials: JSON.parse(this.options.vertexJsonCredentials),
+				},
+			})
+		} else if (this.options.vertexKeyFile) {
+			this.geminiClient = new VertexAI({
+				project: this.options.vertexProjectId ?? "not-provided",
+				location: this.options.vertexRegion ?? "us-east5",
+				googleAuthOptions: {
+					keyFile: this.options.vertexKeyFile,
+				},
+			})
+		} else {
+			this.geminiClient = new VertexAI({
+				project: this.options.vertexProjectId ?? "not-provided",
+				location: this.options.vertexRegion ?? "us-east5",
+			})
+		}
 	}
 
 	private formatMessageForCache(message: Anthropic.Messages.MessageParam, shouldCache: boolean): VertexMessage {

+ 10 - 0
src/core/webview/ClineProvider.ts

@@ -1659,6 +1659,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			awsUseCrossRegionInference,
 			awsProfile,
 			awsUseProfile,
+			vertexKeyFile,
+			vertexJsonCredentials,
 			vertexProjectId,
 			vertexRegion,
 			openAiBaseUrl,
@@ -1710,6 +1712,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.updateGlobalState("awsUseCrossRegionInference", awsUseCrossRegionInference),
 			this.updateGlobalState("awsProfile", awsProfile),
 			this.updateGlobalState("awsUseProfile", awsUseProfile),
+			this.updateGlobalState("vertexKeyFile", vertexKeyFile),
+			this.updateGlobalState("vertexJsonCredentials", vertexJsonCredentials),
 			this.updateGlobalState("vertexProjectId", vertexProjectId),
 			this.updateGlobalState("vertexRegion", vertexRegion),
 			this.updateGlobalState("openAiBaseUrl", openAiBaseUrl),
@@ -2160,6 +2164,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			awsUseCrossRegionInference,
 			awsProfile,
 			awsUseProfile,
+			vertexKeyFile,
+			vertexJsonCredentials,
 			vertexProjectId,
 			vertexRegion,
 			openAiBaseUrl,
@@ -2248,6 +2254,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 			this.getGlobalState("awsUseCrossRegionInference") as Promise<boolean | undefined>,
 			this.getGlobalState("awsProfile") as Promise<string | undefined>,
 			this.getGlobalState("awsUseProfile") as Promise<boolean | undefined>,
+			this.getGlobalState("vertexKeyFile") as Promise<string | undefined>,
+			this.getGlobalState("vertexJsonCredentials") as Promise<string | undefined>,
 			this.getGlobalState("vertexProjectId") as Promise<string | undefined>,
 			this.getGlobalState("vertexRegion") as Promise<string | undefined>,
 			this.getGlobalState("openAiBaseUrl") as Promise<string | undefined>,
@@ -2353,6 +2361,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
 				awsUseCrossRegionInference,
 				awsProfile,
 				awsUseProfile,
+				vertexKeyFile,
+				vertexJsonCredentials,
 				vertexProjectId,
 				vertexRegion,
 				openAiBaseUrl,

+ 2 - 0
src/shared/api.ts

@@ -40,6 +40,8 @@ export interface ApiHandlerOptions {
 	awsUseProfile?: boolean
 	vertexProjectId?: string
 	vertexRegion?: string
+	vertexKeyFile?: string
+	vertexJsonCredentials?: string
 	openAiBaseUrl?: string
 	openAiApiKey?: string
 	openAiModelId?: string

+ 2 - 0
src/shared/globalState.ts

@@ -22,6 +22,8 @@ export type GlobalStateKey =
 	| "awsUseCrossRegionInference"
 	| "awsProfile"
 	| "awsUseProfile"
+	| "vertexKeyFile"
+	| "vertexJsonCredentials"
 	| "vertexProjectId"
 	| "vertexRegion"
 	| "lastShownAnnouncementId"

+ 19 - 0
webview-ui/src/components/settings/ApiOptions.tsx

@@ -604,6 +604,20 @@ const ApiOptions = ({
 
 			{selectedProvider === "vertex" && (
 				<div style={{ display: "flex", flexDirection: "column", gap: 5 }}>
+					<VSCodeTextField
+						value={apiConfiguration?.vertexJsonCredentials || ""}
+						style={{ width: "100%" }}
+						onInput={handleInputChange("vertexJsonCredentials")}
+						placeholder="Enter Credentials JSON...">
+						<span className="font-medium">Google Cloud Credentials</span>
+					</VSCodeTextField>
+					<VSCodeTextField
+						value={apiConfiguration?.vertexKeyFile || ""}
+						style={{ width: "100%" }}
+						onInput={handleInputChange("vertexKeyFile")}
+						placeholder="Enter Key File Path...">
+						<span className="font-medium">Google Cloud Key File Path</span>
+					</VSCodeTextField>
 					<VSCodeTextField
 						value={apiConfiguration?.vertexProjectId || ""}
 						style={{ width: "100%" }}
@@ -649,6 +663,11 @@ const ApiOptions = ({
 							style={{ display: "inline", fontSize: "inherit" }}>
 							{"2) install the Google Cloud CLI › configure Application Default Credentials."}
 						</VSCodeLink>
+						<VSCodeLink
+							href="https://developers.google.com/workspace/guides/create-credentials?hl=en#service-account"
+							style={{ display: "inline", fontSize: "inherit" }}>
+							{"3) or create a service account with credentials."}
+						</VSCodeLink>
 					</p>
 				</div>
 			)}