2
0

useSelectedModel.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. import {
  2. type ProviderName,
  3. type ProviderSettings,
  4. type ModelInfo,
  5. type ModelRecord,
  6. type RouterModels,
  7. anthropicModels,
  8. bedrockModels,
  9. deepSeekModels,
  10. moonshotModels,
  11. minimaxModels,
  12. geminiModels,
  13. mistralModels,
  14. openAiModelInfoSaneDefaults,
  15. openAiNativeModels,
  16. vertexModels,
  17. xaiModels,
  18. vscodeLlmModels,
  19. vscodeLlmDefaultModelId,
  20. openAiCodexModels,
  21. sambaNovaModels,
  22. internationalZAiModels,
  23. mainlandZAiModels,
  24. fireworksModels,
  25. basetenModels,
  26. azureModels,
  27. qwenCodeModels,
  28. litellmDefaultModelInfo,
  29. lMStudioDefaultModelInfo,
  30. BEDROCK_1M_CONTEXT_MODEL_IDS,
  31. VERTEX_1M_CONTEXT_MODEL_IDS,
  32. isDynamicProvider,
  33. isRetiredProvider,
  34. getProviderDefaultModelId,
  35. } from "@roo-code/types"
  36. import { useRouterModels } from "./useRouterModels"
  37. import { useOpenRouterModelProviders } from "./useOpenRouterModelProviders"
  38. import { useLmStudioModels } from "./useLmStudioModels"
  39. import { useOllamaModels } from "./useOllamaModels"
  40. /**
  41. * Helper to get a validated model ID for dynamic providers.
  42. * Returns the configured model ID if it exists in the available models, otherwise returns the default.
  43. */
  44. function getValidatedModelId(
  45. configuredId: string | undefined,
  46. availableModels: ModelRecord | undefined,
  47. defaultModelId: string,
  48. ): string {
  49. return configuredId && availableModels?.[configuredId] ? configuredId : defaultModelId
  50. }
  51. export const useSelectedModel = (apiConfiguration?: ProviderSettings) => {
  52. const provider = apiConfiguration?.apiProvider || "anthropic"
  53. const activeProvider: ProviderName | undefined = isRetiredProvider(provider) ? undefined : provider
  54. const dynamicProvider = activeProvider && isDynamicProvider(activeProvider) ? activeProvider : undefined
  55. const openRouterModelId = activeProvider === "openrouter" ? apiConfiguration?.openRouterModelId : undefined
  56. const lmStudioModelId = activeProvider === "lmstudio" ? apiConfiguration?.lmStudioModelId : undefined
  57. const ollamaModelId = activeProvider === "ollama" ? apiConfiguration?.ollamaModelId : undefined
  58. // Only fetch router models for dynamic providers
  59. const shouldFetchRouterModels = !!dynamicProvider
  60. const routerModels = useRouterModels({
  61. provider: dynamicProvider,
  62. enabled: shouldFetchRouterModels,
  63. })
  64. const openRouterModelProviders = useOpenRouterModelProviders(openRouterModelId)
  65. const lmStudioModels = useLmStudioModels(lmStudioModelId)
  66. const ollamaModels = useOllamaModels(ollamaModelId)
  67. // Compute readiness only for the data actually needed for the selected provider
  68. const needRouterModels = shouldFetchRouterModels
  69. const needOpenRouterProviders = activeProvider === "openrouter"
  70. const needLmStudio = typeof lmStudioModelId !== "undefined"
  71. const needOllama = typeof ollamaModelId !== "undefined"
  72. const hasValidRouterData =
  73. needRouterModels && dynamicProvider
  74. ? routerModels.data &&
  75. routerModels.data[dynamicProvider] !== undefined &&
  76. typeof routerModels.data[dynamicProvider] === "object" &&
  77. !routerModels.isLoading
  78. : true
  79. const isReady =
  80. (!needLmStudio || typeof lmStudioModels.data !== "undefined") &&
  81. (!needOllama || typeof ollamaModels.data !== "undefined") &&
  82. hasValidRouterData &&
  83. (!needOpenRouterProviders || typeof openRouterModelProviders.data !== "undefined")
  84. const { id, info } =
  85. apiConfiguration && isReady && activeProvider
  86. ? getSelectedModel({
  87. provider: activeProvider,
  88. apiConfiguration,
  89. routerModels: (routerModels.data || {}) as RouterModels,
  90. openRouterModelProviders: (openRouterModelProviders.data || {}) as Record<string, ModelInfo>,
  91. lmStudioModels: (lmStudioModels.data || undefined) as ModelRecord | undefined,
  92. ollamaModels: (ollamaModels.data || undefined) as ModelRecord | undefined,
  93. })
  94. : { id: getProviderDefaultModelId(activeProvider ?? "anthropic"), info: undefined }
  95. return {
  96. provider,
  97. id,
  98. info,
  99. isLoading:
  100. (needRouterModels && routerModels.isLoading) ||
  101. (needOpenRouterProviders && openRouterModelProviders.isLoading) ||
  102. (needLmStudio && lmStudioModels!.isLoading) ||
  103. (needOllama && ollamaModels!.isLoading),
  104. isError:
  105. (needRouterModels && routerModels.isError) ||
  106. (needOpenRouterProviders && openRouterModelProviders.isError) ||
  107. (needLmStudio && lmStudioModels!.isError) ||
  108. (needOllama && ollamaModels!.isError),
  109. }
  110. }
  111. function getSelectedModel({
  112. provider,
  113. apiConfiguration,
  114. routerModels,
  115. openRouterModelProviders,
  116. lmStudioModels,
  117. ollamaModels,
  118. }: {
  119. provider: ProviderName
  120. apiConfiguration: ProviderSettings
  121. routerModels: RouterModels
  122. openRouterModelProviders: Record<string, ModelInfo>
  123. lmStudioModels: ModelRecord | undefined
  124. ollamaModels: ModelRecord | undefined
  125. }): { id: string; info: ModelInfo | undefined } {
  126. // the `undefined` case are used to show the invalid selection to prevent
  127. // users from seeing the default model if their selection is invalid
  128. // this gives a better UX than showing the default model
  129. const defaultModelId = getProviderDefaultModelId(provider)
  130. switch (provider) {
  131. case "openrouter": {
  132. const id = getValidatedModelId(apiConfiguration.openRouterModelId, routerModels.openrouter, defaultModelId)
  133. let info = routerModels.openrouter?.[id]
  134. const specificProvider = apiConfiguration.openRouterSpecificProvider
  135. if (specificProvider && openRouterModelProviders[specificProvider]) {
  136. // Overwrite the info with the specific provider info. Some
  137. // fields are missing the model info for `openRouterModelProviders`
  138. // so we need to merge the two.
  139. info = info
  140. ? { ...info, ...openRouterModelProviders[specificProvider] }
  141. : openRouterModelProviders[specificProvider]
  142. }
  143. return { id, info }
  144. }
  145. case "requesty": {
  146. const id = getValidatedModelId(apiConfiguration.requestyModelId, routerModels.requesty, defaultModelId)
  147. const routerInfo = routerModels.requesty?.[id]
  148. return { id, info: routerInfo }
  149. }
  150. case "litellm": {
  151. const id = getValidatedModelId(apiConfiguration.litellmModelId, routerModels.litellm, defaultModelId)
  152. const routerInfo = routerModels.litellm?.[id]
  153. return { id, info: routerInfo ?? litellmDefaultModelInfo }
  154. }
  155. case "xai": {
  156. const id = apiConfiguration.apiModelId ?? defaultModelId
  157. const info = xaiModels[id as keyof typeof xaiModels]
  158. return info ? { id, info } : { id, info: undefined }
  159. }
  160. case "baseten": {
  161. const id = apiConfiguration.apiModelId ?? defaultModelId
  162. const info = basetenModels[id as keyof typeof basetenModels]
  163. return { id, info }
  164. }
  165. case "bedrock": {
  166. const id = apiConfiguration.apiModelId ?? defaultModelId
  167. const baseInfo = bedrockModels[id as keyof typeof bedrockModels]
  168. // Special case for custom ARN.
  169. if (id === "custom-arn") {
  170. return {
  171. id,
  172. info: { maxTokens: 5000, contextWindow: 128_000, supportsPromptCache: false, supportsImages: true },
  173. }
  174. }
  175. // Apply 1M context for supported Claude 4 models when enabled
  176. if (BEDROCK_1M_CONTEXT_MODEL_IDS.includes(id as any) && apiConfiguration.awsBedrock1MContext && baseInfo) {
  177. // Create a new ModelInfo object with updated context window
  178. const info: ModelInfo = {
  179. ...baseInfo,
  180. contextWindow: 1_000_000,
  181. }
  182. return { id, info }
  183. }
  184. return { id, info: baseInfo }
  185. }
  186. case "vertex": {
  187. const id = apiConfiguration.apiModelId ?? defaultModelId
  188. const baseInfo = vertexModels[id as keyof typeof vertexModels]
  189. // Apply 1M context for supported Claude 4 models when enabled
  190. if (VERTEX_1M_CONTEXT_MODEL_IDS.includes(id as any) && apiConfiguration.vertex1MContext && baseInfo) {
  191. const modelInfo: ModelInfo = baseInfo
  192. const tier = modelInfo.tiers?.[0]
  193. if (tier) {
  194. const info: ModelInfo = {
  195. ...modelInfo,
  196. contextWindow: tier.contextWindow,
  197. inputPrice: tier.inputPrice,
  198. outputPrice: tier.outputPrice,
  199. cacheWritesPrice: tier.cacheWritesPrice,
  200. cacheReadsPrice: tier.cacheReadsPrice,
  201. }
  202. return { id, info }
  203. }
  204. }
  205. return { id, info: baseInfo }
  206. }
  207. case "gemini": {
  208. const id = apiConfiguration.apiModelId ?? defaultModelId
  209. const info = geminiModels[id as keyof typeof geminiModels]
  210. return { id, info }
  211. }
  212. case "deepseek": {
  213. const id = apiConfiguration.apiModelId ?? defaultModelId
  214. const info = deepSeekModels[id as keyof typeof deepSeekModels]
  215. return { id, info }
  216. }
  217. case "moonshot": {
  218. const id = apiConfiguration.apiModelId ?? defaultModelId
  219. const info = moonshotModels[id as keyof typeof moonshotModels]
  220. return { id, info }
  221. }
  222. case "minimax": {
  223. const id = apiConfiguration.apiModelId ?? defaultModelId
  224. const info = minimaxModels[id as keyof typeof minimaxModels]
  225. return { id, info }
  226. }
  227. case "zai": {
  228. const isChina = apiConfiguration.zaiApiLine === "china_coding"
  229. const models = isChina ? mainlandZAiModels : internationalZAiModels
  230. const defaultModelId = getProviderDefaultModelId(provider, { isChina })
  231. const id = apiConfiguration.apiModelId ?? defaultModelId
  232. const info = models[id as keyof typeof models]
  233. return { id, info }
  234. }
  235. case "openai-native": {
  236. const id = apiConfiguration.apiModelId ?? defaultModelId
  237. const info = openAiNativeModels[id as keyof typeof openAiNativeModels]
  238. return { id, info }
  239. }
  240. case "mistral": {
  241. const id = apiConfiguration.apiModelId ?? defaultModelId
  242. const info = mistralModels[id as keyof typeof mistralModels]
  243. return { id, info }
  244. }
  245. case "openai": {
  246. const id = apiConfiguration.openAiModelId ?? ""
  247. const customInfo = apiConfiguration?.openAiCustomModelInfo
  248. const info = customInfo ?? openAiModelInfoSaneDefaults
  249. return { id, info }
  250. }
  251. case "ollama": {
  252. const id = apiConfiguration.ollamaModelId ?? ""
  253. const info = ollamaModels && ollamaModels[apiConfiguration.ollamaModelId!]
  254. const adjustedInfo =
  255. info?.contextWindow &&
  256. apiConfiguration?.ollamaNumCtx &&
  257. apiConfiguration.ollamaNumCtx < info.contextWindow
  258. ? { ...info, contextWindow: apiConfiguration.ollamaNumCtx }
  259. : info
  260. return {
  261. id,
  262. info: adjustedInfo || undefined,
  263. }
  264. }
  265. case "lmstudio": {
  266. const id = apiConfiguration.lmStudioModelId ?? ""
  267. const modelInfo = lmStudioModels && lmStudioModels[apiConfiguration.lmStudioModelId!]
  268. return {
  269. id,
  270. info: modelInfo ? { ...lMStudioDefaultModelInfo, ...modelInfo } : undefined,
  271. }
  272. }
  273. case "vscode-lm": {
  274. const id = apiConfiguration?.vsCodeLmModelSelector
  275. ? `${apiConfiguration.vsCodeLmModelSelector.vendor}/${apiConfiguration.vsCodeLmModelSelector.family}`
  276. : vscodeLlmDefaultModelId
  277. const modelFamily = apiConfiguration?.vsCodeLmModelSelector?.family ?? vscodeLlmDefaultModelId
  278. const info = vscodeLlmModels[modelFamily as keyof typeof vscodeLlmModels]
  279. return { id, info: { ...openAiModelInfoSaneDefaults, ...info, supportsImages: false } } // VSCode LM API currently doesn't support images.
  280. }
  281. case "sambanova": {
  282. const id = apiConfiguration.apiModelId ?? defaultModelId
  283. const info = sambaNovaModels[id as keyof typeof sambaNovaModels]
  284. return { id, info }
  285. }
  286. case "fireworks": {
  287. const id = apiConfiguration.apiModelId ?? defaultModelId
  288. const info = fireworksModels[id as keyof typeof fireworksModels]
  289. return { id, info }
  290. }
  291. case "roo": {
  292. const id = getValidatedModelId(apiConfiguration.apiModelId, routerModels.roo, defaultModelId)
  293. const info = routerModels.roo?.[id]
  294. return { id, info }
  295. }
  296. case "qwen-code": {
  297. const id = apiConfiguration.apiModelId ?? defaultModelId
  298. const info = qwenCodeModels[id as keyof typeof qwenCodeModels]
  299. return { id, info }
  300. }
  301. case "openai-codex": {
  302. const id = apiConfiguration.apiModelId ?? defaultModelId
  303. const info = openAiCodexModels[id as keyof typeof openAiCodexModels]
  304. return { id, info }
  305. }
  306. case "vercel-ai-gateway": {
  307. const id = getValidatedModelId(
  308. apiConfiguration.vercelAiGatewayModelId,
  309. routerModels["vercel-ai-gateway"],
  310. defaultModelId,
  311. )
  312. const info = routerModels["vercel-ai-gateway"]?.[id]
  313. return { id, info }
  314. }
  315. case "azure": {
  316. // apiModelId holds the base model selection (from model picker).
  317. // azureDeploymentName is the deployment name sent to the Azure API.
  318. // Only use apiModelId if it matches a known Azure model (prevents stale values from other providers).
  319. const explicitModelId = apiConfiguration.apiModelId
  320. const matchesAzureModel = explicitModelId && azureModels[explicitModelId as keyof typeof azureModels]
  321. const id = matchesAzureModel ? explicitModelId : defaultModelId
  322. const info = azureModels[id as keyof typeof azureModels]
  323. return { id, info: info || undefined }
  324. }
  325. // case "anthropic":
  326. // case "fake-ai":
  327. default: {
  328. provider satisfies "anthropic" | "gemini-cli" | "fake-ai"
  329. const id = apiConfiguration.apiModelId ?? defaultModelId
  330. const baseInfo = anthropicModels[id as keyof typeof anthropicModels]
  331. // Apply 1M context beta tier pricing for supported Claude 4 models
  332. if (
  333. provider === "anthropic" &&
  334. (id === "claude-sonnet-4-20250514" || id === "claude-sonnet-4-5" || id === "claude-opus-4-6") &&
  335. apiConfiguration.anthropicBeta1MContext &&
  336. baseInfo
  337. ) {
  338. // Type assertion since we know claude-sonnet-4-20250514 and claude-sonnet-4-5 have tiers
  339. const modelWithTiers = baseInfo as typeof baseInfo & {
  340. tiers?: Array<{
  341. contextWindow: number
  342. inputPrice?: number
  343. outputPrice?: number
  344. cacheWritesPrice?: number
  345. cacheReadsPrice?: number
  346. }>
  347. }
  348. const tier = modelWithTiers.tiers?.[0]
  349. if (tier) {
  350. // Create a new ModelInfo object with updated values
  351. const info: ModelInfo = {
  352. ...baseInfo,
  353. contextWindow: tier.contextWindow,
  354. inputPrice: tier.inputPrice ?? baseInfo.inputPrice,
  355. outputPrice: tier.outputPrice ?? baseInfo.outputPrice,
  356. cacheWritesPrice: tier.cacheWritesPrice ?? baseInfo.cacheWritesPrice,
  357. cacheReadsPrice: tier.cacheReadsPrice ?? baseInfo.cacheReadsPrice,
  358. }
  359. return { id, info }
  360. }
  361. }
  362. return { id, info: baseInfo }
  363. }
  364. }
  365. }