|
@@ -1,17 +1,16 @@
|
|
|
import axios from "axios"
|
|
import axios from "axios"
|
|
|
import { z } from "zod"
|
|
import { z } from "zod"
|
|
|
-import type { ModelInfo } from "@roo-code/types"
|
|
|
|
|
|
|
+
|
|
|
import {
|
|
import {
|
|
|
|
|
+ type ModelInfo,
|
|
|
HUGGINGFACE_API_URL,
|
|
HUGGINGFACE_API_URL,
|
|
|
HUGGINGFACE_CACHE_DURATION,
|
|
HUGGINGFACE_CACHE_DURATION,
|
|
|
HUGGINGFACE_DEFAULT_MAX_TOKENS,
|
|
HUGGINGFACE_DEFAULT_MAX_TOKENS,
|
|
|
HUGGINGFACE_DEFAULT_CONTEXT_WINDOW,
|
|
HUGGINGFACE_DEFAULT_CONTEXT_WINDOW,
|
|
|
} from "@roo-code/types"
|
|
} from "@roo-code/types"
|
|
|
|
|
+
|
|
|
import type { ModelRecord } from "../../../shared/api"
|
|
import type { ModelRecord } from "../../../shared/api"
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * HuggingFace Provider Schema
|
|
|
|
|
- */
|
|
|
|
|
const huggingFaceProviderSchema = z.object({
|
|
const huggingFaceProviderSchema = z.object({
|
|
|
provider: z.string(),
|
|
provider: z.string(),
|
|
|
status: z.enum(["live", "staging", "error"]),
|
|
status: z.enum(["live", "staging", "error"]),
|
|
@@ -27,7 +26,8 @@ const huggingFaceProviderSchema = z.object({
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
|
- * Represents a provider that can serve a HuggingFace model
|
|
|
|
|
|
|
+ * Represents a provider that can serve a HuggingFace model.
|
|
|
|
|
+ *
|
|
|
* @property provider - The provider identifier (e.g., "sambanova", "together")
|
|
* @property provider - The provider identifier (e.g., "sambanova", "together")
|
|
|
* @property status - The current status of the provider
|
|
* @property status - The current status of the provider
|
|
|
* @property supports_tools - Whether the provider supports tool/function calling
|
|
* @property supports_tools - Whether the provider supports tool/function calling
|
|
@@ -37,9 +37,6 @@ const huggingFaceProviderSchema = z.object({
|
|
|
*/
|
|
*/
|
|
|
export type HuggingFaceProvider = z.infer<typeof huggingFaceProviderSchema>
|
|
export type HuggingFaceProvider = z.infer<typeof huggingFaceProviderSchema>
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * HuggingFace Model Schema
|
|
|
|
|
- */
|
|
|
|
|
const huggingFaceModelSchema = z.object({
|
|
const huggingFaceModelSchema = z.object({
|
|
|
id: z.string(),
|
|
id: z.string(),
|
|
|
object: z.literal("model"),
|
|
object: z.literal("model"),
|
|
@@ -50,6 +47,7 @@ const huggingFaceModelSchema = z.object({
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
|
* Represents a HuggingFace model available through the router API
|
|
* Represents a HuggingFace model available through the router API
|
|
|
|
|
+ *
|
|
|
* @property id - The unique identifier of the model
|
|
* @property id - The unique identifier of the model
|
|
|
* @property object - The object type (always "model")
|
|
* @property object - The object type (always "model")
|
|
|
* @property created - Unix timestamp of when the model was created
|
|
* @property created - Unix timestamp of when the model was created
|
|
@@ -58,26 +56,13 @@ const huggingFaceModelSchema = z.object({
|
|
|
*/
|
|
*/
|
|
|
export type HuggingFaceModel = z.infer<typeof huggingFaceModelSchema>
|
|
export type HuggingFaceModel = z.infer<typeof huggingFaceModelSchema>
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * HuggingFace API Response Schema
|
|
|
|
|
- */
|
|
|
|
|
const huggingFaceApiResponseSchema = z.object({
|
|
const huggingFaceApiResponseSchema = z.object({
|
|
|
object: z.string(),
|
|
object: z.string(),
|
|
|
data: z.array(huggingFaceModelSchema),
|
|
data: z.array(huggingFaceModelSchema),
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * Represents the response from the HuggingFace router API
|
|
|
|
|
- * @property object - The response object type
|
|
|
|
|
- * @property data - Array of available models
|
|
|
|
|
- */
|
|
|
|
|
type HuggingFaceApiResponse = z.infer<typeof huggingFaceApiResponseSchema>
|
|
type HuggingFaceApiResponse = z.infer<typeof huggingFaceApiResponseSchema>
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * Cache entry for storing fetched models
|
|
|
|
|
- * @property data - The cached model records
|
|
|
|
|
- * @property timestamp - Unix timestamp of when the cache was last updated
|
|
|
|
|
- */
|
|
|
|
|
interface CacheEntry {
|
|
interface CacheEntry {
|
|
|
data: ModelRecord
|
|
data: ModelRecord
|
|
|
rawModels?: HuggingFaceModel[]
|
|
rawModels?: HuggingFaceModel[]
|
|
@@ -87,13 +72,14 @@ interface CacheEntry {
|
|
|
let cache: CacheEntry | null = null
|
|
let cache: CacheEntry | null = null
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
|
- * Parse a HuggingFace model into ModelInfo format
|
|
|
|
|
|
|
+ * Parse a HuggingFace model into ModelInfo format.
|
|
|
|
|
+ *
|
|
|
* @param model - The HuggingFace model to parse
|
|
* @param model - The HuggingFace model to parse
|
|
|
* @param provider - Optional specific provider to use for capabilities
|
|
* @param provider - Optional specific provider to use for capabilities
|
|
|
* @returns ModelInfo object compatible with the application's model system
|
|
* @returns ModelInfo object compatible with the application's model system
|
|
|
*/
|
|
*/
|
|
|
function parseHuggingFaceModel(model: HuggingFaceModel, provider?: HuggingFaceProvider): ModelInfo {
|
|
function parseHuggingFaceModel(model: HuggingFaceModel, provider?: HuggingFaceProvider): ModelInfo {
|
|
|
- // Use provider-specific values if available, otherwise find first provider with values
|
|
|
|
|
|
|
+ // Use provider-specific values if available, otherwise find first provider with values.
|
|
|
const contextLength =
|
|
const contextLength =
|
|
|
provider?.context_length ||
|
|
provider?.context_length ||
|
|
|
model.providers.find((p) => p.context_length)?.context_length ||
|
|
model.providers.find((p) => p.context_length)?.context_length ||
|
|
@@ -101,13 +87,13 @@ function parseHuggingFaceModel(model: HuggingFaceModel, provider?: HuggingFacePr
|
|
|
|
|
|
|
|
const pricing = provider?.pricing || model.providers.find((p) => p.pricing)?.pricing
|
|
const pricing = provider?.pricing || model.providers.find((p) => p.pricing)?.pricing
|
|
|
|
|
|
|
|
- // Include provider name in description if specific provider is given
|
|
|
|
|
|
|
+ // Include provider name in description if specific provider is given.
|
|
|
const description = provider ? `${model.id} via ${provider.provider}` : `${model.id} via HuggingFace`
|
|
const description = provider ? `${model.id} via ${provider.provider}` : `${model.id} via HuggingFace`
|
|
|
|
|
|
|
|
return {
|
|
return {
|
|
|
maxTokens: Math.min(contextLength, HUGGINGFACE_DEFAULT_MAX_TOKENS),
|
|
maxTokens: Math.min(contextLength, HUGGINGFACE_DEFAULT_MAX_TOKENS),
|
|
|
contextWindow: contextLength,
|
|
contextWindow: contextLength,
|
|
|
- supportsImages: false, // HuggingFace API doesn't provide this info yet
|
|
|
|
|
|
|
+ supportsImages: false, // HuggingFace API doesn't provide this info yet.
|
|
|
supportsPromptCache: false,
|
|
supportsPromptCache: false,
|
|
|
supportsComputerUse: false,
|
|
supportsComputerUse: false,
|
|
|
inputPrice: pricing?.input,
|
|
inputPrice: pricing?.input,
|
|
@@ -125,7 +111,6 @@ function parseHuggingFaceModel(model: HuggingFaceModel, provider?: HuggingFacePr
|
|
|
export async function getHuggingFaceModels(): Promise<ModelRecord> {
|
|
export async function getHuggingFaceModels(): Promise<ModelRecord> {
|
|
|
const now = Date.now()
|
|
const now = Date.now()
|
|
|
|
|
|
|
|
- // Check cache
|
|
|
|
|
if (cache && now - cache.timestamp < HUGGINGFACE_CACHE_DURATION) {
|
|
if (cache && now - cache.timestamp < HUGGINGFACE_CACHE_DURATION) {
|
|
|
return cache.data
|
|
return cache.data
|
|
|
}
|
|
}
|
|
@@ -144,7 +129,7 @@ export async function getHuggingFaceModels(): Promise<ModelRecord> {
|
|
|
Pragma: "no-cache",
|
|
Pragma: "no-cache",
|
|
|
"Cache-Control": "no-cache",
|
|
"Cache-Control": "no-cache",
|
|
|
},
|
|
},
|
|
|
- timeout: 10000, // 10 second timeout
|
|
|
|
|
|
|
+ timeout: 10000,
|
|
|
})
|
|
})
|
|
|
|
|
|
|
|
const result = huggingFaceApiResponseSchema.safeParse(response.data)
|
|
const result = huggingFaceApiResponseSchema.safeParse(response.data)
|
|
@@ -157,38 +142,31 @@ export async function getHuggingFaceModels(): Promise<ModelRecord> {
|
|
|
const validModels = result.data.data.filter((model) => model.providers.length > 0)
|
|
const validModels = result.data.data.filter((model) => model.providers.length > 0)
|
|
|
|
|
|
|
|
for (const model of validModels) {
|
|
for (const model of validModels) {
|
|
|
- // Add the base model
|
|
|
|
|
|
|
+ // Add the base model.
|
|
|
models[model.id] = parseHuggingFaceModel(model)
|
|
models[model.id] = parseHuggingFaceModel(model)
|
|
|
|
|
|
|
|
- // Add provider-specific variants for all live providers
|
|
|
|
|
|
|
+ // Add provider-specific variants for all live providers.
|
|
|
for (const provider of model.providers) {
|
|
for (const provider of model.providers) {
|
|
|
if (provider.status === "live") {
|
|
if (provider.status === "live") {
|
|
|
const providerKey = `${model.id}:${provider.provider}`
|
|
const providerKey = `${model.id}:${provider.provider}`
|
|
|
const providerModel = parseHuggingFaceModel(model, provider)
|
|
const providerModel = parseHuggingFaceModel(model, provider)
|
|
|
|
|
|
|
|
- // Always add provider variants to show all available providers
|
|
|
|
|
|
|
+ // Always add provider variants to show all available providers.
|
|
|
models[providerKey] = providerModel
|
|
models[providerKey] = providerModel
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Update cache
|
|
|
|
|
- cache = {
|
|
|
|
|
- data: models,
|
|
|
|
|
- rawModels: validModels,
|
|
|
|
|
- timestamp: now,
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ cache = { data: models, rawModels: validModels, timestamp: now }
|
|
|
|
|
|
|
|
return models
|
|
return models
|
|
|
} catch (error) {
|
|
} catch (error) {
|
|
|
console.error("Error fetching HuggingFace models:", error)
|
|
console.error("Error fetching HuggingFace models:", error)
|
|
|
|
|
|
|
|
- // Return cached data if available
|
|
|
|
|
if (cache) {
|
|
if (cache) {
|
|
|
return cache.data
|
|
return cache.data
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // Re-throw with more context
|
|
|
|
|
if (axios.isAxiosError(error)) {
|
|
if (axios.isAxiosError(error)) {
|
|
|
if (error.response) {
|
|
if (error.response) {
|
|
|
throw new Error(
|
|
throw new Error(
|
|
@@ -208,45 +186,35 @@ export async function getHuggingFaceModels(): Promise<ModelRecord> {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
|
- * Get cached models without making an API request
|
|
|
|
|
|
|
+ * Get cached models without making an API request.
|
|
|
*/
|
|
*/
|
|
|
export function getCachedHuggingFaceModels(): ModelRecord | null {
|
|
export function getCachedHuggingFaceModels(): ModelRecord | null {
|
|
|
return cache?.data || null
|
|
return cache?.data || null
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
|
- * Get cached raw models for UI display
|
|
|
|
|
|
|
+ * Get cached raw models for UI display.
|
|
|
*/
|
|
*/
|
|
|
export function getCachedRawHuggingFaceModels(): HuggingFaceModel[] | null {
|
|
export function getCachedRawHuggingFaceModels(): HuggingFaceModel[] | null {
|
|
|
return cache?.rawModels || null
|
|
return cache?.rawModels || null
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * Clear the cache
|
|
|
|
|
- */
|
|
|
|
|
export function clearHuggingFaceCache(): void {
|
|
export function clearHuggingFaceCache(): void {
|
|
|
cache = null
|
|
cache = null
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * HuggingFace Models Response Interface
|
|
|
|
|
- */
|
|
|
|
|
export interface HuggingFaceModelsResponse {
|
|
export interface HuggingFaceModelsResponse {
|
|
|
models: HuggingFaceModel[]
|
|
models: HuggingFaceModel[]
|
|
|
cached: boolean
|
|
cached: boolean
|
|
|
timestamp: number
|
|
timestamp: number
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-/**
|
|
|
|
|
- * Get HuggingFace models with response metadata
|
|
|
|
|
- * This function provides a higher-level API that includes cache status and timestamp
|
|
|
|
|
- */
|
|
|
|
|
export async function getHuggingFaceModelsWithMetadata(): Promise<HuggingFaceModelsResponse> {
|
|
export async function getHuggingFaceModelsWithMetadata(): Promise<HuggingFaceModelsResponse> {
|
|
|
try {
|
|
try {
|
|
|
- // First, trigger the fetch to populate cache
|
|
|
|
|
|
|
+ // First, trigger the fetch to populate cache.
|
|
|
await getHuggingFaceModels()
|
|
await getHuggingFaceModels()
|
|
|
|
|
|
|
|
- // Get the raw models from cache
|
|
|
|
|
|
|
+ // Get the raw models from cache.
|
|
|
const cachedRawModels = getCachedRawHuggingFaceModels()
|
|
const cachedRawModels = getCachedRawHuggingFaceModels()
|
|
|
|
|
|
|
|
if (cachedRawModels) {
|
|
if (cachedRawModels) {
|
|
@@ -257,7 +225,7 @@ export async function getHuggingFaceModelsWithMetadata(): Promise<HuggingFaceMod
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- // If no cached raw models, fetch directly from API
|
|
|
|
|
|
|
+ // If no cached raw models, fetch directly from API.
|
|
|
const response = await axios.get(HUGGINGFACE_API_URL, {
|
|
const response = await axios.get(HUGGINGFACE_API_URL, {
|
|
|
headers: {
|
|
headers: {
|
|
|
"Upgrade-Insecure-Requests": "1",
|
|
"Upgrade-Insecure-Requests": "1",
|
|
@@ -281,10 +249,6 @@ export async function getHuggingFaceModelsWithMetadata(): Promise<HuggingFaceMod
|
|
|
}
|
|
}
|
|
|
} catch (error) {
|
|
} catch (error) {
|
|
|
console.error("Failed to get HuggingFace models:", error)
|
|
console.error("Failed to get HuggingFace models:", error)
|
|
|
- return {
|
|
|
|
|
- models: [],
|
|
|
|
|
- cached: false,
|
|
|
|
|
- timestamp: Date.now(),
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ return { models: [], cached: false, timestamp: Date.now() }
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|