|
|
@@ -1,415 +0,0 @@
|
|
|
-import { OpenAI } from "openai"
|
|
|
-import { IEmbedder, EmbeddingResponse, EmbedderInfo } from "../interfaces/embedder"
|
|
|
-import {
|
|
|
- MAX_BATCH_TOKENS,
|
|
|
- MAX_ITEM_TOKENS,
|
|
|
- MAX_BATCH_RETRIES as MAX_RETRIES,
|
|
|
- INITIAL_RETRY_DELAY_MS as INITIAL_DELAY_MS,
|
|
|
-} from "../constants"
|
|
|
-import { getDefaultModelId, getModelQueryPrefix } from "../../../shared/embeddingModels"
|
|
|
-import { t } from "../../../i18n"
|
|
|
-import { withValidationErrorHandling, HttpError, formatEmbeddingError } from "../shared/validation-helpers"
|
|
|
-import { TelemetryEventName } from "@roo-code/types"
|
|
|
-import { TelemetryService } from "@roo-code/telemetry"
|
|
|
-import { Mutex } from "async-mutex"
|
|
|
-import { handleOpenAIError } from "../../../api/providers/utils/openai-error-handler"
|
|
|
-import { CloudService } from "@roo-code/cloud"
|
|
|
-
|
|
|
-interface EmbeddingItem {
|
|
|
- embedding: string | number[]
|
|
|
- [key: string]: any
|
|
|
-}
|
|
|
-
|
|
|
-interface RooEmbeddingResponse {
|
|
|
- data: EmbeddingItem[]
|
|
|
- usage?: {
|
|
|
- prompt_tokens?: number
|
|
|
- total_tokens?: number
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-function getSessionToken(): string {
|
|
|
- const token = CloudService.hasInstance() ? CloudService.instance.authService?.getSessionToken() : undefined
|
|
|
- return token ?? "unauthenticated"
|
|
|
-}
|
|
|
-
|
|
|
-/**
|
|
|
- * Roo Code Cloud implementation of the embedder interface with batching and rate limiting.
|
|
|
- * Roo Code Cloud provides access to embedding models through a unified proxy endpoint.
|
|
|
- */
|
|
|
-export class RooEmbedder implements IEmbedder {
|
|
|
- private embeddingsClient: OpenAI
|
|
|
- private readonly defaultModelId: string
|
|
|
- private readonly maxItemTokens: number
|
|
|
- private readonly baseUrl: string
|
|
|
-
|
|
|
- // Global rate limiting state shared across all instances
|
|
|
- private static globalRateLimitState = {
|
|
|
- isRateLimited: false,
|
|
|
- rateLimitResetTime: 0,
|
|
|
- consecutiveRateLimitErrors: 0,
|
|
|
- lastRateLimitError: 0,
|
|
|
- // Mutex to ensure thread-safe access to rate limit state
|
|
|
- mutex: new Mutex(),
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Creates a new Roo Code Cloud embedder
|
|
|
- * @param modelId Optional model identifier (defaults to "openai/text-embedding-3-large")
|
|
|
- * @param maxItemTokens Optional maximum tokens per item (defaults to MAX_ITEM_TOKENS)
|
|
|
- */
|
|
|
- constructor(modelId?: string, maxItemTokens?: number) {
|
|
|
- const sessionToken = getSessionToken()
|
|
|
-
|
|
|
- this.baseUrl = process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy"
|
|
|
-
|
|
|
- // Ensure baseURL ends with /v1 for OpenAI client, but don't duplicate it
|
|
|
- const baseURL = !this.baseUrl.endsWith("/v1") ? `${this.baseUrl}/v1` : this.baseUrl
|
|
|
-
|
|
|
- // Wrap OpenAI client creation to handle invalid API key characters
|
|
|
- try {
|
|
|
- this.embeddingsClient = new OpenAI({
|
|
|
- baseURL,
|
|
|
- apiKey: sessionToken,
|
|
|
- defaultHeaders: {
|
|
|
- "HTTP-Referer": "https://github.com/RooCodeInc/Roo-Code",
|
|
|
- "X-Title": "Roo Code",
|
|
|
- },
|
|
|
- })
|
|
|
- } catch (error) {
|
|
|
- // Use the error handler to transform ByteString conversion errors
|
|
|
- throw handleOpenAIError(error, "Roo Code Cloud")
|
|
|
- }
|
|
|
-
|
|
|
- this.defaultModelId = modelId || getDefaultModelId("roo")
|
|
|
- this.maxItemTokens = maxItemTokens || MAX_ITEM_TOKENS
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Creates embeddings for the given texts with batching and rate limiting
|
|
|
- * @param texts Array of text strings to embed
|
|
|
- * @param model Optional model identifier
|
|
|
- * @returns Promise resolving to embedding response
|
|
|
- */
|
|
|
- async createEmbeddings(texts: string[], model?: string): Promise<EmbeddingResponse> {
|
|
|
- const modelToUse = model || this.defaultModelId
|
|
|
-
|
|
|
- // Apply model-specific query prefix if required
|
|
|
- const queryPrefix = getModelQueryPrefix("roo", modelToUse)
|
|
|
- const processedTexts = queryPrefix
|
|
|
- ? texts.map((text, index) => {
|
|
|
- // Prevent double-prefixing
|
|
|
- if (text.startsWith(queryPrefix)) {
|
|
|
- return text
|
|
|
- }
|
|
|
- const prefixedText = `${queryPrefix}${text}`
|
|
|
- const estimatedTokens = Math.ceil(prefixedText.length / 4)
|
|
|
- if (estimatedTokens > MAX_ITEM_TOKENS) {
|
|
|
- console.warn(
|
|
|
- t("embeddings:textWithPrefixExceedsTokenLimit", {
|
|
|
- index,
|
|
|
- estimatedTokens,
|
|
|
- maxTokens: MAX_ITEM_TOKENS,
|
|
|
- }),
|
|
|
- )
|
|
|
- // Return original text if adding prefix would exceed limit
|
|
|
- return text
|
|
|
- }
|
|
|
- return prefixedText
|
|
|
- })
|
|
|
- : texts
|
|
|
-
|
|
|
- const allEmbeddings: number[][] = []
|
|
|
- const usage = { promptTokens: 0, totalTokens: 0 }
|
|
|
- const remainingTexts = [...processedTexts]
|
|
|
-
|
|
|
- while (remainingTexts.length > 0) {
|
|
|
- const currentBatch: string[] = []
|
|
|
- let currentBatchTokens = 0
|
|
|
- const processedIndices: number[] = []
|
|
|
-
|
|
|
- for (let i = 0; i < remainingTexts.length; i++) {
|
|
|
- const text = remainingTexts[i]
|
|
|
- const itemTokens = Math.ceil(text.length / 4)
|
|
|
-
|
|
|
- if (itemTokens > this.maxItemTokens) {
|
|
|
- console.warn(
|
|
|
- t("embeddings:textExceedsTokenLimit", {
|
|
|
- index: i,
|
|
|
- itemTokens,
|
|
|
- maxTokens: this.maxItemTokens,
|
|
|
- }),
|
|
|
- )
|
|
|
- processedIndices.push(i)
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- if (currentBatchTokens + itemTokens <= MAX_BATCH_TOKENS) {
|
|
|
- currentBatch.push(text)
|
|
|
- currentBatchTokens += itemTokens
|
|
|
- processedIndices.push(i)
|
|
|
- } else {
|
|
|
- break
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Remove processed items from remainingTexts (in reverse order to maintain correct indices)
|
|
|
- for (let i = processedIndices.length - 1; i >= 0; i--) {
|
|
|
- remainingTexts.splice(processedIndices[i], 1)
|
|
|
- }
|
|
|
-
|
|
|
- if (currentBatch.length > 0) {
|
|
|
- const batchResult = await this._embedBatchWithRetries(currentBatch, modelToUse)
|
|
|
- allEmbeddings.push(...batchResult.embeddings)
|
|
|
- usage.promptTokens += batchResult.usage.promptTokens
|
|
|
- usage.totalTokens += batchResult.usage.totalTokens
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return { embeddings: allEmbeddings, usage }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Helper method to handle batch embedding with retries and exponential backoff
|
|
|
- * @param batchTexts Array of texts to embed in this batch
|
|
|
- * @param model Model identifier to use
|
|
|
- * @returns Promise resolving to embeddings and usage statistics
|
|
|
- */
|
|
|
- private async _embedBatchWithRetries(
|
|
|
- batchTexts: string[],
|
|
|
- model: string,
|
|
|
- ): Promise<{ embeddings: number[][]; usage: { promptTokens: number; totalTokens: number } }> {
|
|
|
- for (let attempts = 0; attempts < MAX_RETRIES; attempts++) {
|
|
|
- // Check global rate limit before attempting request
|
|
|
- await this.waitForGlobalRateLimit()
|
|
|
-
|
|
|
- // Update API key before each request to ensure we use the latest session token
|
|
|
- this.embeddingsClient.apiKey = getSessionToken()
|
|
|
-
|
|
|
- try {
|
|
|
- const response = (await this.embeddingsClient.embeddings.create({
|
|
|
- input: batchTexts,
|
|
|
- model: model,
|
|
|
- // OpenAI package (as of v4.78.1) has a parsing issue that truncates embedding dimensions to 256
|
|
|
- // when processing numeric arrays, which breaks compatibility with models using larger dimensions.
|
|
|
- // By requesting base64 encoding, we bypass the package's parser and handle decoding ourselves.
|
|
|
- encoding_format: "base64",
|
|
|
- })) as RooEmbeddingResponse
|
|
|
-
|
|
|
- // Convert base64 embeddings to float32 arrays
|
|
|
- const processedEmbeddings = response.data.map((item: EmbeddingItem) => {
|
|
|
- if (typeof item.embedding === "string") {
|
|
|
- const buffer = Buffer.from(item.embedding, "base64")
|
|
|
-
|
|
|
- // Create Float32Array view over the buffer
|
|
|
- const float32Array = new Float32Array(buffer.buffer, buffer.byteOffset, buffer.byteLength / 4)
|
|
|
-
|
|
|
- return {
|
|
|
- ...item,
|
|
|
- embedding: Array.from(float32Array),
|
|
|
- }
|
|
|
- }
|
|
|
- return item
|
|
|
- })
|
|
|
-
|
|
|
- // Replace the original data with processed embeddings
|
|
|
- response.data = processedEmbeddings
|
|
|
-
|
|
|
- const embeddings = response.data.map((item) => item.embedding as number[])
|
|
|
-
|
|
|
- return {
|
|
|
- embeddings: embeddings,
|
|
|
- usage: {
|
|
|
- promptTokens: response.usage?.prompt_tokens || 0,
|
|
|
- totalTokens: response.usage?.total_tokens || 0,
|
|
|
- },
|
|
|
- }
|
|
|
- } catch (error) {
|
|
|
- // Capture telemetry before error is reformatted
|
|
|
- TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, {
|
|
|
- error: error instanceof Error ? error.message : String(error),
|
|
|
- stack: error instanceof Error ? error.stack : undefined,
|
|
|
- location: "RooEmbedder:_embedBatchWithRetries",
|
|
|
- attempt: attempts + 1,
|
|
|
- })
|
|
|
-
|
|
|
- const hasMoreAttempts = attempts < MAX_RETRIES - 1
|
|
|
-
|
|
|
- // Check if it's a rate limit error
|
|
|
- const httpError = error as HttpError
|
|
|
- if (httpError?.status === 429) {
|
|
|
- // Update global rate limit state
|
|
|
- await this.updateGlobalRateLimitState(httpError)
|
|
|
-
|
|
|
- if (hasMoreAttempts) {
|
|
|
- // Calculate delay based on global rate limit state
|
|
|
- const baseDelay = INITIAL_DELAY_MS * Math.pow(2, attempts)
|
|
|
- const globalDelay = await this.getGlobalRateLimitDelay()
|
|
|
- const delayMs = Math.max(baseDelay, globalDelay)
|
|
|
-
|
|
|
- console.warn(
|
|
|
- t("embeddings:rateLimitRetry", {
|
|
|
- delayMs,
|
|
|
- attempt: attempts + 1,
|
|
|
- maxRetries: MAX_RETRIES,
|
|
|
- }),
|
|
|
- )
|
|
|
- await new Promise((resolve) => setTimeout(resolve, delayMs))
|
|
|
- continue
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- // Log the error for debugging
|
|
|
- console.error(`Roo Code Cloud embedder error (attempt ${attempts + 1}/${MAX_RETRIES}):`, error)
|
|
|
-
|
|
|
- // Format and throw the error
|
|
|
- throw formatEmbeddingError(error, MAX_RETRIES)
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- throw new Error(t("embeddings:failedMaxAttempts", { attempts: MAX_RETRIES }))
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Validates the Roo Code Cloud embedder configuration by testing API connectivity
|
|
|
- * @returns Promise resolving to validation result with success status and optional error message
|
|
|
- */
|
|
|
- async validateConfiguration(): Promise<{ valid: boolean; error?: string }> {
|
|
|
- return withValidationErrorHandling(async () => {
|
|
|
- // Check if we have a valid session token
|
|
|
- const sessionToken = getSessionToken()
|
|
|
- if (!sessionToken || sessionToken === "unauthenticated") {
|
|
|
- return {
|
|
|
- valid: false,
|
|
|
- error: "embeddings:validation.rooAuthenticationRequired",
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- try {
|
|
|
- // Update API key before validation
|
|
|
- this.embeddingsClient.apiKey = sessionToken
|
|
|
-
|
|
|
- // Test with a minimal embedding request
|
|
|
- const testTexts = ["test"]
|
|
|
- const modelToUse = this.defaultModelId
|
|
|
-
|
|
|
- const response = (await this.embeddingsClient.embeddings.create({
|
|
|
- input: testTexts,
|
|
|
- model: modelToUse,
|
|
|
- encoding_format: "base64",
|
|
|
- })) as RooEmbeddingResponse
|
|
|
-
|
|
|
- // Check if we got a valid response
|
|
|
- if (!response?.data || response.data.length === 0) {
|
|
|
- return {
|
|
|
- valid: false,
|
|
|
- error: "embeddings:validation.invalidResponse",
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return { valid: true }
|
|
|
- } catch (error) {
|
|
|
- // Capture telemetry for validation errors
|
|
|
- TelemetryService.instance.captureEvent(TelemetryEventName.CODE_INDEX_ERROR, {
|
|
|
- error: error instanceof Error ? error.message : String(error),
|
|
|
- stack: error instanceof Error ? error.stack : undefined,
|
|
|
- location: "RooEmbedder:validateConfiguration",
|
|
|
- })
|
|
|
- throw error
|
|
|
- }
|
|
|
- }, "roo")
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Returns information about this embedder
|
|
|
- */
|
|
|
- get embedderInfo(): EmbedderInfo {
|
|
|
- return {
|
|
|
- name: "roo",
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Waits if there's an active global rate limit
|
|
|
- */
|
|
|
- private async waitForGlobalRateLimit(): Promise<void> {
|
|
|
- const release = await RooEmbedder.globalRateLimitState.mutex.acquire()
|
|
|
- let mutexReleased = false
|
|
|
-
|
|
|
- try {
|
|
|
- const state = RooEmbedder.globalRateLimitState
|
|
|
-
|
|
|
- if (state.isRateLimited && state.rateLimitResetTime > Date.now()) {
|
|
|
- const waitTime = state.rateLimitResetTime - Date.now()
|
|
|
- // Silent wait - no logging to prevent flooding
|
|
|
- release()
|
|
|
- mutexReleased = true
|
|
|
- await new Promise((resolve) => setTimeout(resolve, waitTime))
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- // Reset rate limit if time has passed
|
|
|
- if (state.isRateLimited && state.rateLimitResetTime <= Date.now()) {
|
|
|
- state.isRateLimited = false
|
|
|
- state.consecutiveRateLimitErrors = 0
|
|
|
- }
|
|
|
- } finally {
|
|
|
- // Only release if we haven't already
|
|
|
- if (!mutexReleased) {
|
|
|
- release()
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Updates global rate limit state when a 429 error occurs
|
|
|
- */
|
|
|
- private async updateGlobalRateLimitState(error: HttpError): Promise<void> {
|
|
|
- const release = await RooEmbedder.globalRateLimitState.mutex.acquire()
|
|
|
- try {
|
|
|
- const state = RooEmbedder.globalRateLimitState
|
|
|
- const now = Date.now()
|
|
|
-
|
|
|
- // Increment consecutive rate limit errors
|
|
|
- if (now - state.lastRateLimitError < 60000) {
|
|
|
- // Within 1 minute
|
|
|
- state.consecutiveRateLimitErrors++
|
|
|
- } else {
|
|
|
- state.consecutiveRateLimitErrors = 1
|
|
|
- }
|
|
|
-
|
|
|
- state.lastRateLimitError = now
|
|
|
-
|
|
|
- // Calculate exponential backoff based on consecutive errors
|
|
|
- const baseDelay = 5000 // 5 seconds base
|
|
|
- const maxDelay = 300000 // 5 minutes max
|
|
|
- const exponentialDelay = Math.min(baseDelay * Math.pow(2, state.consecutiveRateLimitErrors - 1), maxDelay)
|
|
|
-
|
|
|
- // Set global rate limit
|
|
|
- state.isRateLimited = true
|
|
|
- state.rateLimitResetTime = now + exponentialDelay
|
|
|
-
|
|
|
- // Silent rate limit activation - no logging to prevent flooding
|
|
|
- } finally {
|
|
|
- release()
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Gets the current global rate limit delay
|
|
|
- */
|
|
|
- private async getGlobalRateLimitDelay(): Promise<number> {
|
|
|
- const release = await RooEmbedder.globalRateLimitState.mutex.acquire()
|
|
|
- try {
|
|
|
- const state = RooEmbedder.globalRateLimitState
|
|
|
-
|
|
|
- if (state.isRateLimited && state.rateLimitResetTime > Date.now()) {
|
|
|
- return state.rateLimitResetTime - Date.now()
|
|
|
- }
|
|
|
-
|
|
|
- return 0
|
|
|
- } finally {
|
|
|
- release()
|
|
|
- }
|
|
|
- }
|
|
|
-}
|