Browse Source

Add gemini support

Saoud Rizwan 1 year ago
parent
commit
fbb7620fa1

+ 12 - 2
package-lock.json

@@ -1,17 +1,18 @@
 {
   "name": "claude-dev",
-  "version": "1.5.34",
+  "version": "1.6.4",
   "lockfileVersion": 3,
   "requires": true,
   "packages": {
     "": {
       "name": "claude-dev",
-      "version": "1.5.34",
+      "version": "1.6.4",
       "license": "MIT",
       "dependencies": {
         "@anthropic-ai/bedrock-sdk": "^0.10.2",
         "@anthropic-ai/sdk": "^0.26.0",
         "@anthropic-ai/vertex-sdk": "^0.4.1",
+        "@google/generative-ai": "^0.18.0",
         "@types/clone-deep": "^4.0.4",
         "@types/pdf-parse": "^1.1.4",
         "@vscode/codicons": "^0.0.36",
@@ -2635,6 +2636,15 @@
         "node": "^12.22.0 || ^14.17.0 || >=16.0.0"
       }
     },
+    "node_modules/@google/generative-ai": {
+      "version": "0.18.0",
+      "resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.18.0.tgz",
+      "integrity": "sha512-AhaIWSpk2tuhYHrBhUqC0xrWWznmYEja1/TRDIb+5kruBU5kUzMlFsXCQNO9PzyTZ4clUJ3CX/Rvy+Xm9x+w3g==",
+      "license": "Apache-2.0",
+      "engines": {
+        "node": ">=18.0.0"
+      }
+    },
     "node_modules/@humanwhocodes/config-array": {
       "version": "0.11.14",
       "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.14.tgz",

+ 1 - 0
package.json

@@ -151,6 +151,7 @@
     "@anthropic-ai/bedrock-sdk": "^0.10.2",
     "@anthropic-ai/sdk": "^0.26.0",
     "@anthropic-ai/vertex-sdk": "^0.4.1",
+    "@google/generative-ai": "^0.18.0",
     "@types/clone-deep": "^4.0.4",
     "@types/pdf-parse": "^1.1.4",
     "@vscode/codicons": "^0.0.36",

+ 57 - 0
src/api/gemini.ts

@@ -0,0 +1,57 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import { FunctionCallingMode, GoogleGenerativeAI } from "@google/generative-ai"
+import { ApiHandler, ApiHandlerMessageResponse } from "."
+import { ApiHandlerOptions, geminiDefaultModelId, GeminiModelId, geminiModels, ModelInfo } from "../shared/api"
+import {
+	convertAnthropicMessageToGemini,
+	convertAnthropicToolToGemini,
+	convertGeminiResponseToAnthropic,
+} from "../utils/gemini-format"
+
+export class GeminiHandler implements ApiHandler {
+	private options: ApiHandlerOptions
+	private client: GoogleGenerativeAI
+
+	constructor(options: ApiHandlerOptions) {
+		if (!options.geminiApiKey) {
+			throw new Error("API key is required for Google Gemini")
+		}
+		this.options = options
+		this.client = new GoogleGenerativeAI(options.geminiApiKey)
+	}
+
+	async createMessage(
+		systemPrompt: string,
+		messages: Anthropic.Messages.MessageParam[],
+		tools: Anthropic.Messages.Tool[]
+	): Promise<ApiHandlerMessageResponse> {
+		const model = this.client.getGenerativeModel({
+			model: this.getModel().id,
+			systemInstruction: systemPrompt,
+			tools: [{ functionDeclarations: tools.map(convertAnthropicToolToGemini) }],
+			toolConfig: {
+				functionCallingConfig: {
+					mode: FunctionCallingMode.AUTO,
+				},
+			},
+		})
+		const result = await model.generateContent({
+			contents: messages.map(convertAnthropicMessageToGemini),
+			generationConfig: {
+				maxOutputTokens: this.getModel().info.maxTokens,
+			},
+		})
+		const message = convertGeminiResponseToAnthropic(result.response)
+
+		return { message }
+	}
+
+	getModel(): { id: GeminiModelId; info: ModelInfo } {
+		const modelId = this.options.apiModelId
+		if (modelId && modelId in geminiModels) {
+			const id = modelId as GeminiModelId
+			return { id, info: geminiModels[id] }
+		}
+		return { id: geminiDefaultModelId, info: geminiModels[geminiDefaultModelId] }
+	}
+}

+ 3 - 0
src/api/index.ts

@@ -6,6 +6,7 @@ import { OpenRouterHandler } from "./openrouter"
 import { VertexHandler } from "./vertex"
 import { OpenAiHandler } from "./openai"
 import { OllamaHandler } from "./ollama"
+import { GeminiHandler } from "./gemini"
 
 export interface ApiHandlerMessageResponse {
 	message: Anthropic.Messages.Message
@@ -37,6 +38,8 @@ export function buildApiHandler(configuration: ApiConfiguration): ApiHandler {
 			return new OpenAiHandler(options)
 		case "ollama":
 			return new OllamaHandler(options)
+		case "gemini":
+			return new GeminiHandler(options)
 		default:
 			return new AnthropicHandler(options)
 	}

+ 14 - 1
src/providers/ClaudeDevProvider.ts

@@ -18,7 +18,14 @@ https://github.com/microsoft/vscode-webview-ui-toolkit-samples/blob/main/default
 https://github.com/KumarVariable/vscode-extension-sidebar-html/blob/master/src/customSidebarViewProvider.ts
 */
 
-type SecretKey = "apiKey" | "openRouterApiKey" | "awsAccessKey" | "awsSecretKey" | "awsSessionToken" | "openAiApiKey"
+type SecretKey =
+	| "apiKey"
+	| "openRouterApiKey"
+	| "awsAccessKey"
+	| "awsSecretKey"
+	| "awsSessionToken"
+	| "openAiApiKey"
+	| "geminiApiKey"
 type GlobalStateKey =
 	| "apiProvider"
 	| "apiModelId"
@@ -329,6 +336,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 								ollamaModelId,
 								ollamaBaseUrl,
 								anthropicBaseUrl,
+								geminiApiKey,
 							} = message.apiConfiguration
 							await this.updateGlobalState("apiProvider", apiProvider)
 							await this.updateGlobalState("apiModelId", apiModelId)
@@ -346,6 +354,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 							await this.updateGlobalState("ollamaModelId", ollamaModelId)
 							await this.updateGlobalState("ollamaBaseUrl", ollamaBaseUrl)
 							await this.updateGlobalState("anthropicBaseUrl", anthropicBaseUrl)
+							await this.storeSecret("geminiApiKey", geminiApiKey)
 							this.claudeDev?.updateApi(message.apiConfiguration)
 						}
 						await this.postStateToWebview()
@@ -667,6 +676,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 			ollamaModelId,
 			ollamaBaseUrl,
 			anthropicBaseUrl,
+			geminiApiKey,
 			lastShownAnnouncementId,
 			customInstructions,
 			alwaysAllowReadOnly,
@@ -688,6 +698,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 			this.getGlobalState("ollamaModelId") as Promise<string | undefined>,
 			this.getGlobalState("ollamaBaseUrl") as Promise<string | undefined>,
 			this.getGlobalState("anthropicBaseUrl") as Promise<string | undefined>,
+			this.getSecret("geminiApiKey") as Promise<string | undefined>,
 			this.getGlobalState("lastShownAnnouncementId") as Promise<string | undefined>,
 			this.getGlobalState("customInstructions") as Promise<string | undefined>,
 			this.getGlobalState("alwaysAllowReadOnly") as Promise<boolean | undefined>,
@@ -726,6 +737,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 				ollamaModelId,
 				ollamaBaseUrl,
 				anthropicBaseUrl,
+				geminiApiKey,
 			},
 			lastShownAnnouncementId,
 			customInstructions,
@@ -804,6 +816,7 @@ export class ClaudeDevProvider implements vscode.WebviewViewProvider {
 			"awsSecretKey",
 			"awsSessionToken",
 			"openAiApiKey",
+			"geminiApiKey",
 		]
 		for (const key of secretKeys) {
 			await this.storeSecret(key, undefined)

+ 25 - 1
src/shared/api.ts

@@ -1,4 +1,4 @@
-export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama"
+export type ApiProvider = "anthropic" | "openrouter" | "bedrock" | "vertex" | "openai" | "ollama" | "gemini"
 
 export interface ApiHandlerOptions {
 	apiModelId?: string
@@ -16,6 +16,7 @@ export interface ApiHandlerOptions {
 	openAiModelId?: string
 	ollamaModelId?: string
 	ollamaBaseUrl?: string
+	geminiApiKey?: string
 }
 
 export type ApiConfiguration = ApiHandlerOptions & {
@@ -305,3 +306,26 @@ export const openAiModelInfoSaneDefaults: ModelInfo = {
 	inputPrice: 0,
 	outputPrice: 0,
 }
+
+// Gemini
+// https://ai.google.dev/gemini-api/docs/models/gemini
+export type GeminiModelId = keyof typeof geminiModels
+export const geminiDefaultModelId: GeminiModelId = "gemini-1.5-flash-latest"
+export const geminiModels = {
+	"gemini-1.5-flash-latest": {
+		maxTokens: 8192,
+		contextWindow: 1_048_576,
+		supportsImages: true,
+		supportsPromptCache: false,
+		inputPrice: 0,
+		outputPrice: 0,
+	},
+	"gemini-1.5-pro-latest": {
+		maxTokens: 8192,
+		contextWindow: 2_097_152,
+		supportsImages: true,
+		supportsPromptCache: false,
+		inputPrice: 0,
+		outputPrice: 0,
+	},
+} as const satisfies Record<string, ModelInfo>

+ 137 - 0
src/utils/gemini-format.ts

@@ -0,0 +1,137 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import { Content, EnhancedGenerateContentResponse, FunctionDeclaration, Part, SchemaType } from "@google/generative-ai"
+
+export function convertAnthropicContentToGemini(
+	content:
+		| string
+		| Array<
+				| Anthropic.Messages.TextBlockParam
+				| Anthropic.Messages.ImageBlockParam
+				| Anthropic.Messages.ToolUseBlockParam
+				| Anthropic.Messages.ToolResultBlockParam
+		  >
+): Part[] {
+	if (typeof content === "string") {
+		return [{ text: content }]
+	}
+	return content.map((block) => {
+		switch (block.type) {
+			case "text":
+				return { text: block.text }
+			case "image":
+				if (block.source.type !== "base64") {
+					throw new Error("Unsupported image source type")
+				}
+				return {
+					inlineData: {
+						data: block.source.data,
+						mimeType: block.source.media_type,
+					},
+				}
+			case "tool_use":
+				return {
+					functionCall: {
+						name: block.name,
+						args: block.input,
+					},
+				} as Part
+			case "tool_result":
+				return {
+					functionResponse: {
+						name: block.tool_use_id,
+						response: {
+							content: block.content,
+						},
+					},
+				}
+			default:
+				throw new Error(`Unsupported content block type: ${(block as any).type}`)
+		}
+	})
+}
+
+export function convertAnthropicMessageToGemini(message: Anthropic.Messages.MessageParam): Content {
+	return {
+		role: message.role === "assistant" ? "model" : message.role,
+		parts: convertAnthropicContentToGemini(message.content),
+	}
+}
+
+export function convertAnthropicToolToGemini(tool: Anthropic.Messages.Tool): FunctionDeclaration {
+	return {
+		name: tool.name,
+		description: tool.description || "",
+		parameters: {
+			type: SchemaType.OBJECT,
+			properties: Object.fromEntries(
+				Object.entries(tool.input_schema.properties || {}).map(([key, value]) => [
+					key,
+					{
+						type: (value as any).type.toUpperCase(),
+						description: (value as any).description || "",
+					},
+				])
+			),
+			required: (tool.input_schema.required as string[]) || [],
+		},
+	}
+}
+
+export function convertGeminiResponseToAnthropic(
+	response: EnhancedGenerateContentResponse
+): Anthropic.Messages.Message {
+	const content: Anthropic.Messages.ContentBlock[] = []
+
+	// Add the main text response
+	const text = response.text()
+	if (text) {
+		content.push({ type: "text", text })
+	}
+
+	// Add function calls as tool_use blocks
+	const functionCalls = response.functionCalls()
+	if (functionCalls) {
+		functionCalls.forEach((call, index) => {
+			content.push({
+				type: "tool_use",
+				id: `tool_${index}`,
+				name: call.name,
+				input: call.args,
+			})
+		})
+	}
+
+	// Determine stop reason
+	let stop_reason: Anthropic.Messages.Message["stop_reason"] = null
+	const finishReason = response.candidates?.[0]?.finishReason
+	if (finishReason) {
+		switch (finishReason) {
+			case "STOP":
+				stop_reason = "end_turn"
+				break
+			case "MAX_TOKENS":
+				stop_reason = "max_tokens"
+				break
+			case "SAFETY":
+			case "RECITATION":
+			case "OTHER":
+				stop_reason = "stop_sequence"
+				break
+			// Add more cases if needed
+		}
+	}
+
+	return {
+		id: `msg_${Date.now()}`, // Generate a unique ID
+		type: "message",
+		role: "assistant",
+		content,
+		model: "",
+		stop_reason,
+		stop_sequence: null, // Gemini doesn't provide this information
+		usage: {
+			input_tokens: response.usageMetadata?.promptTokenCount ?? 0,
+			output_tokens: response.usageMetadata?.candidatesTokenCount ?? 0,
+		},
+	}
+}

+ 90 - 21
webview-ui/src/components/ApiOptions.tsx

@@ -1,13 +1,14 @@
 import {
+	VSCodeCheckbox,
 	VSCodeDropdown,
 	VSCodeLink,
 	VSCodeOption,
 	VSCodeRadio,
 	VSCodeRadioGroup,
 	VSCodeTextField,
-	VSCodeCheckbox,
 } from "@vscode/webview-ui-toolkit/react"
 import { memo, useCallback, useEffect, useMemo, useState } from "react"
+import { useEvent, useInterval } from "react-use"
 import {
 	ApiConfiguration,
 	ModelInfo,
@@ -15,17 +16,18 @@ import {
 	anthropicModels,
 	bedrockDefaultModelId,
 	bedrockModels,
+	geminiDefaultModelId,
+	geminiModels,
 	openAiModelInfoSaneDefaults,
 	openRouterDefaultModelId,
 	openRouterModels,
 	vertexDefaultModelId,
 	vertexModels,
 } from "../../../src/shared/api"
-import { useExtensionState } from "../context/ExtensionStateContext"
-import VSCodeButtonLink from "./VSCodeButtonLink"
 import { ExtensionMessage } from "../../../src/shared/ExtensionMessage"
-import { useEvent, useInterval } from "react-use"
+import { useExtensionState } from "../context/ExtensionStateContext"
 import { vscode } from "../utils/vscode"
+import VSCodeButtonLink from "./VSCodeButtonLink"
 
 interface ApiOptionsProps {
 	showModelOptions: boolean
@@ -113,6 +115,7 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
 					<VSCodeOption value="openrouter">OpenRouter</VSCodeOption>
 					<VSCodeOption value="bedrock">AWS Bedrock</VSCodeOption>
 					<VSCodeOption value="vertex">GCP Vertex AI</VSCodeOption>
+					<VSCodeOption value="gemini">Google Gemini</VSCodeOption>
 					<VSCodeOption value="openai">OpenAI Compatible</VSCodeOption>
 					<VSCodeOption value="ollama">Ollama</VSCodeOption>
 				</VSCodeDropdown>
@@ -161,7 +164,9 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
 						}}>
 						This key is stored locally and only used to make API requests from this extension.
 						{!apiConfiguration?.apiKey && (
-							<VSCodeLink href="https://console.anthropic.com/" style={{ display: "inline" }}>
+							<VSCodeLink
+								href="https://console.anthropic.com/"
+								style={{ display: "inline", fontSize: "inherit" }}>
 								You can get an Anthropic API key by signing up here.
 							</VSCodeLink>
 						)}
@@ -311,20 +316,48 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
 						To use Google Cloud Vertex AI, you need to
 						<VSCodeLink
 							href="https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#before_you_begin"
-							style={{ display: "inline" }}>
+							style={{ display: "inline", fontSize: "inherit" }}>
 							{
 								"1) create a Google Cloud account › enable the Vertex AI API › enable the desired Claude models,"
 							}
 						</VSCodeLink>{" "}
 						<VSCodeLink
 							href="https://cloud.google.com/docs/authentication/provide-credentials-adc#google-idp"
-							style={{ display: "inline" }}>
+							style={{ display: "inline", fontSize: "inherit" }}>
 							{"2) install the Google Cloud CLI › configure Application Default Credentials."}
 						</VSCodeLink>
 					</p>
 				</div>
 			)}
 
+			{selectedProvider === "gemini" && (
+				<div>
+					<VSCodeTextField
+						value={apiConfiguration?.geminiApiKey || ""}
+						style={{ width: "100%" }}
+						type="password"
+						onInput={handleInputChange("geminiApiKey")}
+						placeholder="Enter API Key...">
+						<span style={{ fontWeight: 500 }}>Gemini API Key</span>
+					</VSCodeTextField>
+					<p
+						style={{
+							fontSize: "12px",
+							marginTop: 3,
+							color: "var(--vscode-descriptionForeground)",
+						}}>
+						This key is stored locally and only used to make API requests from this extension.
+						{!apiConfiguration?.geminiApiKey && (
+							<VSCodeLink
+								href="https://ai.google.dev/"
+								style={{ display: "inline", fontSize: "inherit" }}>
+								You can get a Gemini API key by signing up here.
+							</VSCodeLink>
+						)}
+					</p>
+				</div>
+			)}
+
 			{selectedProvider === "openai" && (
 				<div>
 					<VSCodeTextField
@@ -418,11 +451,13 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
 						started, see their
 						<VSCodeLink
 							href="https://github.com/ollama/ollama/blob/main/README.md"
-							style={{ display: "inline" }}>
+							style={{ display: "inline", fontSize: "inherit" }}>
 							quickstart guide.
 						</VSCodeLink>{" "}
 						You can use any model that supports{" "}
-						<VSCodeLink href="https://ollama.com/search?c=tools" style={{ display: "inline" }}>
+						<VSCodeLink
+							href="https://ollama.com/search?c=tools"
+							style={{ display: "inline", fontSize: "inherit" }}>
 							tool use.
 						</VSCodeLink>
 						<span style={{ color: "var(--vscode-errorForeground)" }}>
@@ -454,9 +489,10 @@ const ApiOptions = ({ showModelOptions, apiErrorMessage }: ApiOptionsProps) => {
 						{selectedProvider === "openrouter" && createDropdown(openRouterModels)}
 						{selectedProvider === "bedrock" && createDropdown(bedrockModels)}
 						{selectedProvider === "vertex" && createDropdown(vertexModels)}
+						{selectedProvider === "gemini" && createDropdown(geminiModels)}
 					</div>
 
-					<ModelInfoView modelInfo={selectedModelInfo} />
+					<ModelInfoView selectedModelId={selectedModelId} modelInfo={selectedModelInfo} />
 				</>
 			)}
 		</div>
@@ -476,7 +512,8 @@ export const formatPrice = (price: number) => {
 	}).format(price)
 }
 
-const ModelInfoView = ({ modelInfo }: { modelInfo: ModelInfo }) => {
+const ModelInfoView = ({ selectedModelId, modelInfo }: { selectedModelId: string; modelInfo: ModelInfo }) => {
+	const isGemini = Object.keys(geminiModels).includes(selectedModelId)
 	return (
 		<p style={{ fontSize: "12px", marginTop: "2px", color: "var(--vscode-descriptionForeground)" }}>
 			<ModelInfoSupportsItem
@@ -485,27 +522,57 @@ const ModelInfoView = ({ modelInfo }: { modelInfo: ModelInfo }) => {
 				doesNotSupportLabel="Does not support images"
 			/>
 			<br />
-			<ModelInfoSupportsItem
-				isSupported={modelInfo.supportsPromptCache}
-				supportsLabel="Supports prompt caching"
-				doesNotSupportLabel="Does not support prompt caching"
-			/>
-			<br />
+			{!isGemini && (
+				<>
+					<ModelInfoSupportsItem
+						isSupported={modelInfo.supportsPromptCache}
+						supportsLabel="Supports prompt caching"
+						doesNotSupportLabel="Does not support prompt caching"
+					/>
+					<br />
+				</>
+			)}
 			<span style={{ fontWeight: 500 }}>Max output:</span> {modelInfo?.maxTokens?.toLocaleString()} tokens
 			<br />
-			<span style={{ fontWeight: 500 }}>Input price:</span> {formatPrice(modelInfo.inputPrice)}/million tokens
-			{modelInfo.supportsPromptCache && modelInfo.cacheWritesPrice && modelInfo.cacheReadsPrice && (
+			{modelInfo.inputPrice > 0 && (
 				<>
+					<span style={{ fontWeight: 500 }}>Input price:</span> {formatPrice(modelInfo.inputPrice)}/million
+					tokens
 					<br />
+				</>
+			)}
+			{modelInfo.supportsPromptCache && modelInfo.cacheWritesPrice && modelInfo.cacheReadsPrice && (
+				<>
 					<span style={{ fontWeight: 500 }}>Cache writes price:</span>{" "}
 					{formatPrice(modelInfo.cacheWritesPrice || 0)}/million tokens
 					<br />
 					<span style={{ fontWeight: 500 }}>Cache reads price:</span>{" "}
 					{formatPrice(modelInfo.cacheReadsPrice || 0)}/million tokens
+					<br />
+				</>
+			)}
+			{modelInfo.outputPrice > 0 && (
+				<>
+					<span style={{ fontWeight: 500 }}>Output price:</span> {formatPrice(modelInfo.outputPrice)}/million
+					tokens
+				</>
+			)}
+			{isGemini && (
+				<>
+					<span
+						style={{
+							fontStyle: "italic",
+						}}>
+						* Free up to {selectedModelId === geminiDefaultModelId ? "15" : "2"} requests per minute. After
+						that, billing depends on prompt size.{" "}
+						<VSCodeLink
+							href="https://ai.google.dev/pricing"
+							style={{ display: "inline", fontSize: "inherit" }}>
+							For more info, see pricing details.
+						</VSCodeLink>
+					</span>
 				</>
 			)}
-			<br />
-			<span style={{ fontWeight: 500 }}>Output price:</span> {formatPrice(modelInfo.outputPrice)}/million tokens
 		</p>
 	)
 }
@@ -563,6 +630,8 @@ export function normalizeApiConfiguration(apiConfiguration?: ApiConfiguration) {
 			return getProviderData(bedrockModels, bedrockDefaultModelId)
 		case "vertex":
 			return getProviderData(vertexModels, vertexDefaultModelId)
+		case "gemini":
+			return getProviderData(geminiModels, geminiDefaultModelId)
 		case "openai":
 			return {
 				selectedProvider: provider,

+ 39 - 36
webview-ui/src/components/TaskHeader.tsx

@@ -1,5 +1,5 @@
 import { VSCodeButton } from "@vscode/webview-ui-toolkit/react"
-import React, { memo, useEffect, useRef, useState } from "react"
+import React, { memo, useEffect, useMemo, useRef, useState } from "react"
 import { useWindowSize } from "react-use"
 import { ClaudeMessage } from "../../../src/shared/ExtensionMessage"
 import { useExtensionState } from "../context/ExtensionStateContext"
@@ -90,6 +90,14 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 		}
 	}, [task.text, windowWidth])
 
+	const isCostAvailable = useMemo(() => {
+		return (
+			apiConfiguration?.apiProvider !== "openai" &&
+			apiConfiguration?.apiProvider !== "ollama" &&
+			apiConfiguration?.apiProvider !== "gemini"
+		)
+	}, [apiConfiguration?.apiProvider])
+
 	return (
 		<div style={{ padding: "10px 13px 10px 13px" }}>
 			<div
@@ -140,25 +148,22 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 							{!isTaskExpanded && <span style={{ marginLeft: 4 }}>{task.text}</span>}
 						</div>
 					</div>
-					{!isTaskExpanded &&
-						apiConfiguration?.apiProvider !== "openai" &&
-						apiConfiguration?.apiProvider !== "ollama" && (
-							<div
-								style={{
-									marginLeft: 10,
-									backgroundColor:
-										"color-mix(in srgb, var(--vscode-badge-foreground) 70%, transparent)",
-									color: "var(--vscode-badge-background)",
-									padding: "2px 4px",
-									borderRadius: "500px",
-									fontSize: "11px",
-									fontWeight: 500,
-									display: "inline-block",
-									flexShrink: 0,
-								}}>
-								${totalCost?.toFixed(4)}
-							</div>
-						)}
+					{!isTaskExpanded && isCostAvailable && (
+						<div
+							style={{
+								marginLeft: 10,
+								backgroundColor: "color-mix(in srgb, var(--vscode-badge-foreground) 70%, transparent)",
+								color: "var(--vscode-badge-background)",
+								padding: "2px 4px",
+								borderRadius: "500px",
+								fontSize: "11px",
+								fontWeight: 500,
+								display: "inline-block",
+								flexShrink: 0,
+							}}>
+							${totalCost?.toFixed(4)}
+						</div>
+					)}
 					<VSCodeButton appearance="icon" onClick={onClose} style={{ marginLeft: 6, flexShrink: 0 }}>
 						<span className="codicon codicon-close"></span>
 					</VSCodeButton>
@@ -257,8 +262,7 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 										{tokensOut?.toLocaleString()}
 									</span>
 								</div>
-								{(apiConfiguration?.apiProvider === "openai" ||
-									apiConfiguration?.apiProvider === "ollama") && <ExportButton />}
+								{!isCostAvailable && <ExportButton />}
 							</div>
 
 							{(doesModelSupportPromptCache || cacheReads !== undefined || cacheWrites !== undefined) && (
@@ -280,21 +284,20 @@ const TaskHeader: React.FC<TaskHeaderProps> = ({
 									</span>
 								</div>
 							)}
-							{apiConfiguration?.apiProvider !== "openai" &&
-								apiConfiguration?.apiProvider !== "ollama" && (
-									<div
-										style={{
-											display: "flex",
-											justifyContent: "space-between",
-											alignItems: "center",
-										}}>
-										<div style={{ display: "flex", alignItems: "center", gap: "4px" }}>
-											<span style={{ fontWeight: "bold" }}>API Cost:</span>
-											<span>${totalCost?.toFixed(4)}</span>
-										</div>
-										<ExportButton />
+							{isCostAvailable && (
+								<div
+									style={{
+										display: "flex",
+										justifyContent: "space-between",
+										alignItems: "center",
+									}}>
+									<div style={{ display: "flex", alignItems: "center", gap: "4px" }}>
+										<span style={{ fontWeight: "bold" }}>API Cost:</span>
+										<span>${totalCost?.toFixed(4)}</span>
 									</div>
-								)}
+									<ExportButton />
+								</div>
+							)}
 						</div>
 					</>
 				)}

+ 1 - 0
webview-ui/src/context/ExtensionStateContext.tsx

@@ -40,6 +40,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode
 						config.vertexProjectId,
 						config.openAiApiKey,
 						config.ollamaModelId,
+						config.geminiApiKey,
 				  ].some((key) => key !== undefined)
 				: false
 			setShowWelcome(!hasKey)

+ 5 - 0
webview-ui/src/utils/validate.ts

@@ -23,6 +23,11 @@ export function validateApiConfiguration(apiConfiguration?: ApiConfiguration): s
 					return "You must provide a valid Google Cloud Project ID and Region."
 				}
 				break
+			case "gemini":
+				if (!apiConfiguration.geminiApiKey) {
+					return "You must provide a valid API key or choose a different provider."
+				}
+				break
 			case "openai":
 				if (
 					!apiConfiguration.openAiBaseUrl ||