ProviderSettingsManager.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. import { ExtensionContext } from "vscode"
  2. import { z, ZodError } from "zod"
  3. import {
  4. type ProviderSettingsEntry,
  5. providerSettingsSchema,
  6. providerSettingsSchemaDiscriminated,
  7. } from "@roo-code/types"
  8. import { TelemetryService } from "@roo-code/telemetry"
  9. import { Mode, modes } from "../../shared/modes"
  10. const providerSettingsWithIdSchema = providerSettingsSchema.extend({ id: z.string().optional() })
  11. const discriminatedProviderSettingsWithIdSchema = providerSettingsSchemaDiscriminated.and(
  12. z.object({ id: z.string().optional() }),
  13. )
  14. type ProviderSettingsWithId = z.infer<typeof providerSettingsWithIdSchema>
  15. export const providerProfilesSchema = z.object({
  16. currentApiConfigName: z.string(),
  17. apiConfigs: z.record(z.string(), providerSettingsWithIdSchema),
  18. modeApiConfigs: z.record(z.string(), z.string()).optional(),
  19. migrations: z
  20. .object({
  21. rateLimitSecondsMigrated: z.boolean().optional(),
  22. diffSettingsMigrated: z.boolean().optional(),
  23. openAiHeadersMigrated: z.boolean().optional(),
  24. })
  25. .optional(),
  26. })
  27. export type ProviderProfiles = z.infer<typeof providerProfilesSchema>
  28. export class ProviderSettingsManager {
  29. private static readonly SCOPE_PREFIX = "roo_cline_config_"
  30. private readonly defaultConfigId = this.generateId()
  31. private readonly defaultModeApiConfigs: Record<string, string> = Object.fromEntries(
  32. modes.map((mode) => [mode.slug, this.defaultConfigId]),
  33. )
  34. private readonly defaultProviderProfiles: ProviderProfiles = {
  35. currentApiConfigName: "default",
  36. apiConfigs: { default: { id: this.defaultConfigId } },
  37. modeApiConfigs: this.defaultModeApiConfigs,
  38. migrations: {
  39. rateLimitSecondsMigrated: true, // Mark as migrated on fresh installs
  40. diffSettingsMigrated: true, // Mark as migrated on fresh installs
  41. openAiHeadersMigrated: true, // Mark as migrated on fresh installs
  42. },
  43. }
  44. private readonly context: ExtensionContext
  45. constructor(context: ExtensionContext) {
  46. this.context = context
  47. // TODO: We really shouldn't have async methods in the constructor.
  48. this.initialize().catch(console.error)
  49. }
  50. public generateId() {
  51. return Math.random().toString(36).substring(2, 15)
  52. }
  53. // Synchronize readConfig/writeConfig operations to avoid data loss.
  54. private _lock = Promise.resolve()
  55. private lock<T>(cb: () => Promise<T>) {
  56. const next = this._lock.then(cb)
  57. this._lock = next.catch(() => {}) as Promise<void>
  58. return next
  59. }
  60. /**
  61. * Initialize config if it doesn't exist and run migrations.
  62. */
  63. public async initialize() {
  64. try {
  65. return await this.lock(async () => {
  66. const providerProfiles = await this.load()
  67. if (!providerProfiles) {
  68. await this.store(this.defaultProviderProfiles)
  69. return
  70. }
  71. let isDirty = false
  72. // Migrate existing installs to have per-mode API config map
  73. if (!providerProfiles.modeApiConfigs) {
  74. // Use the currently selected config for all modes initially
  75. const currentName = providerProfiles.currentApiConfigName
  76. const seedId =
  77. providerProfiles.apiConfigs[currentName]?.id ??
  78. Object.values(providerProfiles.apiConfigs)[0]?.id ??
  79. this.defaultConfigId
  80. providerProfiles.modeApiConfigs = Object.fromEntries(modes.map((m) => [m.slug, seedId]))
  81. isDirty = true
  82. }
  83. // Ensure all configs have IDs.
  84. for (const [_name, apiConfig] of Object.entries(providerProfiles.apiConfigs)) {
  85. if (!apiConfig.id) {
  86. apiConfig.id = this.generateId()
  87. isDirty = true
  88. }
  89. }
  90. // Ensure migrations field exists
  91. if (!providerProfiles.migrations) {
  92. providerProfiles.migrations = {
  93. rateLimitSecondsMigrated: false,
  94. diffSettingsMigrated: false,
  95. openAiHeadersMigrated: false,
  96. } // Initialize with default values
  97. isDirty = true
  98. }
  99. if (!providerProfiles.migrations.rateLimitSecondsMigrated) {
  100. await this.migrateRateLimitSeconds(providerProfiles)
  101. providerProfiles.migrations.rateLimitSecondsMigrated = true
  102. isDirty = true
  103. }
  104. if (!providerProfiles.migrations.diffSettingsMigrated) {
  105. await this.migrateDiffSettings(providerProfiles)
  106. providerProfiles.migrations.diffSettingsMigrated = true
  107. isDirty = true
  108. }
  109. if (!providerProfiles.migrations.openAiHeadersMigrated) {
  110. await this.migrateOpenAiHeaders(providerProfiles)
  111. providerProfiles.migrations.openAiHeadersMigrated = true
  112. isDirty = true
  113. }
  114. if (isDirty) {
  115. await this.store(providerProfiles)
  116. }
  117. })
  118. } catch (error) {
  119. throw new Error(`Failed to initialize config: ${error}`)
  120. }
  121. }
  122. private async migrateRateLimitSeconds(providerProfiles: ProviderProfiles) {
  123. try {
  124. let rateLimitSeconds: number | undefined
  125. try {
  126. rateLimitSeconds = await this.context.globalState.get<number>("rateLimitSeconds")
  127. } catch (error) {
  128. console.error("[MigrateRateLimitSeconds] Error getting global rate limit:", error)
  129. }
  130. if (rateLimitSeconds === undefined) {
  131. // Failed to get the existing value, use the default.
  132. rateLimitSeconds = 0
  133. }
  134. for (const [_name, apiConfig] of Object.entries(providerProfiles.apiConfigs)) {
  135. if (apiConfig.rateLimitSeconds === undefined) {
  136. apiConfig.rateLimitSeconds = rateLimitSeconds
  137. }
  138. }
  139. } catch (error) {
  140. console.error(`[MigrateRateLimitSeconds] Failed to migrate rate limit settings:`, error)
  141. }
  142. }
  143. private async migrateDiffSettings(providerProfiles: ProviderProfiles) {
  144. try {
  145. let diffEnabled: boolean | undefined
  146. let fuzzyMatchThreshold: number | undefined
  147. try {
  148. diffEnabled = await this.context.globalState.get<boolean>("diffEnabled")
  149. fuzzyMatchThreshold = await this.context.globalState.get<number>("fuzzyMatchThreshold")
  150. } catch (error) {
  151. console.error("[MigrateDiffSettings] Error getting global diff settings:", error)
  152. }
  153. if (diffEnabled === undefined) {
  154. // Failed to get the existing value, use the default.
  155. diffEnabled = true
  156. }
  157. if (fuzzyMatchThreshold === undefined) {
  158. // Failed to get the existing value, use the default.
  159. fuzzyMatchThreshold = 1.0
  160. }
  161. for (const [_name, apiConfig] of Object.entries(providerProfiles.apiConfigs)) {
  162. if (apiConfig.diffEnabled === undefined) {
  163. apiConfig.diffEnabled = diffEnabled
  164. }
  165. if (apiConfig.fuzzyMatchThreshold === undefined) {
  166. apiConfig.fuzzyMatchThreshold = fuzzyMatchThreshold
  167. }
  168. }
  169. } catch (error) {
  170. console.error(`[MigrateDiffSettings] Failed to migrate diff settings:`, error)
  171. }
  172. }
  173. private async migrateOpenAiHeaders(providerProfiles: ProviderProfiles) {
  174. try {
  175. for (const [_name, apiConfig] of Object.entries(providerProfiles.apiConfigs)) {
  176. // Use type assertion to access the deprecated property safely
  177. const configAny = apiConfig as any
  178. // Check if openAiHostHeader exists but openAiHeaders doesn't
  179. if (
  180. configAny.openAiHostHeader &&
  181. (!apiConfig.openAiHeaders || Object.keys(apiConfig.openAiHeaders || {}).length === 0)
  182. ) {
  183. // Create the headers object with the Host value
  184. apiConfig.openAiHeaders = { Host: configAny.openAiHostHeader }
  185. // Delete the old property to prevent re-migration
  186. // This prevents the header from reappearing after deletion
  187. configAny.openAiHostHeader = undefined
  188. }
  189. }
  190. } catch (error) {
  191. console.error(`[MigrateOpenAiHeaders] Failed to migrate OpenAI headers:`, error)
  192. }
  193. }
  194. /**
  195. * List all available configs with metadata.
  196. */
  197. public async listConfig(): Promise<ProviderSettingsEntry[]> {
  198. try {
  199. return await this.lock(async () => {
  200. const providerProfiles = await this.load()
  201. return Object.entries(providerProfiles.apiConfigs).map(([name, apiConfig]) => ({
  202. name,
  203. id: apiConfig.id || "",
  204. apiProvider: apiConfig.apiProvider,
  205. }))
  206. })
  207. } catch (error) {
  208. throw new Error(`Failed to list configs: ${error}`)
  209. }
  210. }
  211. /**
  212. * Save a config with the given name.
  213. * Preserves the ID from the input 'config' object if it exists,
  214. * otherwise generates a new one (for creation scenarios).
  215. */
  216. public async saveConfig(name: string, config: ProviderSettingsWithId): Promise<string> {
  217. try {
  218. return await this.lock(async () => {
  219. const providerProfiles = await this.load()
  220. // Preserve the existing ID if this is an update to an existing config.
  221. const existingId = providerProfiles.apiConfigs[name]?.id
  222. const id = config.id || existingId || this.generateId()
  223. // Filter out settings from other providers.
  224. const filteredConfig = providerSettingsSchemaDiscriminated.parse(config)
  225. providerProfiles.apiConfigs[name] = { ...filteredConfig, id }
  226. await this.store(providerProfiles)
  227. return id
  228. })
  229. } catch (error) {
  230. throw new Error(`Failed to save config: ${error}`)
  231. }
  232. }
  233. public async getProfile(
  234. params: { name: string } | { id: string },
  235. ): Promise<ProviderSettingsWithId & { name: string }> {
  236. try {
  237. return await this.lock(async () => {
  238. const providerProfiles = await this.load()
  239. let name: string
  240. let providerSettings: ProviderSettingsWithId
  241. if ("name" in params) {
  242. name = params.name
  243. if (!providerProfiles.apiConfigs[name]) {
  244. throw new Error(`Config with name '${name}' not found`)
  245. }
  246. providerSettings = providerProfiles.apiConfigs[name]
  247. } else {
  248. const id = params.id
  249. const entry = Object.entries(providerProfiles.apiConfigs).find(
  250. ([_, apiConfig]) => apiConfig.id === id,
  251. )
  252. if (!entry) {
  253. throw new Error(`Config with ID '${id}' not found`)
  254. }
  255. name = entry[0]
  256. providerSettings = entry[1]
  257. }
  258. return { name, ...providerSettings }
  259. })
  260. } catch (error) {
  261. throw new Error(`Failed to get profile: ${error instanceof Error ? error.message : error}`)
  262. }
  263. }
  264. /**
  265. * Activate a profile by name or ID.
  266. */
  267. public async activateProfile(
  268. params: { name: string } | { id: string },
  269. ): Promise<ProviderSettingsWithId & { name: string }> {
  270. const { name, ...providerSettings } = await this.getProfile(params)
  271. try {
  272. return await this.lock(async () => {
  273. const providerProfiles = await this.load()
  274. providerProfiles.currentApiConfigName = name
  275. await this.store(providerProfiles)
  276. return { name, ...providerSettings }
  277. })
  278. } catch (error) {
  279. throw new Error(`Failed to activate profile: ${error instanceof Error ? error.message : error}`)
  280. }
  281. }
  282. /**
  283. * Delete a config by name.
  284. */
  285. public async deleteConfig(name: string) {
  286. try {
  287. return await this.lock(async () => {
  288. const providerProfiles = await this.load()
  289. if (!providerProfiles.apiConfigs[name]) {
  290. throw new Error(`Config '${name}' not found`)
  291. }
  292. if (Object.keys(providerProfiles.apiConfigs).length === 1) {
  293. throw new Error(`Cannot delete the last remaining configuration`)
  294. }
  295. delete providerProfiles.apiConfigs[name]
  296. await this.store(providerProfiles)
  297. })
  298. } catch (error) {
  299. throw new Error(`Failed to delete config: ${error}`)
  300. }
  301. }
  302. /**
  303. * Check if a config exists by name.
  304. */
  305. public async hasConfig(name: string) {
  306. try {
  307. return await this.lock(async () => {
  308. const providerProfiles = await this.load()
  309. return name in providerProfiles.apiConfigs
  310. })
  311. } catch (error) {
  312. throw new Error(`Failed to check config existence: ${error}`)
  313. }
  314. }
  315. /**
  316. * Set the API config for a specific mode.
  317. */
  318. public async setModeConfig(mode: Mode, configId: string) {
  319. try {
  320. return await this.lock(async () => {
  321. const providerProfiles = await this.load()
  322. // Ensure the per-mode config map exists
  323. if (!providerProfiles.modeApiConfigs) {
  324. providerProfiles.modeApiConfigs = {}
  325. }
  326. // Assign the chosen config ID to this mode
  327. providerProfiles.modeApiConfigs[mode] = configId
  328. await this.store(providerProfiles)
  329. })
  330. } catch (error) {
  331. throw new Error(`Failed to set mode config: ${error}`)
  332. }
  333. }
  334. /**
  335. * Get the API config ID for a specific mode.
  336. */
  337. public async getModeConfigId(mode: Mode) {
  338. try {
  339. return await this.lock(async () => {
  340. const { modeApiConfigs } = await this.load()
  341. return modeApiConfigs?.[mode]
  342. })
  343. } catch (error) {
  344. throw new Error(`Failed to get mode config: ${error}`)
  345. }
  346. }
  347. public async export() {
  348. try {
  349. return await this.lock(async () => {
  350. const profiles = providerProfilesSchema.parse(await this.load())
  351. const configs = profiles.apiConfigs
  352. for (const name in configs) {
  353. // Avoid leaking properties from other providers.
  354. configs[name] = discriminatedProviderSettingsWithIdSchema.parse(configs[name])
  355. }
  356. return profiles
  357. })
  358. } catch (error) {
  359. throw new Error(`Failed to export provider profiles: ${error}`)
  360. }
  361. }
  362. public async import(providerProfiles: ProviderProfiles) {
  363. try {
  364. return await this.lock(() => this.store(providerProfiles))
  365. } catch (error) {
  366. throw new Error(`Failed to import provider profiles: ${error}`)
  367. }
  368. }
  369. /**
  370. * Reset provider profiles by deleting them from secrets.
  371. */
  372. public async resetAllConfigs() {
  373. return await this.lock(async () => {
  374. await this.context.secrets.delete(this.secretsKey)
  375. })
  376. }
  377. private get secretsKey() {
  378. return `${ProviderSettingsManager.SCOPE_PREFIX}api_config`
  379. }
  380. private async load(): Promise<ProviderProfiles> {
  381. try {
  382. const content = await this.context.secrets.get(this.secretsKey)
  383. if (!content) {
  384. return this.defaultProviderProfiles
  385. }
  386. const providerProfiles = providerProfilesSchema
  387. .extend({
  388. apiConfigs: z.record(z.string(), z.any()),
  389. })
  390. .parse(JSON.parse(content))
  391. const apiConfigs = Object.entries(providerProfiles.apiConfigs).reduce(
  392. (acc, [key, apiConfig]) => {
  393. const result = providerSettingsWithIdSchema.safeParse(apiConfig)
  394. return result.success ? { ...acc, [key]: result.data } : acc
  395. },
  396. {} as Record<string, ProviderSettingsWithId>,
  397. )
  398. return {
  399. ...providerProfiles,
  400. apiConfigs: Object.fromEntries(
  401. Object.entries(apiConfigs).filter(([_, apiConfig]) => apiConfig !== null),
  402. ),
  403. }
  404. } catch (error) {
  405. if (error instanceof ZodError) {
  406. TelemetryService.instance.captureSchemaValidationError({
  407. schemaName: "ProviderProfiles",
  408. error,
  409. })
  410. }
  411. throw new Error(`Failed to read provider profiles from secrets: ${error}`)
  412. }
  413. }
  414. private async store(providerProfiles: ProviderProfiles) {
  415. try {
  416. await this.context.secrets.store(this.secretsKey, JSON.stringify(providerProfiles, null, 2))
  417. } catch (error) {
  418. throw new Error(`Failed to write provider profiles to secrets: ${error}`)
  419. }
  420. }
  421. }