useOpenRouterModelProviders.ts 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import axios from "axios"
  2. import { z } from "zod"
  3. import { useQuery, UseQueryOptions } from "@tanstack/react-query"
  4. import { ModelInfo } from "@roo/shared/api"
  5. import { parseApiPrice } from "@roo/utils/cost"
  6. export const OPENROUTER_DEFAULT_PROVIDER_NAME = "[default]"
  7. const openRouterEndpointsSchema = z.object({
  8. data: z.object({
  9. id: z.string(),
  10. name: z.string(),
  11. description: z.string().optional(),
  12. architecture: z
  13. .object({
  14. modality: z.string().nullish(),
  15. tokenizer: z.string().nullish(),
  16. })
  17. .nullish(),
  18. endpoints: z.array(
  19. z.object({
  20. name: z.string(),
  21. context_length: z.number(),
  22. max_completion_tokens: z.number().nullish(),
  23. pricing: z
  24. .object({
  25. prompt: z.union([z.string(), z.number()]).optional(),
  26. completion: z.union([z.string(), z.number()]).optional(),
  27. })
  28. .optional(),
  29. }),
  30. ),
  31. }),
  32. })
  33. type OpenRouterModelProvider = ModelInfo & {
  34. label: string
  35. }
  36. async function getOpenRouterProvidersForModel(modelId: string) {
  37. const models: Record<string, OpenRouterModelProvider> = {}
  38. try {
  39. const response = await axios.get(`https://openrouter.ai/api/v1/models/${modelId}/endpoints`)
  40. const result = openRouterEndpointsSchema.safeParse(response.data)
  41. if (!result.success) {
  42. console.error("OpenRouter API response validation failed:", result.error)
  43. return models
  44. }
  45. const { id, description, architecture, endpoints } = result.data.data
  46. for (const endpoint of endpoints) {
  47. const providerName = endpoint.name.split("|")[0].trim()
  48. const inputPrice = parseApiPrice(endpoint.pricing?.prompt)
  49. const outputPrice = parseApiPrice(endpoint.pricing?.completion)
  50. const modelInfo: OpenRouterModelProvider = {
  51. maxTokens: endpoint.max_completion_tokens || endpoint.context_length,
  52. contextWindow: endpoint.context_length,
  53. supportsImages: architecture?.modality?.includes("image"),
  54. supportsPromptCache: false,
  55. inputPrice,
  56. outputPrice,
  57. description,
  58. label: providerName,
  59. }
  60. // TODO: This is wrong. We need to fetch the model info from
  61. // OpenRouter instead of hardcoding it here. The endpoints payload
  62. // doesn't include this unfortunately, so we need to get it from the
  63. // main models endpoint.
  64. switch (true) {
  65. case modelId.startsWith("anthropic/claude-3.7-sonnet"):
  66. modelInfo.supportsComputerUse = true
  67. modelInfo.supportsPromptCache = true
  68. modelInfo.cacheWritesPrice = 3.75
  69. modelInfo.cacheReadsPrice = 0.3
  70. modelInfo.maxTokens = id === "anthropic/claude-3.7-sonnet:thinking" ? 64_000 : 8192
  71. break
  72. case modelId.startsWith("anthropic/claude-3.5-sonnet-20240620"):
  73. modelInfo.supportsPromptCache = true
  74. modelInfo.cacheWritesPrice = 3.75
  75. modelInfo.cacheReadsPrice = 0.3
  76. modelInfo.maxTokens = 8192
  77. break
  78. default:
  79. modelInfo.supportsPromptCache = true
  80. modelInfo.cacheWritesPrice = 0.3
  81. modelInfo.cacheReadsPrice = 0.03
  82. break
  83. }
  84. models[providerName] = modelInfo
  85. }
  86. } catch (error) {
  87. if (error instanceof z.ZodError) {
  88. console.error(`OpenRouter API response validation failed:`, error.errors)
  89. } else {
  90. console.error(`Error fetching OpenRouter providers:`, error)
  91. }
  92. }
  93. return models
  94. }
  95. type UseOpenRouterModelProvidersOptions = Omit<
  96. UseQueryOptions<Record<string, OpenRouterModelProvider>>,
  97. "queryKey" | "queryFn"
  98. >
  99. export const useOpenRouterModelProviders = (modelId?: string, options?: UseOpenRouterModelProvidersOptions) =>
  100. useQuery<Record<string, OpenRouterModelProvider>>({
  101. queryKey: ["openrouter-model-providers", modelId],
  102. queryFn: () => (modelId ? getOpenRouterProvidersForModel(modelId) : {}),
  103. ...options,
  104. })