Quellcode durchsuchen

Merge pull request #1282 from ashktn/feat/vertex-gemini

feat: Add support for Gemini models on Vertex AI
Matt Rubens vor 10 Monaten
Ursprung
Commit
8c8210ef50

+ 5 - 0
.changeset/dry-suits-shake.md

@@ -0,0 +1,5 @@
+---
+"roo-cline": minor
+---
+
+Add Gemini models on Vertex AI

+ 13 - 0
package-lock.json

@@ -12,6 +12,7 @@
 				"@anthropic-ai/sdk": "^0.37.0",
 				"@anthropic-ai/vertex-sdk": "^0.7.0",
 				"@aws-sdk/client-bedrock-runtime": "^3.706.0",
+				"@google-cloud/vertexai": "^1.9.3",
 				"@google/generative-ai": "^0.18.0",
 				"@mistralai/mistralai": "^1.3.6",
 				"@modelcontextprotocol/sdk": "^1.0.1",
@@ -3238,6 +3239,18 @@
 				"node": "^12.22.0 || ^14.17.0 || >=16.0.0"
 			}
 		},
+		"node_modules/@google-cloud/vertexai": {
+			"version": "1.9.3",
+			"resolved": "https://registry.npmjs.org/@google-cloud/vertexai/-/vertexai-1.9.3.tgz",
+			"integrity": "sha512-35o5tIEMLW3JeFJOaaMNR2e5sq+6rpnhrF97PuAxeOm0GlqVTESKhkGj7a5B5mmJSSSU3hUfIhcQCRRsw4Ipzg==",
+			"license": "Apache-2.0",
+			"dependencies": {
+				"google-auth-library": "^9.1.0"
+			},
+			"engines": {
+				"node": ">=18.0.0"
+			}
+		},
 		"node_modules/@google/generative-ai": {
 			"version": "0.18.0",
 			"resolved": "https://registry.npmjs.org/@google/generative-ai/-/generative-ai-0.18.0.tgz",

+ 1 - 0
package.json

@@ -313,6 +313,7 @@
 		"@anthropic-ai/vertex-sdk": "^0.7.0",
 		"@aws-sdk/client-bedrock-runtime": "^3.706.0",
 		"@google/generative-ai": "^0.18.0",
+		"@google-cloud/vertexai": "^1.9.3",
 		"@mistralai/mistralai": "^1.3.6",
 		"@modelcontextprotocol/sdk": "^1.0.1",
 		"@types/clone-deep": "^4.0.4",

+ 296 - 38
src/api/providers/__tests__/vertex.test.ts

@@ -6,6 +6,7 @@ import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta"
 
 import { VertexHandler } from "../vertex"
 import { ApiStreamChunk } from "../../transform/stream"
+import { VertexAI } from "@google-cloud/vertexai"
 
 // Mock Vertex SDK
 jest.mock("@anthropic-ai/vertex-sdk", () => ({
@@ -49,24 +50,100 @@ jest.mock("@anthropic-ai/vertex-sdk", () => ({
 	})),
 }))
 
-describe("VertexHandler", () => {
-	let handler: VertexHandler
+// Mock Vertex Gemini SDK
+jest.mock("@google-cloud/vertexai", () => {
+	const mockGenerateContentStream = jest.fn().mockImplementation(() => {
+		return {
+			stream: {
+				async *[Symbol.asyncIterator]() {
+					yield {
+						candidates: [
+							{
+								content: {
+									parts: [{ text: "Test Gemini response" }],
+								},
+							},
+						],
+					}
+				},
+			},
+			response: {
+				usageMetadata: {
+					promptTokenCount: 5,
+					candidatesTokenCount: 10,
+				},
+			},
+		}
+	})
 
-	beforeEach(() => {
-		handler = new VertexHandler({
-			apiModelId: "claude-3-5-sonnet-v2@20241022",
-			vertexProjectId: "test-project",
-			vertexRegion: "us-central1",
-		})
+	const mockGenerateContent = jest.fn().mockResolvedValue({
+		response: {
+			candidates: [
+				{
+					content: {
+						parts: [{ text: "Test Gemini response" }],
+					},
+				},
+			],
+		},
+	})
+
+	const mockGenerativeModel = jest.fn().mockImplementation(() => {
+		return {
+			generateContentStream: mockGenerateContentStream,
+			generateContent: mockGenerateContent,
+		}
 	})
 
+	return {
+		VertexAI: jest.fn().mockImplementation(() => {
+			return {
+				getGenerativeModel: mockGenerativeModel,
+			}
+		}),
+		GenerativeModel: mockGenerativeModel,
+	}
+})
+
+describe("VertexHandler", () => {
+	let handler: VertexHandler
+
 	describe("constructor", () => {
-		it("should initialize with provided config", () => {
+		it("should initialize with provided config for Claude", () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			expect(AnthropicVertex).toHaveBeenCalledWith({
 				projectId: "test-project",
 				region: "us-central1",
 			})
 		})
+
+		it("should initialize with provided config for Gemini", () => {
+			handler = new VertexHandler({
+				apiModelId: "gemini-1.5-pro-001",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
+			expect(VertexAI).toHaveBeenCalledWith({
+				project: "test-project",
+				location: "us-central1",
+			})
+		})
+
+		it("should throw error for invalid model", () => {
+			expect(() => {
+				new VertexHandler({
+					apiModelId: "invalid-model",
+					vertexProjectId: "test-project",
+					vertexRegion: "us-central1",
+				})
+			}).toThrow("Unknown model ID: invalid-model")
+		})
 	})
 
 	describe("createMessage", () => {
@@ -83,7 +160,13 @@ describe("VertexHandler", () => {
 
 		const systemPrompt = "You are a helpful assistant"
 
-		it("should handle streaming responses correctly", async () => {
+		it("should handle streaming responses correctly for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const mockStream = [
 				{
 					type: "message_start",
@@ -127,7 +210,7 @@ describe("VertexHandler", () => {
 			}
 
 			const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
-			;(handler["client"].messages as any).create = mockCreate
+			;(handler["anthropicClient"].messages as any).create = mockCreate
 
 			const stream = handler.createMessage(systemPrompt, mockMessages)
 			const chunks: ApiStreamChunk[] = []
@@ -187,7 +270,58 @@ describe("VertexHandler", () => {
 			})
 		})
 
-		it("should handle multiple content blocks with line breaks", async () => {
+		it("should handle streaming responses correctly for Gemini", async () => {
+			const mockGemini = require("@google-cloud/vertexai")
+			const mockGenerateContentStream = mockGemini.VertexAI().getGenerativeModel().generateContentStream
+			handler = new VertexHandler({
+				apiModelId: "gemini-1.5-pro-001",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
+			const stream = handler.createMessage(systemPrompt, mockMessages)
+			const chunks: ApiStreamChunk[] = []
+
+			for await (const chunk of stream) {
+				chunks.push(chunk)
+			}
+
+			expect(chunks.length).toBe(2)
+			expect(chunks[0]).toEqual({
+				type: "text",
+				text: "Test Gemini response",
+			})
+			expect(chunks[1]).toEqual({
+				type: "usage",
+				inputTokens: 5,
+				outputTokens: 10,
+			})
+
+			expect(mockGenerateContentStream).toHaveBeenCalledWith({
+				contents: [
+					{
+						role: "user",
+						parts: [{ text: "Hello" }],
+					},
+					{
+						role: "model",
+						parts: [{ text: "Hi there!" }],
+					},
+				],
+				generationConfig: {
+					maxOutputTokens: 16384,
+					temperature: 0,
+				},
+			})
+		})
+
+		it("should handle multiple content blocks with line breaks for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const mockStream = [
 				{
 					type: "content_block_start",
@@ -216,7 +350,7 @@ describe("VertexHandler", () => {
 			}
 
 			const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
-			;(handler["client"].messages as any).create = mockCreate
+			;(handler["anthropicClient"].messages as any).create = mockCreate
 
 			const stream = handler.createMessage(systemPrompt, mockMessages)
 			const chunks: ApiStreamChunk[] = []
@@ -240,10 +374,16 @@ describe("VertexHandler", () => {
 			})
 		})
 
-		it("should handle API errors", async () => {
+		it("should handle API errors for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const mockError = new Error("Vertex API error")
 			const mockCreate = jest.fn().mockRejectedValue(mockError)
-			;(handler["client"].messages as any).create = mockCreate
+			;(handler["anthropicClient"].messages as any).create = mockCreate
 
 			const stream = handler.createMessage(systemPrompt, mockMessages)
 
@@ -254,7 +394,13 @@ describe("VertexHandler", () => {
 			}).rejects.toThrow("Vertex API error")
 		})
 
-		it("should handle prompt caching for supported models", async () => {
+		it("should handle prompt caching for supported models for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const mockStream = [
 				{
 					type: "message_start",
@@ -299,7 +445,7 @@ describe("VertexHandler", () => {
 			}
 
 			const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
-			;(handler["client"].messages as any).create = mockCreate
+			;(handler["anthropicClient"].messages as any).create = mockCreate
 
 			const stream = handler.createMessage(systemPrompt, [
 				{
@@ -383,7 +529,13 @@ describe("VertexHandler", () => {
 			)
 		})
 
-		it("should handle cache-related usage metrics", async () => {
+		it("should handle cache-related usage metrics for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const mockStream = [
 				{
 					type: "message_start",
@@ -415,7 +567,7 @@ describe("VertexHandler", () => {
 			}
 
 			const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
-			;(handler["client"].messages as any).create = mockCreate
+			;(handler["anthropicClient"].messages as any).create = mockCreate
 
 			const stream = handler.createMessage(systemPrompt, mockMessages)
 			const chunks: ApiStreamChunk[] = []
@@ -442,7 +594,13 @@ describe("VertexHandler", () => {
 
 		const systemPrompt = "You are a helpful assistant"
 
-		it("should handle thinking content blocks and deltas", async () => {
+		it("should handle thinking content blocks and deltas for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const mockStream = [
 				{
 					type: "message_start",
@@ -488,7 +646,7 @@ describe("VertexHandler", () => {
 			}
 
 			const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
-			;(handler["client"].messages as any).create = mockCreate
+			;(handler["anthropicClient"].messages as any).create = mockCreate
 
 			const stream = handler.createMessage(systemPrompt, mockMessages)
 			const chunks: ApiStreamChunk[] = []
@@ -510,7 +668,13 @@ describe("VertexHandler", () => {
 			expect(textChunks[1].text).toBe("Here's my answer:")
 		})
 
-		it("should handle multiple thinking blocks with line breaks", async () => {
+		it("should handle multiple thinking blocks with line breaks for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const mockStream = [
 				{
 					type: "content_block_start",
@@ -539,7 +703,7 @@ describe("VertexHandler", () => {
 			}
 
 			const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
-			;(handler["client"].messages as any).create = mockCreate
+			;(handler["anthropicClient"].messages as any).create = mockCreate
 
 			const stream = handler.createMessage(systemPrompt, mockMessages)
 			const chunks: ApiStreamChunk[] = []
@@ -565,10 +729,16 @@ describe("VertexHandler", () => {
 	})
 
 	describe("completePrompt", () => {
-		it("should complete prompt successfully", async () => {
+		it("should complete prompt successfully for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const result = await handler.completePrompt("Test prompt")
 			expect(result).toBe("Test response")
-			expect(handler["client"].messages.create).toHaveBeenCalledWith({
+			expect(handler["anthropicClient"].messages.create).toHaveBeenCalledWith({
 				model: "claude-3-5-sonnet-v2@20241022",
 				max_tokens: 8192,
 				temperature: 0,
@@ -583,31 +753,109 @@ describe("VertexHandler", () => {
 			})
 		})
 
-		it("should handle API errors", async () => {
+		it("should complete prompt successfully for Gemini", async () => {
+			const mockGemini = require("@google-cloud/vertexai")
+			const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
+
+			handler = new VertexHandler({
+				apiModelId: "gemini-1.5-pro-001",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
+			const result = await handler.completePrompt("Test prompt")
+			expect(result).toBe("Test Gemini response")
+			expect(mockGenerateContent).toHaveBeenCalled()
+			expect(mockGenerateContent).toHaveBeenCalledWith({
+				contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
+				generationConfig: {
+					temperature: 0,
+				},
+			})
+		})
+
+		it("should handle API errors for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const mockError = new Error("Vertex API error")
 			const mockCreate = jest.fn().mockRejectedValue(mockError)
-			;(handler["client"].messages as any).create = mockCreate
+			;(handler["anthropicClient"].messages as any).create = mockCreate
+
+			await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
+				"Vertex completion error: Vertex API error",
+			)
+		})
+
+		it("should handle API errors for Gemini", async () => {
+			const mockGemini = require("@google-cloud/vertexai")
+			const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
+			mockGenerateContent.mockRejectedValue(new Error("Vertex API error"))
+			handler = new VertexHandler({
+				apiModelId: "gemini-1.5-pro-001",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
 
 			await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
 				"Vertex completion error: Vertex API error",
 			)
 		})
 
-		it("should handle non-text content", async () => {
+		it("should handle non-text content for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const mockCreate = jest.fn().mockResolvedValue({
 				content: [{ type: "image" }],
 			})
-			;(handler["client"].messages as any).create = mockCreate
+			;(handler["anthropicClient"].messages as any).create = mockCreate
 
 			const result = await handler.completePrompt("Test prompt")
 			expect(result).toBe("")
 		})
 
-		it("should handle empty response", async () => {
+		it("should handle empty response for Claude", async () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const mockCreate = jest.fn().mockResolvedValue({
 				content: [{ type: "text", text: "" }],
 			})
-			;(handler["client"].messages as any).create = mockCreate
+			;(handler["anthropicClient"].messages as any).create = mockCreate
+
+			const result = await handler.completePrompt("Test prompt")
+			expect(result).toBe("")
+		})
+
+		it("should handle empty response for Gemini", async () => {
+			const mockGemini = require("@google-cloud/vertexai")
+			const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
+			mockGenerateContent.mockResolvedValue({
+				response: {
+					candidates: [
+						{
+							content: {
+								parts: [{ text: "" }],
+							},
+						},
+					],
+				},
+			})
+			handler = new VertexHandler({
+				apiModelId: "gemini-1.5-pro-001",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
 
 			const result = await handler.completePrompt("Test prompt")
 			expect(result).toBe("")
@@ -615,7 +863,13 @@ describe("VertexHandler", () => {
 	})
 
 	describe("getModel", () => {
-		it("should return correct model info", () => {
+		it("should return correct model info for Claude", () => {
+			handler = new VertexHandler({
+				apiModelId: "claude-3-5-sonnet-v2@20241022",
+				vertexProjectId: "test-project",
+				vertexRegion: "us-central1",
+			})
+
 			const modelInfo = handler.getModel()
 			expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022")
 			expect(modelInfo.info).toBeDefined()
@@ -623,14 +877,18 @@ describe("VertexHandler", () => {
 			expect(modelInfo.info.contextWindow).toBe(200_000)
 		})
 
-		it("should return default model if invalid model specified", () => {
-			const invalidHandler = new VertexHandler({
-				apiModelId: "invalid-model",
+		it("should return correct model info for Gemini", () => {
+			handler = new VertexHandler({
+				apiModelId: "gemini-2.0-flash-001",
 				vertexProjectId: "test-project",
 				vertexRegion: "us-central1",
 			})
-			const modelInfo = invalidHandler.getModel()
-			expect(modelInfo.id).toBe("claude-3-7-sonnet@20250219") // Default model
+
+			const modelInfo = handler.getModel()
+			expect(modelInfo.id).toBe("gemini-2.0-flash-001")
+			expect(modelInfo.info).toBeDefined()
+			expect(modelInfo.info.maxTokens).toBe(8192)
+			expect(modelInfo.info.contextWindow).toBe(1048576)
 		})
 	})
 
@@ -724,7 +982,7 @@ describe("VertexHandler", () => {
 					},
 				}
 			})
-			;(thinkingHandler["client"].messages as any).create = mockCreate
+			;(thinkingHandler["anthropicClient"].messages as any).create = mockCreate
 
 			await thinkingHandler
 				.createMessage("You are a helpful assistant", [{ role: "user", content: "Hello" }])

+ 121 - 7
src/api/providers/vertex.ts

@@ -5,6 +5,8 @@ import { ApiHandler, SingleCompletionHandler } from "../"
 import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta"
 import { ApiHandlerOptions, ModelInfo, vertexDefaultModelId, VertexModelId, vertexModels } from "../../shared/api"
 import { ApiStream } from "../transform/stream"
+import { VertexAI } from "@google-cloud/vertexai"
+import { convertAnthropicMessageToVertexGemini } from "../transform/vertex-gemini-format"
 
 // Types for Vertex SDK
 
@@ -91,19 +93,37 @@ interface VertexMessageStreamEvent {
 				thinking: string
 		  }
 }
-
 // https://docs.anthropic.com/en/api/claude-on-vertex-ai
 export class VertexHandler implements ApiHandler, SingleCompletionHandler {
+	MODEL_CLAUDE = "claude"
+	MODEL_GEMINI = "gemini"
+
 	private options: ApiHandlerOptions
-	private client: AnthropicVertex
+	private anthropicClient: AnthropicVertex
+	private geminiClient: VertexAI
+	private modelType: string
 
 	constructor(options: ApiHandlerOptions) {
 		this.options = options
-		this.client = new AnthropicVertex({
+
+		if (this.options.apiModelId?.startsWith(this.MODEL_CLAUDE)) {
+			this.modelType = this.MODEL_CLAUDE
+		} else if (this.options.apiModelId?.startsWith(this.MODEL_GEMINI)) {
+			this.modelType = this.MODEL_GEMINI
+		} else {
+			throw new Error(`Unknown model ID: ${this.options.apiModelId}`)
+		}
+
+		this.anthropicClient = new AnthropicVertex({
 			projectId: this.options.vertexProjectId ?? "not-provided",
 			// https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions
 			region: this.options.vertexRegion ?? "us-east5",
 		})
+
+		this.geminiClient = new VertexAI({
+			project: this.options.vertexProjectId ?? "not-provided",
+			location: this.options.vertexRegion ?? "us-east5",
+		})
 	}
 
 	private formatMessageForCache(message: Anthropic.Messages.MessageParam, shouldCache: boolean): VertexMessage {
@@ -154,7 +174,42 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 		}
 	}
 
-	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+	private async *createGeminiMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+		const model = this.geminiClient.getGenerativeModel({
+			model: this.getModel().id,
+			systemInstruction: systemPrompt,
+		})
+
+		const result = await model.generateContentStream({
+			contents: messages.map(convertAnthropicMessageToVertexGemini),
+			generationConfig: {
+				maxOutputTokens: this.getModel().info.maxTokens,
+				temperature: this.options.modelTemperature ?? 0,
+			},
+		})
+
+		for await (const chunk of result.stream) {
+			if (chunk.candidates?.[0]?.content?.parts) {
+				for (const part of chunk.candidates[0].content.parts) {
+					if (part.text) {
+						yield {
+							type: "text",
+							text: part.text,
+						}
+					}
+				}
+			}
+		}
+
+		const response = await result.response
+		yield {
+			type: "usage",
+			inputTokens: response.usageMetadata?.promptTokenCount ?? 0,
+			outputTokens: response.usageMetadata?.candidatesTokenCount ?? 0,
+		}
+	}
+
+	private async *createClaudeMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
 		const model = this.getModel()
 		let { id, info, temperature, maxTokens, thinking } = model
 		const useCache = model.info.supportsPromptCache
@@ -192,7 +247,7 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 			stream: true,
 		}
 
-		const stream = (await this.client.messages.create(
+		const stream = (await this.anthropicClient.messages.create(
 			params as Anthropic.Messages.MessageCreateParamsStreaming,
 		)) as unknown as AnthropicStream<VertexMessageStreamEvent>
 
@@ -272,6 +327,22 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 		}
 	}
 
+	async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
+		switch (this.modelType) {
+			case this.MODEL_CLAUDE: {
+				yield* this.createClaudeMessage(systemPrompt, messages)
+				break
+			}
+			case this.MODEL_GEMINI: {
+				yield* this.createGeminiMessage(systemPrompt, messages)
+				break
+			}
+			default: {
+				throw new Error(`Invalid model type: ${this.modelType}`)
+			}
+		}
+	}
+
 	getModel(): {
 		id: VertexModelId
 		info: ModelInfo
@@ -316,7 +387,36 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 		return { id, info, temperature, maxTokens, thinking }
 	}
 
-	async completePrompt(prompt: string): Promise<string> {
+	private async completePromptGemini(prompt: string): Promise<string> {
+		try {
+			const model = this.geminiClient.getGenerativeModel({
+				model: this.getModel().id,
+			})
+
+			const result = await model.generateContent({
+				contents: [{ role: "user", parts: [{ text: prompt }] }],
+				generationConfig: {
+					temperature: this.options.modelTemperature ?? 0,
+				},
+			})
+
+			let text = ""
+			result.response.candidates?.forEach((candidate) => {
+				candidate.content.parts.forEach((part) => {
+					text += part.text
+				})
+			})
+
+			return text
+		} catch (error) {
+			if (error instanceof Error) {
+				throw new Error(`Vertex completion error: ${error.message}`)
+			}
+			throw error
+		}
+	}
+
+	private async completePromptClaude(prompt: string): Promise<string> {
 		try {
 			let { id, info, temperature, maxTokens, thinking } = this.getModel()
 			const useCache = info.supportsPromptCache
@@ -344,7 +444,7 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 				stream: false,
 			}
 
-			const response = (await this.client.messages.create(
+			const response = (await this.anthropicClient.messages.create(
 				params as Anthropic.Messages.MessageCreateParamsNonStreaming,
 			)) as unknown as VertexMessageResponse
 
@@ -360,4 +460,18 @@ export class VertexHandler implements ApiHandler, SingleCompletionHandler {
 			throw error
 		}
 	}
+
+	async completePrompt(prompt: string): Promise<string> {
+		switch (this.modelType) {
+			case this.MODEL_CLAUDE: {
+				return this.completePromptClaude(prompt)
+			}
+			case this.MODEL_GEMINI: {
+				return this.completePromptGemini(prompt)
+			}
+			default: {
+				throw new Error(`Invalid model type: ${this.modelType}`)
+			}
+		}
+	}
 }

+ 338 - 0
src/api/transform/__tests__/vertex-gemini-format.test.ts

@@ -0,0 +1,338 @@
+// npx jest src/api/transform/__tests__/vertex-gemini-format.test.ts
+
+import { Anthropic } from "@anthropic-ai/sdk"
+
+import { convertAnthropicMessageToVertexGemini } from "../vertex-gemini-format"
+
+describe("convertAnthropicMessageToVertexGemini", () => {
+	it("should convert a simple text message", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "user",
+			content: "Hello, world!",
+		}
+
+		const result = convertAnthropicMessageToVertexGemini(anthropicMessage)
+
+		expect(result).toEqual({
+			role: "user",
+			parts: [{ text: "Hello, world!" }],
+		})
+	})
+
+	it("should convert assistant role to model role", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "assistant",
+			content: "I'm an assistant",
+		}
+
+		const result = convertAnthropicMessageToVertexGemini(anthropicMessage)
+
+		expect(result).toEqual({
+			role: "model",
+			parts: [{ text: "I'm an assistant" }],
+		})
+	})
+
+	it("should convert a message with text blocks", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "user",
+			content: [
+				{ type: "text", text: "First paragraph" },
+				{ type: "text", text: "Second paragraph" },
+			],
+		}
+
+		const result = convertAnthropicMessageToVertexGemini(anthropicMessage)
+
+		expect(result).toEqual({
+			role: "user",
+			parts: [{ text: "First paragraph" }, { text: "Second paragraph" }],
+		})
+	})
+
+	it("should convert a message with an image", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "user",
+			content: [
+				{ type: "text", text: "Check out this image:" },
+				{
+					type: "image",
+					source: {
+						type: "base64",
+						media_type: "image/jpeg",
+						data: "base64encodeddata",
+					},
+				},
+			],
+		}
+
+		const result = convertAnthropicMessageToVertexGemini(anthropicMessage)
+
+		expect(result).toEqual({
+			role: "user",
+			parts: [
+				{ text: "Check out this image:" },
+				{
+					inlineData: {
+						data: "base64encodeddata",
+						mimeType: "image/jpeg",
+					},
+				},
+			],
+		})
+	})
+
+	it("should throw an error for unsupported image source type", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "user",
+			content: [
+				{
+					type: "image",
+					source: {
+						type: "url", // Not supported
+						url: "https://example.com/image.jpg",
+					} as any,
+				},
+			],
+		}
+
+		expect(() => convertAnthropicMessageToVertexGemini(anthropicMessage)).toThrow("Unsupported image source type")
+	})
+
+	it("should convert a message with tool use", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "assistant",
+			content: [
+				{ type: "text", text: "Let me calculate that for you." },
+				{
+					type: "tool_use",
+					id: "calc-123",
+					name: "calculator",
+					input: { operation: "add", numbers: [2, 3] },
+				},
+			],
+		}
+
+		const result = convertAnthropicMessageToVertexGemini(anthropicMessage)
+
+		expect(result).toEqual({
+			role: "model",
+			parts: [
+				{ text: "Let me calculate that for you." },
+				{
+					functionCall: {
+						name: "calculator",
+						args: { operation: "add", numbers: [2, 3] },
+					},
+				},
+			],
+		})
+	})
+
+	it("should convert a message with tool result as string", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "user",
+			content: [
+				{ type: "text", text: "Here's the result:" },
+				{
+					type: "tool_result",
+					tool_use_id: "calculator-123",
+					content: "The result is 5",
+				},
+			],
+		}
+
+		const result = convertAnthropicMessageToVertexGemini(anthropicMessage)
+
+		expect(result).toEqual({
+			role: "user",
+			parts: [
+				{ text: "Here's the result:" },
+				{
+					functionResponse: {
+						name: "calculator",
+						response: {
+							name: "calculator",
+							content: "The result is 5",
+						},
+					},
+				},
+			],
+		})
+	})
+
+	it("should handle empty tool result content", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "user",
+			content: [
+				{
+					type: "tool_result",
+					tool_use_id: "calculator-123",
+					content: null as any, // Empty content
+				},
+			],
+		}
+
+		const result = convertAnthropicMessageToVertexGemini(anthropicMessage)
+
+		// Should skip the empty tool result
+		expect(result).toEqual({
+			role: "user",
+			parts: [],
+		})
+	})
+
+	it("should convert a message with tool result as array with text only", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "user",
+			content: [
+				{
+					type: "tool_result",
+					tool_use_id: "search-123",
+					content: [
+						{ type: "text", text: "First result" },
+						{ type: "text", text: "Second result" },
+					],
+				},
+			],
+		}
+
+		const result = convertAnthropicMessageToVertexGemini(anthropicMessage)
+
+		expect(result).toEqual({
+			role: "user",
+			parts: [
+				{
+					functionResponse: {
+						name: "search",
+						response: {
+							name: "search",
+							content: "First result\n\nSecond result",
+						},
+					},
+				},
+			],
+		})
+	})
+
+	it("should convert a message with tool result as array with text and images", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "user",
+			content: [
+				{
+					type: "tool_result",
+					tool_use_id: "search-123",
+					content: [
+						{ type: "text", text: "Search results:" },
+						{
+							type: "image",
+							source: {
+								type: "base64",
+								media_type: "image/png",
+								data: "image1data",
+							},
+						},
+						{
+							type: "image",
+							source: {
+								type: "base64",
+								media_type: "image/jpeg",
+								data: "image2data",
+							},
+						},
+					],
+				},
+			],
+		}
+
+		const result = convertAnthropicMessageToVertexGemini(anthropicMessage)
+
+		expect(result).toEqual({
+			role: "user",
+			parts: [
+				{
+					functionResponse: {
+						name: "search",
+						response: {
+							name: "search",
+							content: "Search results:\n\n(See next part for image)",
+						},
+					},
+				},
+				{
+					inlineData: {
+						data: "image1data",
+						mimeType: "image/png",
+					},
+				},
+				{
+					inlineData: {
+						data: "image2data",
+						mimeType: "image/jpeg",
+					},
+				},
+			],
+		})
+	})
+
+	it("should convert a message with tool result containing only images", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "user",
+			content: [
+				{
+					type: "tool_result",
+					tool_use_id: "imagesearch-123",
+					content: [
+						{
+							type: "image",
+							source: {
+								type: "base64",
+								media_type: "image/png",
+								data: "onlyimagedata",
+							},
+						},
+					],
+				},
+			],
+		}
+
+		const result = convertAnthropicMessageToVertexGemini(anthropicMessage)
+
+		expect(result).toEqual({
+			role: "user",
+			parts: [
+				{
+					functionResponse: {
+						name: "imagesearch",
+						response: {
+							name: "imagesearch",
+							content: "\n\n(See next part for image)",
+						},
+					},
+				},
+				{
+					inlineData: {
+						data: "onlyimagedata",
+						mimeType: "image/png",
+					},
+				},
+			],
+		})
+	})
+
+	it("should throw an error for unsupported content block type", () => {
+		const anthropicMessage: Anthropic.Messages.MessageParam = {
+			role: "user",
+			content: [
+				{
+					type: "unknown_type", // Unsupported type
+					data: "some data",
+				} as any,
+			],
+		}
+
+		expect(() => convertAnthropicMessageToVertexGemini(anthropicMessage)).toThrow(
+			"Unsupported content block type: unknown_type",
+		)
+	})
+})

+ 83 - 0
src/api/transform/vertex-gemini-format.ts

@@ -0,0 +1,83 @@
+import { Anthropic } from "@anthropic-ai/sdk"
+import { Content, FunctionCallPart, FunctionResponsePart, InlineDataPart, Part, TextPart } from "@google-cloud/vertexai"
+
+function convertAnthropicContentToVertexGemini(content: Anthropic.Messages.MessageParam["content"]): Part[] {
+	if (typeof content === "string") {
+		return [{ text: content } as TextPart]
+	}
+
+	return content.flatMap((block) => {
+		switch (block.type) {
+			case "text":
+				return { text: block.text } as TextPart
+			case "image":
+				if (block.source.type !== "base64") {
+					throw new Error("Unsupported image source type")
+				}
+				return {
+					inlineData: {
+						data: block.source.data,
+						mimeType: block.source.media_type,
+					},
+				} as InlineDataPart
+			case "tool_use":
+				return {
+					functionCall: {
+						name: block.name,
+						args: block.input,
+					},
+				} as FunctionCallPart
+			case "tool_result":
+				const name = block.tool_use_id.split("-")[0]
+				if (!block.content) {
+					return []
+				}
+				if (typeof block.content === "string") {
+					return {
+						functionResponse: {
+							name,
+							response: {
+								name,
+								content: block.content,
+							},
+						},
+					} as FunctionResponsePart
+				} else {
+					// The only case when tool_result could be array is when the tool failed and we're providing ie user feedback potentially with images
+					const textParts = block.content.filter((part) => part.type === "text")
+					const imageParts = block.content.filter((part) => part.type === "image")
+					const text = textParts.length > 0 ? textParts.map((part) => part.text).join("\n\n") : ""
+					const imageText = imageParts.length > 0 ? "\n\n(See next part for image)" : ""
+					return [
+						{
+							functionResponse: {
+								name,
+								response: {
+									name,
+									content: text + imageText,
+								},
+							},
+						} as FunctionResponsePart,
+						...imageParts.map(
+							(part) =>
+								({
+									inlineData: {
+										data: part.source.data,
+										mimeType: part.source.media_type,
+									},
+								}) as InlineDataPart,
+						),
+					]
+				}
+			default:
+				throw new Error(`Unsupported content block type: ${(block as any).type}`)
+		}
+	})
+}
+
+export function convertAnthropicMessageToVertexGemini(message: Anthropic.Messages.MessageParam): Content {
+	return {
+		role: message.role === "assistant" ? "model" : "user",
+		parts: convertAnthropicContentToVertexGemini(message.content),
+	}
+}

+ 40 - 0
src/shared/api.ts

@@ -436,6 +436,46 @@ export const openRouterDefaultModelInfo: ModelInfo = {
 export type VertexModelId = keyof typeof vertexModels
 export const vertexDefaultModelId: VertexModelId = "claude-3-7-sonnet@20250219"
 export const vertexModels = {
+	"gemini-2.0-flash-001": {
+		maxTokens: 8192,
+		contextWindow: 1_048_576,
+		supportsImages: true,
+		supportsPromptCache: false,
+		inputPrice: 0.15,
+		outputPrice: 0.6,
+	},
+	"gemini-2.0-flash-lite-001": {
+		maxTokens: 8192,
+		contextWindow: 1_048_576,
+		supportsImages: true,
+		supportsPromptCache: false,
+		inputPrice: 0.075,
+		outputPrice: 0.3,
+	},
+	"gemini-2.0-flash-thinking-exp-01-21": {
+		maxTokens: 8192,
+		contextWindow: 32_768,
+		supportsImages: true,
+		supportsPromptCache: false,
+		inputPrice: 0,
+		outputPrice: 0,
+	},
+	"gemini-1.5-flash-002": {
+		maxTokens: 8192,
+		contextWindow: 1_048_576,
+		supportsImages: true,
+		supportsPromptCache: false,
+		inputPrice: 0.075,
+		outputPrice: 0.3,
+	},
+	"gemini-1.5-pro-002": {
+		maxTokens: 8192,
+		contextWindow: 2_097_152,
+		supportsImages: true,
+		supportsPromptCache: false,
+		inputPrice: 1.25,
+		outputPrice: 5,
+	},
 	"claude-3-7-sonnet@20250219:thinking": {
 		maxTokens: 64_000,
 		contextWindow: 200_000,