|
|
@@ -1,45 +1,41 @@
|
|
|
-import { GeminiHandler } from "../gemini"
|
|
|
+// npx jest src/api/providers/__tests__/gemini.test.ts
|
|
|
+
|
|
|
import { Anthropic } from "@anthropic-ai/sdk"
|
|
|
-import { GoogleGenerativeAI } from "@google/generative-ai"
|
|
|
-
|
|
|
-// Mock the Google Generative AI SDK
|
|
|
-jest.mock("@google/generative-ai", () => ({
|
|
|
- GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
|
|
|
- getGenerativeModel: jest.fn().mockReturnValue({
|
|
|
- generateContentStream: jest.fn(),
|
|
|
- generateContent: jest.fn().mockResolvedValue({
|
|
|
- response: {
|
|
|
- text: () => "Test response",
|
|
|
- },
|
|
|
- }),
|
|
|
- }),
|
|
|
- })),
|
|
|
-}))
|
|
|
+
|
|
|
+import { GeminiHandler } from "../gemini"
|
|
|
+import { geminiDefaultModelId } from "../../../shared/api"
|
|
|
+
|
|
|
+const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"
|
|
|
|
|
|
describe("GeminiHandler", () => {
|
|
|
let handler: GeminiHandler
|
|
|
|
|
|
beforeEach(() => {
|
|
|
+ // Create mock functions
|
|
|
+ const mockGenerateContentStream = jest.fn()
|
|
|
+ const mockGenerateContent = jest.fn()
|
|
|
+ const mockGetGenerativeModel = jest.fn()
|
|
|
+
|
|
|
handler = new GeminiHandler({
|
|
|
apiKey: "test-key",
|
|
|
- apiModelId: "gemini-2.0-flash-thinking-exp-1219",
|
|
|
+ apiModelId: GEMINI_20_FLASH_THINKING_NAME,
|
|
|
geminiApiKey: "test-key",
|
|
|
})
|
|
|
+
|
|
|
+ // Replace the client with our mock
|
|
|
+ handler["client"] = {
|
|
|
+ models: {
|
|
|
+ generateContentStream: mockGenerateContentStream,
|
|
|
+ generateContent: mockGenerateContent,
|
|
|
+ getGenerativeModel: mockGetGenerativeModel,
|
|
|
+ },
|
|
|
+ } as any
|
|
|
})
|
|
|
|
|
|
describe("constructor", () => {
|
|
|
it("should initialize with provided config", () => {
|
|
|
expect(handler["options"].geminiApiKey).toBe("test-key")
|
|
|
- expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
|
|
|
- })
|
|
|
-
|
|
|
- it.skip("should throw if API key is missing", () => {
|
|
|
- expect(() => {
|
|
|
- new GeminiHandler({
|
|
|
- apiModelId: "gemini-2.0-flash-thinking-exp-1219",
|
|
|
- geminiApiKey: "",
|
|
|
- })
|
|
|
- }).toThrow("API key is required for Google Gemini")
|
|
|
+ expect(handler["options"].apiModelId).toBe(GEMINI_20_FLASH_THINKING_NAME)
|
|
|
})
|
|
|
})
|
|
|
|
|
|
@@ -58,25 +54,15 @@ describe("GeminiHandler", () => {
|
|
|
const systemPrompt = "You are a helpful assistant"
|
|
|
|
|
|
it("should handle text messages correctly", async () => {
|
|
|
- // Mock the stream response
|
|
|
- const mockStream = {
|
|
|
- stream: [{ text: () => "Hello" }, { text: () => " world!" }],
|
|
|
- response: {
|
|
|
- usageMetadata: {
|
|
|
- promptTokenCount: 10,
|
|
|
- candidatesTokenCount: 5,
|
|
|
- },
|
|
|
+ // Setup the mock implementation to return an async generator
|
|
|
+ ;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({
|
|
|
+ [Symbol.asyncIterator]: async function* () {
|
|
|
+ yield { text: "Hello" }
|
|
|
+ yield { text: " world!" }
|
|
|
+ yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
|
|
|
},
|
|
|
- }
|
|
|
-
|
|
|
- // Setup the mock implementation
|
|
|
- const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
|
|
|
- const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
|
|
- generateContentStream: mockGenerateContentStream,
|
|
|
})
|
|
|
|
|
|
- ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
|
|
-
|
|
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
const chunks = []
|
|
|
|
|
|
@@ -100,35 +86,21 @@ describe("GeminiHandler", () => {
|
|
|
outputTokens: 5,
|
|
|
})
|
|
|
|
|
|
- // Verify the model configuration
|
|
|
- expect(mockGetGenerativeModel).toHaveBeenCalledWith(
|
|
|
- {
|
|
|
- model: "gemini-2.0-flash-thinking-exp-1219",
|
|
|
- systemInstruction: systemPrompt,
|
|
|
- },
|
|
|
- {
|
|
|
- baseUrl: undefined,
|
|
|
- },
|
|
|
- )
|
|
|
-
|
|
|
- // Verify generation config
|
|
|
- expect(mockGenerateContentStream).toHaveBeenCalledWith(
|
|
|
+ // Verify the call to generateContentStream
|
|
|
+ expect(handler["client"].models.generateContentStream).toHaveBeenCalledWith(
|
|
|
expect.objectContaining({
|
|
|
- generationConfig: {
|
|
|
+ model: GEMINI_20_FLASH_THINKING_NAME,
|
|
|
+ config: expect.objectContaining({
|
|
|
temperature: 0,
|
|
|
- },
|
|
|
+ systemInstruction: systemPrompt,
|
|
|
+ }),
|
|
|
}),
|
|
|
)
|
|
|
})
|
|
|
|
|
|
it("should handle API errors", async () => {
|
|
|
const mockError = new Error("Gemini API error")
|
|
|
- const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
|
|
|
- const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
|
|
- generateContentStream: mockGenerateContentStream,
|
|
|
- })
|
|
|
-
|
|
|
- ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
|
|
+ ;(handler["client"].models.generateContentStream as jest.Mock).mockRejectedValue(mockError)
|
|
|
|
|
|
const stream = handler.createMessage(systemPrompt, mockMessages)
|
|
|
|
|
|
@@ -136,35 +108,26 @@ describe("GeminiHandler", () => {
|
|
|
for await (const chunk of stream) {
|
|
|
// Should throw before yielding any chunks
|
|
|
}
|
|
|
- }).rejects.toThrow("Gemini API error")
|
|
|
+ }).rejects.toThrow()
|
|
|
})
|
|
|
})
|
|
|
|
|
|
describe("completePrompt", () => {
|
|
|
it("should complete prompt successfully", async () => {
|
|
|
- const mockGenerateContent = jest.fn().mockResolvedValue({
|
|
|
- response: {
|
|
|
- text: () => "Test response",
|
|
|
- },
|
|
|
- })
|
|
|
- const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
|
|
- generateContent: mockGenerateContent,
|
|
|
+ // Mock the response with text property
|
|
|
+ ;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({
|
|
|
+ text: "Test response",
|
|
|
})
|
|
|
- ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
|
|
|
|
|
const result = await handler.completePrompt("Test prompt")
|
|
|
expect(result).toBe("Test response")
|
|
|
- expect(mockGetGenerativeModel).toHaveBeenCalledWith(
|
|
|
- {
|
|
|
- model: "gemini-2.0-flash-thinking-exp-1219",
|
|
|
- },
|
|
|
- {
|
|
|
- baseUrl: undefined,
|
|
|
- },
|
|
|
- )
|
|
|
- expect(mockGenerateContent).toHaveBeenCalledWith({
|
|
|
+
|
|
|
+ // Verify the call to generateContent
|
|
|
+ expect(handler["client"].models.generateContent).toHaveBeenCalledWith({
|
|
|
+ model: GEMINI_20_FLASH_THINKING_NAME,
|
|
|
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
|
|
|
- generationConfig: {
|
|
|
+ config: {
|
|
|
+ httpOptions: undefined,
|
|
|
temperature: 0,
|
|
|
},
|
|
|
})
|
|
|
@@ -172,11 +135,7 @@ describe("GeminiHandler", () => {
|
|
|
|
|
|
it("should handle API errors", async () => {
|
|
|
const mockError = new Error("Gemini API error")
|
|
|
- const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
|
|
|
- const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
|
|
- generateContent: mockGenerateContent,
|
|
|
- })
|
|
|
- ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
|
|
+ ;(handler["client"].models.generateContent as jest.Mock).mockRejectedValue(mockError)
|
|
|
|
|
|
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
|
|
|
"Gemini completion error: Gemini API error",
|
|
|
@@ -184,15 +143,10 @@ describe("GeminiHandler", () => {
|
|
|
})
|
|
|
|
|
|
it("should handle empty response", async () => {
|
|
|
- const mockGenerateContent = jest.fn().mockResolvedValue({
|
|
|
- response: {
|
|
|
- text: () => "",
|
|
|
- },
|
|
|
- })
|
|
|
- const mockGetGenerativeModel = jest.fn().mockReturnValue({
|
|
|
- generateContent: mockGenerateContent,
|
|
|
+ // Mock the response with empty text
|
|
|
+ ;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({
|
|
|
+ text: "",
|
|
|
})
|
|
|
- ;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
|
|
|
|
|
|
const result = await handler.completePrompt("Test prompt")
|
|
|
expect(result).toBe("")
|
|
|
@@ -202,7 +156,7 @@ describe("GeminiHandler", () => {
|
|
|
describe("getModel", () => {
|
|
|
it("should return correct model info", () => {
|
|
|
const modelInfo = handler.getModel()
|
|
|
- expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219")
|
|
|
+ expect(modelInfo.id).toBe(GEMINI_20_FLASH_THINKING_NAME)
|
|
|
expect(modelInfo.info).toBeDefined()
|
|
|
expect(modelInfo.info.maxTokens).toBe(8192)
|
|
|
expect(modelInfo.info.contextWindow).toBe(32_767)
|
|
|
@@ -214,7 +168,7 @@ describe("GeminiHandler", () => {
|
|
|
geminiApiKey: "test-key",
|
|
|
})
|
|
|
const modelInfo = invalidHandler.getModel()
|
|
|
- expect(modelInfo.id).toBe("gemini-2.0-flash-001") // Default model
|
|
|
+ expect(modelInfo.id).toBe(geminiDefaultModelId) // Default model
|
|
|
})
|
|
|
})
|
|
|
})
|