|
|
@@ -6,6 +6,7 @@ import { BetaThinkingConfigParam } from "@anthropic-ai/sdk/resources/beta"
|
|
|
|
|
|
import { VertexHandler } from "../vertex"
|
|
|
import { ApiStreamChunk } from "../../transform/stream"
|
|
|
+import { VertexAI } from "@google-cloud/vertexai"
|
|
|
|
|
|
// Mock Vertex SDK
|
|
|
jest.mock("@anthropic-ai/vertex-sdk", () => ({
|
|
|
@@ -49,24 +50,100 @@ jest.mock("@anthropic-ai/vertex-sdk", () => ({
|
|
|
})),
|
|
|
}))
|
|
|
|
|
|
-describe("VertexHandler", () => {
|
|
|
- let handler: VertexHandler
|
|
|
+// Mock Vertex Gemini SDK
|
|
|
+jest.mock("@google-cloud/vertexai", () => {
|
|
|
+ const mockGenerateContentStream = jest.fn().mockImplementation(() => {
|
|
|
+ return {
|
|
|
+ stream: {
|
|
|
+ async *[Symbol.asyncIterator]() {
|
|
|
+ yield {
|
|
|
+ candidates: [
|
|
|
+ {
|
|
|
+ content: {
|
|
|
+ parts: [{ text: "Test Gemini response" }],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ }
|
|
|
+ },
|
|
|
+ },
|
|
|
+ response: {
|
|
|
+ usageMetadata: {
|
|
|
+ promptTokenCount: 5,
|
|
|
+ candidatesTokenCount: 10,
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ })
|
|
|
|
|
|
- beforeEach(() => {
|
|
|
- handler = new VertexHandler({
|
|
|
- apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
- vertexProjectId: "test-project",
|
|
|
- vertexRegion: "us-central1",
|
|
|
- })
|
|
|
+ const mockGenerateContent = jest.fn().mockResolvedValue({
|
|
|
+ response: {
|
|
|
+ candidates: [
|
|
|
+ {
|
|
|
+ content: {
|
|
|
+ parts: [{ text: "Test Gemini response" }],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ const mockGenerativeModel = jest.fn().mockImplementation(() => {
|
|
|
+ return {
|
|
|
+ generateContentStream: mockGenerateContentStream,
|
|
|
+ generateContent: mockGenerateContent,
|
|
|
+ }
|
|
|
})
|
|
|
|
|
|
+ return {
|
|
|
+ VertexAI: jest.fn().mockImplementation(() => {
|
|
|
+ return {
|
|
|
+ getGenerativeModel: mockGenerativeModel,
|
|
|
+ }
|
|
|
+ }),
|
|
|
+ GenerativeModel: mockGenerativeModel,
|
|
|
+ }
|
|
|
+})
|
|
|
+
|
|
|
+describe("VertexHandler", () => {
|
|
|
+ let handler: VertexHandler
|
|
|
+
|
|
|
describe("constructor", () => {
|
|
|
- it("should initialize with provided config", () => {
|
|
|
+ it("should initialize with provided config for Claude", () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
expect(AnthropicVertex).toHaveBeenCalledWith({
|
|
|
projectId: "test-project",
|
|
|
region: "us-central1",
|
|
|
})
|
|
|
})
|
|
|
+
|
|
|
+ it("should initialize with provided config for Gemini", () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "gemini-1.5-pro-001",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
+ expect(VertexAI).toHaveBeenCalledWith({
|
|
|
+ project: "test-project",
|
|
|
+ location: "us-central1",
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should throw error for invalid model", () => {
|
|
|
+ expect(() => {
|
|
|
+ new VertexHandler({
|
|
|
+ apiModelId: "invalid-model",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+ }).toThrow("Unknown model ID: invalid-model")
|
|
|
+ })
|
|
|
})
|
|
|
|
|
|
describe("createMessage", () => {
|
|
|
@@ -83,7 +160,13 @@ describe("VertexHandler", () => {
|
|
|
|
|
|
const systemPrompt = "You are a helpful assistant"
|
|
|
|
|
|
- it("should handle streaming responses correctly", async () => {
|
|
|
+ it("should handle streaming responses correctly for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const mockStream = [
|
|
|
{
|
|
|
type: "message_start",
|
|
|
@@ -127,7 +210,7 @@ describe("VertexHandler", () => {
|
|
|
}
|
|
|
|
|
|
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
|
|
- ;(handler["client"].messages as any).create = mockCreate
|
|
|
+ ;(handler["anthropicClient"].messages as any).create = mockCreate
|
|
|
|
|
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
const chunks: ApiStreamChunk[] = []
|
|
|
@@ -187,7 +270,58 @@ describe("VertexHandler", () => {
|
|
|
})
|
|
|
})
|
|
|
|
|
|
- it("should handle multiple content blocks with line breaks", async () => {
|
|
|
+ it("should handle streaming responses correctly for Gemini", async () => {
|
|
|
+ const mockGemini = require("@google-cloud/vertexai")
|
|
|
+ const mockGenerateContentStream = mockGemini.VertexAI().getGenerativeModel().generateContentStream
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "gemini-1.5-pro-001",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
+ const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
+ const chunks: ApiStreamChunk[] = []
|
|
|
+
|
|
|
+ for await (const chunk of stream) {
|
|
|
+ chunks.push(chunk)
|
|
|
+ }
|
|
|
+
|
|
|
+ expect(chunks.length).toBe(2)
|
|
|
+ expect(chunks[0]).toEqual({
|
|
|
+ type: "text",
|
|
|
+ text: "Test Gemini response",
|
|
|
+ })
|
|
|
+ expect(chunks[1]).toEqual({
|
|
|
+ type: "usage",
|
|
|
+ inputTokens: 5,
|
|
|
+ outputTokens: 10,
|
|
|
+ })
|
|
|
+
|
|
|
+ expect(mockGenerateContentStream).toHaveBeenCalledWith({
|
|
|
+ contents: [
|
|
|
+ {
|
|
|
+ role: "user",
|
|
|
+ parts: [{ text: "Hello" }],
|
|
|
+ },
|
|
|
+ {
|
|
|
+ role: "model",
|
|
|
+ parts: [{ text: "Hi there!" }],
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ generationConfig: {
|
|
|
+ maxOutputTokens: 16384,
|
|
|
+ temperature: 0,
|
|
|
+ },
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle multiple content blocks with line breaks for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const mockStream = [
|
|
|
{
|
|
|
type: "content_block_start",
|
|
|
@@ -216,7 +350,7 @@ describe("VertexHandler", () => {
|
|
|
}
|
|
|
|
|
|
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
|
|
- ;(handler["client"].messages as any).create = mockCreate
|
|
|
+ ;(handler["anthropicClient"].messages as any).create = mockCreate
|
|
|
|
|
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
const chunks: ApiStreamChunk[] = []
|
|
|
@@ -240,10 +374,16 @@ describe("VertexHandler", () => {
|
|
|
})
|
|
|
})
|
|
|
|
|
|
- it("should handle API errors", async () => {
|
|
|
+ it("should handle API errors for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const mockError = new Error("Vertex API error")
|
|
|
const mockCreate = jest.fn().mockRejectedValue(mockError)
|
|
|
- ;(handler["client"].messages as any).create = mockCreate
|
|
|
+ ;(handler["anthropicClient"].messages as any).create = mockCreate
|
|
|
|
|
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
|
|
|
@@ -254,7 +394,13 @@ describe("VertexHandler", () => {
|
|
|
}).rejects.toThrow("Vertex API error")
|
|
|
})
|
|
|
|
|
|
- it("should handle prompt caching for supported models", async () => {
|
|
|
+ it("should handle prompt caching for supported models for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const mockStream = [
|
|
|
{
|
|
|
type: "message_start",
|
|
|
@@ -299,7 +445,7 @@ describe("VertexHandler", () => {
|
|
|
}
|
|
|
|
|
|
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
|
|
- ;(handler["client"].messages as any).create = mockCreate
|
|
|
+ ;(handler["anthropicClient"].messages as any).create = mockCreate
|
|
|
|
|
|
const stream = handler.createMessage(systemPrompt, [
|
|
|
{
|
|
|
@@ -383,7 +529,13 @@ describe("VertexHandler", () => {
|
|
|
)
|
|
|
})
|
|
|
|
|
|
- it("should handle cache-related usage metrics", async () => {
|
|
|
+ it("should handle cache-related usage metrics for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const mockStream = [
|
|
|
{
|
|
|
type: "message_start",
|
|
|
@@ -415,7 +567,7 @@ describe("VertexHandler", () => {
|
|
|
}
|
|
|
|
|
|
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
|
|
- ;(handler["client"].messages as any).create = mockCreate
|
|
|
+ ;(handler["anthropicClient"].messages as any).create = mockCreate
|
|
|
|
|
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
const chunks: ApiStreamChunk[] = []
|
|
|
@@ -442,7 +594,13 @@ describe("VertexHandler", () => {
|
|
|
|
|
|
const systemPrompt = "You are a helpful assistant"
|
|
|
|
|
|
- it("should handle thinking content blocks and deltas", async () => {
|
|
|
+ it("should handle thinking content blocks and deltas for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const mockStream = [
|
|
|
{
|
|
|
type: "message_start",
|
|
|
@@ -488,7 +646,7 @@ describe("VertexHandler", () => {
|
|
|
}
|
|
|
|
|
|
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
|
|
- ;(handler["client"].messages as any).create = mockCreate
|
|
|
+ ;(handler["anthropicClient"].messages as any).create = mockCreate
|
|
|
|
|
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
const chunks: ApiStreamChunk[] = []
|
|
|
@@ -510,7 +668,13 @@ describe("VertexHandler", () => {
|
|
|
expect(textChunks[1].text).toBe("Here's my answer:")
|
|
|
})
|
|
|
|
|
|
- it("should handle multiple thinking blocks with line breaks", async () => {
|
|
|
+ it("should handle multiple thinking blocks with line breaks for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const mockStream = [
|
|
|
{
|
|
|
type: "content_block_start",
|
|
|
@@ -539,7 +703,7 @@ describe("VertexHandler", () => {
|
|
|
}
|
|
|
|
|
|
const mockCreate = jest.fn().mockResolvedValue(asyncIterator)
|
|
|
- ;(handler["client"].messages as any).create = mockCreate
|
|
|
+ ;(handler["anthropicClient"].messages as any).create = mockCreate
|
|
|
|
|
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
const chunks: ApiStreamChunk[] = []
|
|
|
@@ -565,10 +729,16 @@ describe("VertexHandler", () => {
|
|
|
})
|
|
|
|
|
|
describe("completePrompt", () => {
|
|
|
- it("should complete prompt successfully", async () => {
|
|
|
+ it("should complete prompt successfully for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const result = await handler.completePrompt("Test prompt")
|
|
|
expect(result).toBe("Test response")
|
|
|
- expect(handler["client"].messages.create).toHaveBeenCalledWith({
|
|
|
+ expect(handler["anthropicClient"].messages.create).toHaveBeenCalledWith({
|
|
|
model: "claude-3-5-sonnet-v2@20241022",
|
|
|
max_tokens: 8192,
|
|
|
temperature: 0,
|
|
|
@@ -583,31 +753,109 @@ describe("VertexHandler", () => {
|
|
|
})
|
|
|
})
|
|
|
|
|
|
- it("should handle API errors", async () => {
|
|
|
+ it("should complete prompt successfully for Gemini", async () => {
|
|
|
+ const mockGemini = require("@google-cloud/vertexai")
|
|
|
+ const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
|
|
|
+
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "gemini-1.5-pro-001",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
+ const result = await handler.completePrompt("Test prompt")
|
|
|
+ expect(result).toBe("Test Gemini response")
|
|
|
+ expect(mockGenerateContent).toHaveBeenCalled()
|
|
|
+ expect(mockGenerateContent).toHaveBeenCalledWith({
|
|
|
+ contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
|
|
|
+ generationConfig: {
|
|
|
+ temperature: 0,
|
|
|
+ },
|
|
|
+ })
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle API errors for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const mockError = new Error("Vertex API error")
|
|
|
const mockCreate = jest.fn().mockRejectedValue(mockError)
|
|
|
- ;(handler["client"].messages as any).create = mockCreate
|
|
|
+ ;(handler["anthropicClient"].messages as any).create = mockCreate
|
|
|
+
|
|
|
+ await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
|
|
+ "Vertex completion error: Vertex API error",
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle API errors for Gemini", async () => {
|
|
|
+ const mockGemini = require("@google-cloud/vertexai")
|
|
|
+ const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
|
|
|
+ mockGenerateContent.mockRejectedValue(new Error("Vertex API error"))
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "gemini-1.5-pro-001",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
|
|
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
|
|
"Vertex completion error: Vertex API error",
|
|
|
)
|
|
|
})
|
|
|
|
|
|
- it("should handle non-text content", async () => {
|
|
|
+ it("should handle non-text content for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const mockCreate = jest.fn().mockResolvedValue({
|
|
|
content: [{ type: "image" }],
|
|
|
})
|
|
|
- ;(handler["client"].messages as any).create = mockCreate
|
|
|
+ ;(handler["anthropicClient"].messages as any).create = mockCreate
|
|
|
|
|
|
const result = await handler.completePrompt("Test prompt")
|
|
|
expect(result).toBe("")
|
|
|
})
|
|
|
|
|
|
- it("should handle empty response", async () => {
|
|
|
+ it("should handle empty response for Claude", async () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const mockCreate = jest.fn().mockResolvedValue({
|
|
|
content: [{ type: "text", text: "" }],
|
|
|
})
|
|
|
- ;(handler["client"].messages as any).create = mockCreate
|
|
|
+ ;(handler["anthropicClient"].messages as any).create = mockCreate
|
|
|
+
|
|
|
+ const result = await handler.completePrompt("Test prompt")
|
|
|
+ expect(result).toBe("")
|
|
|
+ })
|
|
|
+
|
|
|
+ it("should handle empty response for Gemini", async () => {
|
|
|
+ const mockGemini = require("@google-cloud/vertexai")
|
|
|
+ const mockGenerateContent = mockGemini.VertexAI().getGenerativeModel().generateContent
|
|
|
+ mockGenerateContent.mockResolvedValue({
|
|
|
+ response: {
|
|
|
+ candidates: [
|
|
|
+ {
|
|
|
+ content: {
|
|
|
+ parts: [{ text: "" }],
|
|
|
+ },
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ },
|
|
|
+ })
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "gemini-1.5-pro-001",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
|
|
|
const result = await handler.completePrompt("Test prompt")
|
|
|
expect(result).toBe("")
|
|
|
@@ -615,7 +863,13 @@ describe("VertexHandler", () => {
|
|
|
})
|
|
|
|
|
|
describe("getModel", () => {
|
|
|
- it("should return correct model info", () => {
|
|
|
+ it("should return correct model info for Claude", () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "claude-3-5-sonnet-v2@20241022",
|
|
|
+ vertexProjectId: "test-project",
|
|
|
+ vertexRegion: "us-central1",
|
|
|
+ })
|
|
|
+
|
|
|
const modelInfo = handler.getModel()
|
|
|
expect(modelInfo.id).toBe("claude-3-5-sonnet-v2@20241022")
|
|
|
expect(modelInfo.info).toBeDefined()
|
|
|
@@ -623,14 +877,18 @@ describe("VertexHandler", () => {
|
|
|
expect(modelInfo.info.contextWindow).toBe(200_000)
|
|
|
})
|
|
|
|
|
|
- it("should return default model if invalid model specified", () => {
|
|
|
- const invalidHandler = new VertexHandler({
|
|
|
- apiModelId: "invalid-model",
|
|
|
+ it("should return correct model info for Gemini", () => {
|
|
|
+ handler = new VertexHandler({
|
|
|
+ apiModelId: "gemini-2.0-flash-001",
|
|
|
vertexProjectId: "test-project",
|
|
|
vertexRegion: "us-central1",
|
|
|
})
|
|
|
- const modelInfo = invalidHandler.getModel()
|
|
|
- expect(modelInfo.id).toBe("claude-3-7-sonnet@20250219") // Default model
|
|
|
+
|
|
|
+ const modelInfo = handler.getModel()
|
|
|
+ expect(modelInfo.id).toBe("gemini-2.0-flash-001")
|
|
|
+ expect(modelInfo.info).toBeDefined()
|
|
|
+ expect(modelInfo.info.maxTokens).toBe(8192)
|
|
|
+ expect(modelInfo.info.contextWindow).toBe(1048576)
|
|
|
})
|
|
|
})
|
|
|
|
|
|
@@ -724,7 +982,7 @@ describe("VertexHandler", () => {
|
|
|
},
|
|
|
}
|
|
|
})
|
|
|
- ;(thinkingHandler["client"].messages as any).create = mockCreate
|
|
|
+ ;(thinkingHandler["anthropicClient"].messages as any).create = mockCreate
|
|
|
|
|
|
await thinkingHandler
|
|
|
.createMessage("You are a helpful assistant", [{ role: "user", content: "Hello" }])
|