Browse Source

feat: add gemini-embedding-001 model to code-index service (#5698)

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
Daniel 7 months ago
parent
commit
d7787a2de3

+ 50 - 7
src/services/code-index/__tests__/service-factory.spec.ts

@@ -265,7 +265,7 @@ describe("CodeIndexServiceFactory", () => {
 			expect(() => factory.createEmbedder()).toThrow("serviceFactory.openAiCompatibleConfigMissing")
 			expect(() => factory.createEmbedder()).toThrow("serviceFactory.openAiCompatibleConfigMissing")
 		})
 		})
 
 
-		it("should create GeminiEmbedder when using Gemini provider", () => {
+		it("should create GeminiEmbedder with default model when no modelId specified", () => {
 			// Arrange
 			// Arrange
 			const testConfig = {
 			const testConfig = {
 				embedderProvider: "gemini",
 				embedderProvider: "gemini",
@@ -279,7 +279,25 @@ describe("CodeIndexServiceFactory", () => {
 			factory.createEmbedder()
 			factory.createEmbedder()
 
 
 			// Assert
 			// Assert
-			expect(MockedGeminiEmbedder).toHaveBeenCalledWith("test-gemini-api-key")
+			expect(MockedGeminiEmbedder).toHaveBeenCalledWith("test-gemini-api-key", undefined)
+		})
+
+		it("should create GeminiEmbedder with specified modelId", () => {
+			// Arrange
+			const testConfig = {
+				embedderProvider: "gemini",
+				modelId: "text-embedding-004",
+				geminiOptions: {
+					apiKey: "test-gemini-api-key",
+				},
+			}
+			mockConfigManager.getConfig.mockReturnValue(testConfig as any)
+
+			// Act
+			factory.createEmbedder()
+
+			// Assert
+			expect(MockedGeminiEmbedder).toHaveBeenCalledWith("test-gemini-api-key", "text-embedding-004")
 		})
 		})
 
 
 		it("should throw error when Gemini API key is missing", () => {
 		it("should throw error when Gemini API key is missing", () => {
@@ -507,26 +525,51 @@ describe("CodeIndexServiceFactory", () => {
 			)
 			)
 		})
 		})
 
 
-		it("should use fixed dimension 768 for Gemini provider", () => {
+		it("should use model-specific dimension for Gemini provider", () => {
 			// Arrange
 			// Arrange
 			const testConfig = {
 			const testConfig = {
 				embedderProvider: "gemini",
 				embedderProvider: "gemini",
-				modelId: "text-embedding-004", // This is ignored by Gemini
+				modelId: "gemini-embedding-001",
 				qdrantUrl: "http://localhost:6333",
 				qdrantUrl: "http://localhost:6333",
 				qdrantApiKey: "test-key",
 				qdrantApiKey: "test-key",
 			}
 			}
 			mockConfigManager.getConfig.mockReturnValue(testConfig as any)
 			mockConfigManager.getConfig.mockReturnValue(testConfig as any)
+			mockGetModelDimension.mockReturnValue(3072)
 
 
 			// Act
 			// Act
 			factory.createVectorStore()
 			factory.createVectorStore()
 
 
 			// Assert
 			// Assert
-			// getModelDimension should not be called for Gemini
-			expect(mockGetModelDimension).not.toHaveBeenCalled()
+			expect(mockGetModelDimension).toHaveBeenCalledWith("gemini", "gemini-embedding-001")
 			expect(MockedQdrantVectorStore).toHaveBeenCalledWith(
 			expect(MockedQdrantVectorStore).toHaveBeenCalledWith(
 				"/test/workspace",
 				"/test/workspace",
 				"http://localhost:6333",
 				"http://localhost:6333",
-				768, // Fixed dimension for Gemini
+				3072,
+				"test-key",
+			)
+		})
+
+		it("should use default model dimension for Gemini when modelId not specified", () => {
+			// Arrange
+			const testConfig = {
+				embedderProvider: "gemini",
+				qdrantUrl: "http://localhost:6333",
+				qdrantApiKey: "test-key",
+			}
+			mockConfigManager.getConfig.mockReturnValue(testConfig as any)
+			mockGetDefaultModelId.mockReturnValue("gemini-embedding-001")
+			mockGetModelDimension.mockReturnValue(3072)
+
+			// Act
+			factory.createVectorStore()
+
+			// Assert
+			expect(mockGetDefaultModelId).toHaveBeenCalledWith("gemini")
+			expect(mockGetModelDimension).toHaveBeenCalledWith("gemini", "gemini-embedding-001")
+			expect(MockedQdrantVectorStore).toHaveBeenCalledWith(
+				"/test/workspace",
+				"http://localhost:6333",
+				3072,
 				"test-key",
 				"test-key",
 			)
 			)
 		})
 		})

+ 79 - 3
src/services/code-index/embedders/__tests__/gemini.spec.ts

@@ -25,13 +25,30 @@ describe("GeminiEmbedder", () => {
 	})
 	})
 
 
 	describe("constructor", () => {
 	describe("constructor", () => {
-		it("should create an instance with correct fixed values passed to OpenAICompatibleEmbedder", () => {
+		it("should create an instance with default model when no model specified", () => {
 			// Arrange
 			// Arrange
 			const apiKey = "test-gemini-api-key"
 			const apiKey = "test-gemini-api-key"
 
 
 			// Act
 			// Act
 			embedder = new GeminiEmbedder(apiKey)
 			embedder = new GeminiEmbedder(apiKey)
 
 
+			// Assert
+			expect(MockedOpenAICompatibleEmbedder).toHaveBeenCalledWith(
+				"https://generativelanguage.googleapis.com/v1beta/openai/",
+				apiKey,
+				"gemini-embedding-001",
+				2048,
+			)
+		})
+
+		it("should create an instance with specified model", () => {
+			// Arrange
+			const apiKey = "test-gemini-api-key"
+			const modelId = "text-embedding-004"
+
+			// Act
+			embedder = new GeminiEmbedder(apiKey, modelId)
+
 			// Assert
 			// Assert
 			expect(MockedOpenAICompatibleEmbedder).toHaveBeenCalledWith(
 			expect(MockedOpenAICompatibleEmbedder).toHaveBeenCalledWith(
 				"https://generativelanguage.googleapis.com/v1beta/openai/",
 				"https://generativelanguage.googleapis.com/v1beta/openai/",
@@ -50,7 +67,7 @@ describe("GeminiEmbedder", () => {
 	})
 	})
 
 
 	describe("embedderInfo", () => {
 	describe("embedderInfo", () => {
-		it("should return correct embedder info with dimension 768", () => {
+		it("should return correct embedder info", () => {
 			// Arrange
 			// Arrange
 			embedder = new GeminiEmbedder("test-api-key")
 			embedder = new GeminiEmbedder("test-api-key")
 
 
@@ -61,7 +78,66 @@ describe("GeminiEmbedder", () => {
 			expect(info).toEqual({
 			expect(info).toEqual({
 				name: "gemini",
 				name: "gemini",
 			})
 			})
-			expect(GeminiEmbedder.dimension).toBe(768)
+		})
+
+		describe("createEmbeddings", () => {
+			let mockCreateEmbeddings: any
+
+			beforeEach(() => {
+				mockCreateEmbeddings = vitest.fn()
+				MockedOpenAICompatibleEmbedder.prototype.createEmbeddings = mockCreateEmbeddings
+			})
+
+			it("should use instance model when no model parameter provided", async () => {
+				// Arrange
+				embedder = new GeminiEmbedder("test-api-key")
+				const texts = ["test text 1", "test text 2"]
+				const mockResponse = {
+					embeddings: [
+						[0.1, 0.2],
+						[0.3, 0.4],
+					],
+				}
+				mockCreateEmbeddings.mockResolvedValue(mockResponse)
+
+				// Act
+				const result = await embedder.createEmbeddings(texts)
+
+				// Assert
+				expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, "gemini-embedding-001")
+				expect(result).toEqual(mockResponse)
+			})
+
+			it("should use provided model parameter when specified", async () => {
+				// Arrange
+				embedder = new GeminiEmbedder("test-api-key", "text-embedding-004")
+				const texts = ["test text 1", "test text 2"]
+				const mockResponse = {
+					embeddings: [
+						[0.1, 0.2],
+						[0.3, 0.4],
+					],
+				}
+				mockCreateEmbeddings.mockResolvedValue(mockResponse)
+
+				// Act
+				const result = await embedder.createEmbeddings(texts, "gemini-embedding-001")
+
+				// Assert
+				expect(mockCreateEmbeddings).toHaveBeenCalledWith(texts, "gemini-embedding-001")
+				expect(result).toEqual(mockResponse)
+			})
+
+			it("should handle errors from OpenAICompatibleEmbedder", async () => {
+				// Arrange
+				embedder = new GeminiEmbedder("test-api-key")
+				const texts = ["test text"]
+				const error = new Error("Embedding failed")
+				mockCreateEmbeddings.mockRejectedValue(error)
+
+				// Act & Assert
+				await expect(embedder.createEmbeddings(texts)).rejects.toThrow("Embedding failed")
+			})
 		})
 		})
 	})
 	})
 
 

+ 17 - 20
src/services/code-index/embedders/gemini.ts

@@ -7,33 +7,36 @@ import { TelemetryService } from "@roo-code/telemetry"
 
 
 /**
 /**
  * Gemini embedder implementation that wraps the OpenAI Compatible embedder
  * Gemini embedder implementation that wraps the OpenAI Compatible embedder
- * with fixed configuration for Google's Gemini embedding API.
+ * with configuration for Google's Gemini embedding API.
  *
  *
- * Fixed values:
- * - Base URL: https://generativelanguage.googleapis.com/v1beta/openai/
- * - Model: text-embedding-004
- * - Dimension: 768
+ * Supported models:
+ * - text-embedding-004 (dimension: 768)
+ * - gemini-embedding-001 (dimension: 2048)
  */
  */
 export class GeminiEmbedder implements IEmbedder {
 export class GeminiEmbedder implements IEmbedder {
 	private readonly openAICompatibleEmbedder: OpenAICompatibleEmbedder
 	private readonly openAICompatibleEmbedder: OpenAICompatibleEmbedder
 	private static readonly GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
 	private static readonly GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
-	private static readonly GEMINI_MODEL = "text-embedding-004"
-	private static readonly GEMINI_DIMENSION = 768
+	private static readonly DEFAULT_MODEL = "gemini-embedding-001"
+	private readonly modelId: string
 
 
 	/**
 	/**
 	 * Creates a new Gemini embedder
 	 * Creates a new Gemini embedder
 	 * @param apiKey The Gemini API key for authentication
 	 * @param apiKey The Gemini API key for authentication
+	 * @param modelId The model ID to use (defaults to gemini-embedding-001)
 	 */
 	 */
-	constructor(apiKey: string) {
+	constructor(apiKey: string, modelId?: string) {
 		if (!apiKey) {
 		if (!apiKey) {
 			throw new Error(t("embeddings:validation.apiKeyRequired"))
 			throw new Error(t("embeddings:validation.apiKeyRequired"))
 		}
 		}
 
 
-		// Create an OpenAI Compatible embedder with Gemini's fixed configuration
+		// Use provided model or default
+		this.modelId = modelId || GeminiEmbedder.DEFAULT_MODEL
+
+		// Create an OpenAI Compatible embedder with Gemini's configuration
 		this.openAICompatibleEmbedder = new OpenAICompatibleEmbedder(
 		this.openAICompatibleEmbedder = new OpenAICompatibleEmbedder(
 			GeminiEmbedder.GEMINI_BASE_URL,
 			GeminiEmbedder.GEMINI_BASE_URL,
 			apiKey,
 			apiKey,
-			GeminiEmbedder.GEMINI_MODEL,
+			this.modelId,
 			GEMINI_MAX_ITEM_TOKENS,
 			GEMINI_MAX_ITEM_TOKENS,
 		)
 		)
 	}
 	}
@@ -41,13 +44,14 @@ export class GeminiEmbedder implements IEmbedder {
 	/**
 	/**
 	 * Creates embeddings for the given texts using Gemini's embedding API
 	 * Creates embeddings for the given texts using Gemini's embedding API
 	 * @param texts Array of text strings to embed
 	 * @param texts Array of text strings to embed
-	 * @param model Optional model identifier (ignored - always uses text-embedding-004)
+	 * @param model Optional model identifier (uses constructor model if not provided)
 	 * @returns Promise resolving to embedding response
 	 * @returns Promise resolving to embedding response
 	 */
 	 */
 	async createEmbeddings(texts: string[], model?: string): Promise<EmbeddingResponse> {
 	async createEmbeddings(texts: string[], model?: string): Promise<EmbeddingResponse> {
 		try {
 		try {
-			// Always use the fixed Gemini model, ignoring any passed model parameter
-			return await this.openAICompatibleEmbedder.createEmbeddings(texts, GeminiEmbedder.GEMINI_MODEL)
+			// Use the provided model or fall back to the instance's model
+			const modelToUse = model || this.modelId
+			return await this.openAICompatibleEmbedder.createEmbeddings(texts, modelToUse)
 		} catch (error) {
 		} catch (error) {
 			TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, {
 			TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, {
 				error: error instanceof Error ? error.message : String(error),
 				error: error instanceof Error ? error.message : String(error),
@@ -85,11 +89,4 @@ export class GeminiEmbedder implements IEmbedder {
 			name: "gemini",
 			name: "gemini",
 		}
 		}
 	}
 	}
-
-	/**
-	 * Gets the fixed dimension for Gemini embeddings
-	 */
-	static get dimension(): number {
-		return GeminiEmbedder.GEMINI_DIMENSION
-	}
 }
 }

+ 1 - 4
src/services/code-index/service-factory.ts

@@ -63,7 +63,7 @@ export class CodeIndexServiceFactory {
 			if (!config.geminiOptions?.apiKey) {
 			if (!config.geminiOptions?.apiKey) {
 				throw new Error(t("embeddings:serviceFactory.geminiConfigMissing"))
 				throw new Error(t("embeddings:serviceFactory.geminiConfigMissing"))
 			}
 			}
-			return new GeminiEmbedder(config.geminiOptions.apiKey)
+			return new GeminiEmbedder(config.geminiOptions.apiKey, config.modelId)
 		}
 		}
 
 
 		throw new Error(
 		throw new Error(
@@ -111,9 +111,6 @@ export class CodeIndexServiceFactory {
 		// First check if a manual dimension is provided (works for all providers)
 		// First check if a manual dimension is provided (works for all providers)
 		if (config.modelDimension && config.modelDimension > 0) {
 		if (config.modelDimension && config.modelDimension > 0) {
 			vectorSize = config.modelDimension
 			vectorSize = config.modelDimension
-		} else if (provider === "gemini") {
-			// Gemini's text-embedding-004 has a fixed dimension of 768
-			vectorSize = 768
 		} else {
 		} else {
 			// Fall back to model-specific dimension from profiles
 			// Fall back to model-specific dimension from profiles
 			vectorSize = getModelDimension(provider, modelId)
 			vectorSize = getModelDimension(provider, modelId)

+ 2 - 1
src/shared/embeddingModels.ts

@@ -48,6 +48,7 @@ export const EMBEDDING_MODEL_PROFILES: EmbeddingModelProfiles = {
 	},
 	},
 	gemini: {
 	gemini: {
 		"text-embedding-004": { dimension: 768 },
 		"text-embedding-004": { dimension: 768 },
+		"gemini-embedding-001": { dimension: 3072, scoreThreshold: 0.4 },
 	},
 	},
 }
 }
 
 
@@ -134,7 +135,7 @@ export function getDefaultModelId(provider: EmbedderProvider): string {
 		}
 		}
 
 
 		case "gemini":
 		case "gemini":
-			return "text-embedding-004"
+			return "gemini-embedding-001"
 
 
 		default:
 		default:
 			// Fallback for unknown providers
 			// Fallback for unknown providers