load_test.go 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186
  1. package config
  2. import (
  3. "io"
  4. "log/slog"
  5. "os"
  6. "path/filepath"
  7. "strings"
  8. "testing"
  9. "github.com/charmbracelet/catwalk/pkg/catwalk"
  10. "github.com/charmbracelet/crush/internal/csync"
  11. "github.com/charmbracelet/crush/internal/env"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func TestMain(m *testing.M) {
  15. slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  16. exitVal := m.Run()
  17. os.Exit(exitVal)
  18. }
  19. func TestConfig_LoadFromReaders(t *testing.T) {
  20. data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  21. data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  22. data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  23. loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  24. require.NoError(t, err)
  25. require.NotNil(t, loadedConfig)
  26. require.Equal(t, 1, loadedConfig.Providers.Len())
  27. pc, _ := loadedConfig.Providers.Get("openai")
  28. require.Equal(t, "key2", pc.APIKey)
  29. require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  30. }
  31. func TestConfig_setDefaults(t *testing.T) {
  32. cfg := &Config{}
  33. cfg.setDefaults("/tmp")
  34. require.NotNil(t, cfg.Options)
  35. require.NotNil(t, cfg.Options.TUI)
  36. require.NotNil(t, cfg.Options.ContextPaths)
  37. require.NotNil(t, cfg.Providers)
  38. require.NotNil(t, cfg.Models)
  39. require.NotNil(t, cfg.LSP)
  40. require.NotNil(t, cfg.MCP)
  41. require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  42. for _, path := range defaultContextPaths {
  43. require.Contains(t, cfg.Options.ContextPaths, path)
  44. }
  45. require.Equal(t, "/tmp", cfg.workingDir)
  46. }
  47. func TestConfig_configureProviders(t *testing.T) {
  48. knownProviders := []catwalk.Provider{
  49. {
  50. ID: "openai",
  51. APIKey: "$OPENAI_API_KEY",
  52. APIEndpoint: "https://api.openai.com/v1",
  53. Models: []catwalk.Model{{
  54. ID: "test-model",
  55. }},
  56. },
  57. }
  58. cfg := &Config{}
  59. cfg.setDefaults("/tmp")
  60. env := env.NewFromMap(map[string]string{
  61. "OPENAI_API_KEY": "test-key",
  62. })
  63. resolver := NewEnvironmentVariableResolver(env)
  64. err := cfg.configureProviders(env, resolver, knownProviders)
  65. require.NoError(t, err)
  66. require.Equal(t, 1, cfg.Providers.Len())
  67. // We want to make sure that we keep the configured API key as a placeholder
  68. pc, _ := cfg.Providers.Get("openai")
  69. require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  70. }
  71. func TestConfig_configureProvidersWithOverride(t *testing.T) {
  72. knownProviders := []catwalk.Provider{
  73. {
  74. ID: "openai",
  75. APIKey: "$OPENAI_API_KEY",
  76. APIEndpoint: "https://api.openai.com/v1",
  77. Models: []catwalk.Model{{
  78. ID: "test-model",
  79. }},
  80. },
  81. }
  82. cfg := &Config{
  83. Providers: csync.NewMap[string, ProviderConfig](),
  84. }
  85. cfg.Providers.Set("openai", ProviderConfig{
  86. APIKey: "xyz",
  87. BaseURL: "https://api.openai.com/v2",
  88. Models: []catwalk.Model{
  89. {
  90. ID: "test-model",
  91. Name: "Updated",
  92. },
  93. {
  94. ID: "another-model",
  95. },
  96. },
  97. })
  98. cfg.setDefaults("/tmp")
  99. env := env.NewFromMap(map[string]string{
  100. "OPENAI_API_KEY": "test-key",
  101. })
  102. resolver := NewEnvironmentVariableResolver(env)
  103. err := cfg.configureProviders(env, resolver, knownProviders)
  104. require.NoError(t, err)
  105. require.Equal(t, 1, cfg.Providers.Len())
  106. // We want to make sure that we keep the configured API key as a placeholder
  107. pc, _ := cfg.Providers.Get("openai")
  108. require.Equal(t, "xyz", pc.APIKey)
  109. require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  110. require.Len(t, pc.Models, 2)
  111. require.Equal(t, "Updated", pc.Models[0].Name)
  112. }
  113. func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
  114. knownProviders := []catwalk.Provider{
  115. {
  116. ID: "openai",
  117. APIKey: "$OPENAI_API_KEY",
  118. APIEndpoint: "https://api.openai.com/v1",
  119. Models: []catwalk.Model{{
  120. ID: "test-model",
  121. }},
  122. },
  123. }
  124. cfg := &Config{
  125. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  126. "custom": {
  127. APIKey: "xyz",
  128. BaseURL: "https://api.someendpoint.com/v2",
  129. Models: []catwalk.Model{
  130. {
  131. ID: "test-model",
  132. },
  133. },
  134. },
  135. }),
  136. }
  137. cfg.setDefaults("/tmp")
  138. env := env.NewFromMap(map[string]string{
  139. "OPENAI_API_KEY": "test-key",
  140. })
  141. resolver := NewEnvironmentVariableResolver(env)
  142. err := cfg.configureProviders(env, resolver, knownProviders)
  143. require.NoError(t, err)
  144. // Should be to because of the env variable
  145. require.Equal(t, cfg.Providers.Len(), 2)
  146. // We want to make sure that we keep the configured API key as a placeholder
  147. pc, _ := cfg.Providers.Get("custom")
  148. require.Equal(t, "xyz", pc.APIKey)
  149. // Make sure we set the ID correctly
  150. require.Equal(t, "custom", pc.ID)
  151. require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
  152. require.Len(t, pc.Models, 1)
  153. _, ok := cfg.Providers.Get("openai")
  154. require.True(t, ok, "OpenAI provider should still be present")
  155. }
  156. func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
  157. knownProviders := []catwalk.Provider{
  158. {
  159. ID: catwalk.InferenceProviderBedrock,
  160. APIKey: "",
  161. APIEndpoint: "",
  162. Models: []catwalk.Model{{
  163. ID: "anthropic.claude-sonnet-4-20250514-v1:0",
  164. }},
  165. },
  166. }
  167. cfg := &Config{}
  168. cfg.setDefaults("/tmp")
  169. env := env.NewFromMap(map[string]string{
  170. "AWS_ACCESS_KEY_ID": "test-key-id",
  171. "AWS_SECRET_ACCESS_KEY": "test-secret-key",
  172. })
  173. resolver := NewEnvironmentVariableResolver(env)
  174. err := cfg.configureProviders(env, resolver, knownProviders)
  175. require.NoError(t, err)
  176. require.Equal(t, cfg.Providers.Len(), 1)
  177. bedrockProvider, ok := cfg.Providers.Get("bedrock")
  178. require.True(t, ok, "Bedrock provider should be present")
  179. require.Len(t, bedrockProvider.Models, 1)
  180. require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
  181. }
  182. func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
  183. knownProviders := []catwalk.Provider{
  184. {
  185. ID: catwalk.InferenceProviderBedrock,
  186. APIKey: "",
  187. APIEndpoint: "",
  188. Models: []catwalk.Model{{
  189. ID: "anthropic.claude-sonnet-4-20250514-v1:0",
  190. }},
  191. },
  192. }
  193. cfg := &Config{}
  194. cfg.setDefaults("/tmp")
  195. env := env.NewFromMap(map[string]string{})
  196. resolver := NewEnvironmentVariableResolver(env)
  197. err := cfg.configureProviders(env, resolver, knownProviders)
  198. require.NoError(t, err)
  199. // Provider should not be configured without credentials
  200. require.Equal(t, cfg.Providers.Len(), 0)
  201. }
  202. func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
  203. knownProviders := []catwalk.Provider{
  204. {
  205. ID: catwalk.InferenceProviderBedrock,
  206. APIKey: "",
  207. APIEndpoint: "",
  208. Models: []catwalk.Model{{
  209. ID: "some-random-model",
  210. }},
  211. },
  212. }
  213. cfg := &Config{}
  214. cfg.setDefaults("/tmp")
  215. env := env.NewFromMap(map[string]string{
  216. "AWS_ACCESS_KEY_ID": "test-key-id",
  217. "AWS_SECRET_ACCESS_KEY": "test-secret-key",
  218. })
  219. resolver := NewEnvironmentVariableResolver(env)
  220. err := cfg.configureProviders(env, resolver, knownProviders)
  221. require.Error(t, err)
  222. }
  223. func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
  224. knownProviders := []catwalk.Provider{
  225. {
  226. ID: catwalk.InferenceProviderVertexAI,
  227. APIKey: "",
  228. APIEndpoint: "",
  229. Models: []catwalk.Model{{
  230. ID: "gemini-pro",
  231. }},
  232. },
  233. }
  234. cfg := &Config{}
  235. cfg.setDefaults("/tmp")
  236. env := env.NewFromMap(map[string]string{
  237. "VERTEXAI_PROJECT": "test-project",
  238. "VERTEXAI_LOCATION": "us-central1",
  239. })
  240. resolver := NewEnvironmentVariableResolver(env)
  241. err := cfg.configureProviders(env, resolver, knownProviders)
  242. require.NoError(t, err)
  243. require.Equal(t, cfg.Providers.Len(), 1)
  244. vertexProvider, ok := cfg.Providers.Get("vertexai")
  245. require.True(t, ok, "VertexAI provider should be present")
  246. require.Len(t, vertexProvider.Models, 1)
  247. require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
  248. require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
  249. require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
  250. }
  251. func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
  252. knownProviders := []catwalk.Provider{
  253. {
  254. ID: catwalk.InferenceProviderVertexAI,
  255. APIKey: "",
  256. APIEndpoint: "",
  257. Models: []catwalk.Model{{
  258. ID: "gemini-pro",
  259. }},
  260. },
  261. }
  262. cfg := &Config{}
  263. cfg.setDefaults("/tmp")
  264. env := env.NewFromMap(map[string]string{
  265. "GOOGLE_GENAI_USE_VERTEXAI": "false",
  266. "GOOGLE_CLOUD_PROJECT": "test-project",
  267. "GOOGLE_CLOUD_LOCATION": "us-central1",
  268. })
  269. resolver := NewEnvironmentVariableResolver(env)
  270. err := cfg.configureProviders(env, resolver, knownProviders)
  271. require.NoError(t, err)
  272. // Provider should not be configured without proper credentials
  273. require.Equal(t, cfg.Providers.Len(), 0)
  274. }
  275. func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
  276. knownProviders := []catwalk.Provider{
  277. {
  278. ID: catwalk.InferenceProviderVertexAI,
  279. APIKey: "",
  280. APIEndpoint: "",
  281. Models: []catwalk.Model{{
  282. ID: "gemini-pro",
  283. }},
  284. },
  285. }
  286. cfg := &Config{}
  287. cfg.setDefaults("/tmp")
  288. env := env.NewFromMap(map[string]string{
  289. "GOOGLE_GENAI_USE_VERTEXAI": "true",
  290. "GOOGLE_CLOUD_LOCATION": "us-central1",
  291. })
  292. resolver := NewEnvironmentVariableResolver(env)
  293. err := cfg.configureProviders(env, resolver, knownProviders)
  294. require.NoError(t, err)
  295. // Provider should not be configured without project
  296. require.Equal(t, cfg.Providers.Len(), 0)
  297. }
  298. func TestConfig_configureProvidersSetProviderID(t *testing.T) {
  299. knownProviders := []catwalk.Provider{
  300. {
  301. ID: "openai",
  302. APIKey: "$OPENAI_API_KEY",
  303. APIEndpoint: "https://api.openai.com/v1",
  304. Models: []catwalk.Model{{
  305. ID: "test-model",
  306. }},
  307. },
  308. }
  309. cfg := &Config{}
  310. cfg.setDefaults("/tmp")
  311. env := env.NewFromMap(map[string]string{
  312. "OPENAI_API_KEY": "test-key",
  313. })
  314. resolver := NewEnvironmentVariableResolver(env)
  315. err := cfg.configureProviders(env, resolver, knownProviders)
  316. require.NoError(t, err)
  317. require.Equal(t, cfg.Providers.Len(), 1)
  318. // Provider ID should be set
  319. pc, _ := cfg.Providers.Get("openai")
  320. require.Equal(t, "openai", pc.ID)
  321. }
  322. func TestConfig_EnabledProviders(t *testing.T) {
  323. t.Run("all providers enabled", func(t *testing.T) {
  324. cfg := &Config{
  325. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  326. "openai": {
  327. ID: "openai",
  328. APIKey: "key1",
  329. Disable: false,
  330. },
  331. "anthropic": {
  332. ID: "anthropic",
  333. APIKey: "key2",
  334. Disable: false,
  335. },
  336. }),
  337. }
  338. enabled := cfg.EnabledProviders()
  339. require.Len(t, enabled, 2)
  340. })
  341. t.Run("some providers disabled", func(t *testing.T) {
  342. cfg := &Config{
  343. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  344. "openai": {
  345. ID: "openai",
  346. APIKey: "key1",
  347. Disable: false,
  348. },
  349. "anthropic": {
  350. ID: "anthropic",
  351. APIKey: "key2",
  352. Disable: true,
  353. },
  354. }),
  355. }
  356. enabled := cfg.EnabledProviders()
  357. require.Len(t, enabled, 1)
  358. require.Equal(t, "openai", enabled[0].ID)
  359. })
  360. t.Run("empty providers map", func(t *testing.T) {
  361. cfg := &Config{
  362. Providers: csync.NewMap[string, ProviderConfig](),
  363. }
  364. enabled := cfg.EnabledProviders()
  365. require.Len(t, enabled, 0)
  366. })
  367. }
  368. func TestConfig_IsConfigured(t *testing.T) {
  369. t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
  370. cfg := &Config{
  371. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  372. "openai": {
  373. ID: "openai",
  374. APIKey: "key1",
  375. Disable: false,
  376. },
  377. }),
  378. }
  379. require.True(t, cfg.IsConfigured())
  380. })
  381. t.Run("returns false when no providers are configured", func(t *testing.T) {
  382. cfg := &Config{
  383. Providers: csync.NewMap[string, ProviderConfig](),
  384. }
  385. require.False(t, cfg.IsConfigured())
  386. })
  387. t.Run("returns false when all providers are disabled", func(t *testing.T) {
  388. cfg := &Config{
  389. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  390. "openai": {
  391. ID: "openai",
  392. APIKey: "key1",
  393. Disable: true,
  394. },
  395. "anthropic": {
  396. ID: "anthropic",
  397. APIKey: "key2",
  398. Disable: true,
  399. },
  400. }),
  401. }
  402. require.False(t, cfg.IsConfigured())
  403. })
  404. }
  405. func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
  406. knownProviders := []catwalk.Provider{
  407. {
  408. ID: "openai",
  409. APIKey: "$OPENAI_API_KEY",
  410. APIEndpoint: "https://api.openai.com/v1",
  411. Models: []catwalk.Model{{
  412. ID: "test-model",
  413. }},
  414. },
  415. }
  416. cfg := &Config{
  417. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  418. "openai": {
  419. Disable: true,
  420. },
  421. }),
  422. }
  423. cfg.setDefaults("/tmp")
  424. env := env.NewFromMap(map[string]string{
  425. "OPENAI_API_KEY": "test-key",
  426. })
  427. resolver := NewEnvironmentVariableResolver(env)
  428. err := cfg.configureProviders(env, resolver, knownProviders)
  429. require.NoError(t, err)
  430. // Provider should be removed from config when disabled
  431. require.Equal(t, cfg.Providers.Len(), 0)
  432. _, exists := cfg.Providers.Get("openai")
  433. require.False(t, exists)
  434. }
  435. func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
  436. t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
  437. cfg := &Config{
  438. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  439. "custom": {
  440. BaseURL: "https://api.custom.com/v1",
  441. Models: []catwalk.Model{{
  442. ID: "test-model",
  443. }},
  444. },
  445. "openai": {
  446. APIKey: "$MISSING",
  447. },
  448. }),
  449. }
  450. cfg.setDefaults("/tmp")
  451. env := env.NewFromMap(map[string]string{})
  452. resolver := NewEnvironmentVariableResolver(env)
  453. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  454. require.NoError(t, err)
  455. require.Equal(t, cfg.Providers.Len(), 1)
  456. _, exists := cfg.Providers.Get("custom")
  457. require.True(t, exists)
  458. })
  459. t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
  460. cfg := &Config{
  461. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  462. "custom": {
  463. APIKey: "test-key",
  464. Models: []catwalk.Model{{
  465. ID: "test-model",
  466. }},
  467. },
  468. }),
  469. }
  470. cfg.setDefaults("/tmp")
  471. env := env.NewFromMap(map[string]string{})
  472. resolver := NewEnvironmentVariableResolver(env)
  473. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  474. require.NoError(t, err)
  475. require.Equal(t, cfg.Providers.Len(), 0)
  476. _, exists := cfg.Providers.Get("custom")
  477. require.False(t, exists)
  478. })
  479. t.Run("custom provider with no models is removed", func(t *testing.T) {
  480. cfg := &Config{
  481. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  482. "custom": {
  483. APIKey: "test-key",
  484. BaseURL: "https://api.custom.com/v1",
  485. Models: []catwalk.Model{},
  486. },
  487. }),
  488. }
  489. cfg.setDefaults("/tmp")
  490. env := env.NewFromMap(map[string]string{})
  491. resolver := NewEnvironmentVariableResolver(env)
  492. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  493. require.NoError(t, err)
  494. require.Equal(t, cfg.Providers.Len(), 0)
  495. _, exists := cfg.Providers.Get("custom")
  496. require.False(t, exists)
  497. })
  498. t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
  499. cfg := &Config{
  500. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  501. "custom": {
  502. APIKey: "test-key",
  503. BaseURL: "https://api.custom.com/v1",
  504. Type: "unsupported",
  505. Models: []catwalk.Model{{
  506. ID: "test-model",
  507. }},
  508. },
  509. }),
  510. }
  511. cfg.setDefaults("/tmp")
  512. env := env.NewFromMap(map[string]string{})
  513. resolver := NewEnvironmentVariableResolver(env)
  514. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  515. require.NoError(t, err)
  516. require.Equal(t, cfg.Providers.Len(), 0)
  517. _, exists := cfg.Providers.Get("custom")
  518. require.False(t, exists)
  519. })
  520. t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
  521. cfg := &Config{
  522. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  523. "custom": {
  524. APIKey: "test-key",
  525. BaseURL: "https://api.custom.com/v1",
  526. Type: catwalk.TypeOpenAI,
  527. Models: []catwalk.Model{{
  528. ID: "test-model",
  529. }},
  530. },
  531. }),
  532. }
  533. cfg.setDefaults("/tmp")
  534. env := env.NewFromMap(map[string]string{})
  535. resolver := NewEnvironmentVariableResolver(env)
  536. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  537. require.NoError(t, err)
  538. require.Equal(t, cfg.Providers.Len(), 1)
  539. customProvider, exists := cfg.Providers.Get("custom")
  540. require.True(t, exists)
  541. require.Equal(t, "custom", customProvider.ID)
  542. require.Equal(t, "test-key", customProvider.APIKey)
  543. require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
  544. })
  545. t.Run("custom anthropic provider is supported", func(t *testing.T) {
  546. cfg := &Config{
  547. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  548. "custom-anthropic": {
  549. APIKey: "test-key",
  550. BaseURL: "https://api.anthropic.com/v1",
  551. Type: catwalk.TypeAnthropic,
  552. Models: []catwalk.Model{{
  553. ID: "claude-3-sonnet",
  554. }},
  555. },
  556. }),
  557. }
  558. cfg.setDefaults("/tmp")
  559. env := env.NewFromMap(map[string]string{})
  560. resolver := NewEnvironmentVariableResolver(env)
  561. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  562. require.NoError(t, err)
  563. require.Equal(t, cfg.Providers.Len(), 1)
  564. customProvider, exists := cfg.Providers.Get("custom-anthropic")
  565. require.True(t, exists)
  566. require.Equal(t, "custom-anthropic", customProvider.ID)
  567. require.Equal(t, "test-key", customProvider.APIKey)
  568. require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
  569. require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
  570. })
  571. t.Run("disabled custom provider is removed", func(t *testing.T) {
  572. cfg := &Config{
  573. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  574. "custom": {
  575. APIKey: "test-key",
  576. BaseURL: "https://api.custom.com/v1",
  577. Type: catwalk.TypeOpenAI,
  578. Disable: true,
  579. Models: []catwalk.Model{{
  580. ID: "test-model",
  581. }},
  582. },
  583. }),
  584. }
  585. cfg.setDefaults("/tmp")
  586. env := env.NewFromMap(map[string]string{})
  587. resolver := NewEnvironmentVariableResolver(env)
  588. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  589. require.NoError(t, err)
  590. require.Equal(t, cfg.Providers.Len(), 0)
  591. _, exists := cfg.Providers.Get("custom")
  592. require.False(t, exists)
  593. })
  594. }
  595. func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
  596. t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
  597. knownProviders := []catwalk.Provider{
  598. {
  599. ID: catwalk.InferenceProviderVertexAI,
  600. APIKey: "",
  601. APIEndpoint: "",
  602. Models: []catwalk.Model{{
  603. ID: "gemini-pro",
  604. }},
  605. },
  606. }
  607. cfg := &Config{
  608. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  609. "vertexai": {
  610. BaseURL: "custom-url",
  611. },
  612. }),
  613. }
  614. cfg.setDefaults("/tmp")
  615. env := env.NewFromMap(map[string]string{
  616. "GOOGLE_GENAI_USE_VERTEXAI": "false",
  617. })
  618. resolver := NewEnvironmentVariableResolver(env)
  619. err := cfg.configureProviders(env, resolver, knownProviders)
  620. require.NoError(t, err)
  621. require.Equal(t, cfg.Providers.Len(), 0)
  622. _, exists := cfg.Providers.Get("vertexai")
  623. require.False(t, exists)
  624. })
  625. t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
  626. knownProviders := []catwalk.Provider{
  627. {
  628. ID: catwalk.InferenceProviderBedrock,
  629. APIKey: "",
  630. APIEndpoint: "",
  631. Models: []catwalk.Model{{
  632. ID: "anthropic.claude-sonnet-4-20250514-v1:0",
  633. }},
  634. },
  635. }
  636. cfg := &Config{
  637. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  638. "bedrock": {
  639. BaseURL: "custom-url",
  640. },
  641. }),
  642. }
  643. cfg.setDefaults("/tmp")
  644. env := env.NewFromMap(map[string]string{})
  645. resolver := NewEnvironmentVariableResolver(env)
  646. err := cfg.configureProviders(env, resolver, knownProviders)
  647. require.NoError(t, err)
  648. require.Equal(t, cfg.Providers.Len(), 0)
  649. _, exists := cfg.Providers.Get("bedrock")
  650. require.False(t, exists)
  651. })
  652. t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
  653. knownProviders := []catwalk.Provider{
  654. {
  655. ID: "openai",
  656. APIKey: "$MISSING_API_KEY",
  657. APIEndpoint: "https://api.openai.com/v1",
  658. Models: []catwalk.Model{{
  659. ID: "test-model",
  660. }},
  661. },
  662. }
  663. cfg := &Config{
  664. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  665. "openai": {
  666. BaseURL: "custom-url",
  667. },
  668. }),
  669. }
  670. cfg.setDefaults("/tmp")
  671. env := env.NewFromMap(map[string]string{})
  672. resolver := NewEnvironmentVariableResolver(env)
  673. err := cfg.configureProviders(env, resolver, knownProviders)
  674. require.NoError(t, err)
  675. require.Equal(t, cfg.Providers.Len(), 0)
  676. _, exists := cfg.Providers.Get("openai")
  677. require.False(t, exists)
  678. })
  679. t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
  680. knownProviders := []catwalk.Provider{
  681. {
  682. ID: "openai",
  683. APIKey: "$OPENAI_API_KEY",
  684. APIEndpoint: "$MISSING_ENDPOINT",
  685. Models: []catwalk.Model{{
  686. ID: "test-model",
  687. }},
  688. },
  689. }
  690. cfg := &Config{
  691. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  692. "openai": {
  693. APIKey: "test-key",
  694. },
  695. }),
  696. }
  697. cfg.setDefaults("/tmp")
  698. env := env.NewFromMap(map[string]string{
  699. "OPENAI_API_KEY": "test-key",
  700. })
  701. resolver := NewEnvironmentVariableResolver(env)
  702. err := cfg.configureProviders(env, resolver, knownProviders)
  703. require.NoError(t, err)
  704. require.Equal(t, cfg.Providers.Len(), 1)
  705. _, exists := cfg.Providers.Get("openai")
  706. require.True(t, exists)
  707. })
  708. }
  709. func TestConfig_defaultModelSelection(t *testing.T) {
  710. t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
  711. knownProviders := []catwalk.Provider{
  712. {
  713. ID: "openai",
  714. APIKey: "abc",
  715. DefaultLargeModelID: "large-model",
  716. DefaultSmallModelID: "small-model",
  717. Models: []catwalk.Model{
  718. {
  719. ID: "large-model",
  720. DefaultMaxTokens: 1000,
  721. },
  722. {
  723. ID: "small-model",
  724. DefaultMaxTokens: 500,
  725. },
  726. },
  727. },
  728. }
  729. cfg := &Config{}
  730. cfg.setDefaults("/tmp")
  731. env := env.NewFromMap(map[string]string{})
  732. resolver := NewEnvironmentVariableResolver(env)
  733. err := cfg.configureProviders(env, resolver, knownProviders)
  734. require.NoError(t, err)
  735. large, small, err := cfg.defaultModelSelection(knownProviders)
  736. require.NoError(t, err)
  737. require.Equal(t, "large-model", large.Model)
  738. require.Equal(t, "openai", large.Provider)
  739. require.Equal(t, int64(1000), large.MaxTokens)
  740. require.Equal(t, "small-model", small.Model)
  741. require.Equal(t, "openai", small.Provider)
  742. require.Equal(t, int64(500), small.MaxTokens)
  743. })
  744. t.Run("should error if no providers configured", func(t *testing.T) {
  745. knownProviders := []catwalk.Provider{
  746. {
  747. ID: "openai",
  748. APIKey: "$MISSING_KEY",
  749. DefaultLargeModelID: "large-model",
  750. DefaultSmallModelID: "small-model",
  751. Models: []catwalk.Model{
  752. {
  753. ID: "large-model",
  754. DefaultMaxTokens: 1000,
  755. },
  756. {
  757. ID: "small-model",
  758. DefaultMaxTokens: 500,
  759. },
  760. },
  761. },
  762. }
  763. cfg := &Config{}
  764. cfg.setDefaults("/tmp")
  765. env := env.NewFromMap(map[string]string{})
  766. resolver := NewEnvironmentVariableResolver(env)
  767. err := cfg.configureProviders(env, resolver, knownProviders)
  768. require.NoError(t, err)
  769. _, _, err = cfg.defaultModelSelection(knownProviders)
  770. require.Error(t, err)
  771. })
  772. t.Run("should error if model is missing", func(t *testing.T) {
  773. knownProviders := []catwalk.Provider{
  774. {
  775. ID: "openai",
  776. APIKey: "abc",
  777. DefaultLargeModelID: "large-model",
  778. DefaultSmallModelID: "small-model",
  779. Models: []catwalk.Model{
  780. {
  781. ID: "not-large-model",
  782. DefaultMaxTokens: 1000,
  783. },
  784. {
  785. ID: "small-model",
  786. DefaultMaxTokens: 500,
  787. },
  788. },
  789. },
  790. }
  791. cfg := &Config{}
  792. cfg.setDefaults("/tmp")
  793. env := env.NewFromMap(map[string]string{})
  794. resolver := NewEnvironmentVariableResolver(env)
  795. err := cfg.configureProviders(env, resolver, knownProviders)
  796. require.NoError(t, err)
  797. _, _, err = cfg.defaultModelSelection(knownProviders)
  798. require.Error(t, err)
  799. })
  800. t.Run("should configure the default models with a custom provider", func(t *testing.T) {
  801. knownProviders := []catwalk.Provider{
  802. {
  803. ID: "openai",
  804. APIKey: "$MISSING", // will not be included in the config
  805. DefaultLargeModelID: "large-model",
  806. DefaultSmallModelID: "small-model",
  807. Models: []catwalk.Model{
  808. {
  809. ID: "not-large-model",
  810. DefaultMaxTokens: 1000,
  811. },
  812. {
  813. ID: "small-model",
  814. DefaultMaxTokens: 500,
  815. },
  816. },
  817. },
  818. }
  819. cfg := &Config{
  820. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  821. "custom": {
  822. APIKey: "test-key",
  823. BaseURL: "https://api.custom.com/v1",
  824. Models: []catwalk.Model{
  825. {
  826. ID: "model",
  827. DefaultMaxTokens: 600,
  828. },
  829. },
  830. },
  831. }),
  832. }
  833. cfg.setDefaults("/tmp")
  834. env := env.NewFromMap(map[string]string{})
  835. resolver := NewEnvironmentVariableResolver(env)
  836. err := cfg.configureProviders(env, resolver, knownProviders)
  837. require.NoError(t, err)
  838. large, small, err := cfg.defaultModelSelection(knownProviders)
  839. require.NoError(t, err)
  840. require.Equal(t, "model", large.Model)
  841. require.Equal(t, "custom", large.Provider)
  842. require.Equal(t, int64(600), large.MaxTokens)
  843. require.Equal(t, "model", small.Model)
  844. require.Equal(t, "custom", small.Provider)
  845. require.Equal(t, int64(600), small.MaxTokens)
  846. })
  847. t.Run("should fail if no model configured", func(t *testing.T) {
  848. knownProviders := []catwalk.Provider{
  849. {
  850. ID: "openai",
  851. APIKey: "$MISSING", // will not be included in the config
  852. DefaultLargeModelID: "large-model",
  853. DefaultSmallModelID: "small-model",
  854. Models: []catwalk.Model{
  855. {
  856. ID: "not-large-model",
  857. DefaultMaxTokens: 1000,
  858. },
  859. {
  860. ID: "small-model",
  861. DefaultMaxTokens: 500,
  862. },
  863. },
  864. },
  865. }
  866. cfg := &Config{
  867. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  868. "custom": {
  869. APIKey: "test-key",
  870. BaseURL: "https://api.custom.com/v1",
  871. Models: []catwalk.Model{},
  872. },
  873. }),
  874. }
  875. cfg.setDefaults("/tmp")
  876. env := env.NewFromMap(map[string]string{})
  877. resolver := NewEnvironmentVariableResolver(env)
  878. err := cfg.configureProviders(env, resolver, knownProviders)
  879. require.NoError(t, err)
  880. _, _, err = cfg.defaultModelSelection(knownProviders)
  881. require.Error(t, err)
  882. })
  883. t.Run("should use the default provider first", func(t *testing.T) {
  884. knownProviders := []catwalk.Provider{
  885. {
  886. ID: "openai",
  887. APIKey: "set",
  888. DefaultLargeModelID: "large-model",
  889. DefaultSmallModelID: "small-model",
  890. Models: []catwalk.Model{
  891. {
  892. ID: "large-model",
  893. DefaultMaxTokens: 1000,
  894. },
  895. {
  896. ID: "small-model",
  897. DefaultMaxTokens: 500,
  898. },
  899. },
  900. },
  901. }
  902. cfg := &Config{
  903. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  904. "custom": {
  905. APIKey: "test-key",
  906. BaseURL: "https://api.custom.com/v1",
  907. Models: []catwalk.Model{
  908. {
  909. ID: "large-model",
  910. DefaultMaxTokens: 1000,
  911. },
  912. },
  913. },
  914. }),
  915. }
  916. cfg.setDefaults("/tmp")
  917. env := env.NewFromMap(map[string]string{})
  918. resolver := NewEnvironmentVariableResolver(env)
  919. err := cfg.configureProviders(env, resolver, knownProviders)
  920. require.NoError(t, err)
  921. large, small, err := cfg.defaultModelSelection(knownProviders)
  922. require.NoError(t, err)
  923. require.Equal(t, "large-model", large.Model)
  924. require.Equal(t, "openai", large.Provider)
  925. require.Equal(t, int64(1000), large.MaxTokens)
  926. require.Equal(t, "small-model", small.Model)
  927. require.Equal(t, "openai", small.Provider)
  928. require.Equal(t, int64(500), small.MaxTokens)
  929. })
  930. }
  931. func TestConfig_configureSelectedModels(t *testing.T) {
  932. t.Run("should override defaults", func(t *testing.T) {
  933. knownProviders := []catwalk.Provider{
  934. {
  935. ID: "openai",
  936. APIKey: "abc",
  937. DefaultLargeModelID: "large-model",
  938. DefaultSmallModelID: "small-model",
  939. Models: []catwalk.Model{
  940. {
  941. ID: "larger-model",
  942. DefaultMaxTokens: 2000,
  943. },
  944. {
  945. ID: "large-model",
  946. DefaultMaxTokens: 1000,
  947. },
  948. {
  949. ID: "small-model",
  950. DefaultMaxTokens: 500,
  951. },
  952. },
  953. },
  954. }
  955. cfg := &Config{
  956. Models: map[SelectedModelType]SelectedModel{
  957. "large": {
  958. Model: "larger-model",
  959. },
  960. },
  961. }
  962. cfg.setDefaults("/tmp")
  963. env := env.NewFromMap(map[string]string{})
  964. resolver := NewEnvironmentVariableResolver(env)
  965. err := cfg.configureProviders(env, resolver, knownProviders)
  966. require.NoError(t, err)
  967. err = cfg.configureSelectedModels(knownProviders)
  968. require.NoError(t, err)
  969. large := cfg.Models[SelectedModelTypeLarge]
  970. small := cfg.Models[SelectedModelTypeSmall]
  971. require.Equal(t, "larger-model", large.Model)
  972. require.Equal(t, "openai", large.Provider)
  973. require.Equal(t, int64(2000), large.MaxTokens)
  974. require.Equal(t, "small-model", small.Model)
  975. require.Equal(t, "openai", small.Provider)
  976. require.Equal(t, int64(500), small.MaxTokens)
  977. })
  978. t.Run("should be possible to use multiple providers", func(t *testing.T) {
  979. knownProviders := []catwalk.Provider{
  980. {
  981. ID: "openai",
  982. APIKey: "abc",
  983. DefaultLargeModelID: "large-model",
  984. DefaultSmallModelID: "small-model",
  985. Models: []catwalk.Model{
  986. {
  987. ID: "large-model",
  988. DefaultMaxTokens: 1000,
  989. },
  990. {
  991. ID: "small-model",
  992. DefaultMaxTokens: 500,
  993. },
  994. },
  995. },
  996. {
  997. ID: "anthropic",
  998. APIKey: "abc",
  999. DefaultLargeModelID: "a-large-model",
  1000. DefaultSmallModelID: "a-small-model",
  1001. Models: []catwalk.Model{
  1002. {
  1003. ID: "a-large-model",
  1004. DefaultMaxTokens: 1000,
  1005. },
  1006. {
  1007. ID: "a-small-model",
  1008. DefaultMaxTokens: 200,
  1009. },
  1010. },
  1011. },
  1012. }
  1013. cfg := &Config{
  1014. Models: map[SelectedModelType]SelectedModel{
  1015. "small": {
  1016. Model: "a-small-model",
  1017. Provider: "anthropic",
  1018. MaxTokens: 300,
  1019. },
  1020. },
  1021. }
  1022. cfg.setDefaults("/tmp")
  1023. env := env.NewFromMap(map[string]string{})
  1024. resolver := NewEnvironmentVariableResolver(env)
  1025. err := cfg.configureProviders(env, resolver, knownProviders)
  1026. require.NoError(t, err)
  1027. err = cfg.configureSelectedModels(knownProviders)
  1028. require.NoError(t, err)
  1029. large := cfg.Models[SelectedModelTypeLarge]
  1030. small := cfg.Models[SelectedModelTypeSmall]
  1031. require.Equal(t, "large-model", large.Model)
  1032. require.Equal(t, "openai", large.Provider)
  1033. require.Equal(t, int64(1000), large.MaxTokens)
  1034. require.Equal(t, "a-small-model", small.Model)
  1035. require.Equal(t, "anthropic", small.Provider)
  1036. require.Equal(t, int64(300), small.MaxTokens)
  1037. })
  1038. t.Run("should override the max tokens only", func(t *testing.T) {
  1039. knownProviders := []catwalk.Provider{
  1040. {
  1041. ID: "openai",
  1042. APIKey: "abc",
  1043. DefaultLargeModelID: "large-model",
  1044. DefaultSmallModelID: "small-model",
  1045. Models: []catwalk.Model{
  1046. {
  1047. ID: "large-model",
  1048. DefaultMaxTokens: 1000,
  1049. },
  1050. {
  1051. ID: "small-model",
  1052. DefaultMaxTokens: 500,
  1053. },
  1054. },
  1055. },
  1056. }
  1057. cfg := &Config{
  1058. Models: map[SelectedModelType]SelectedModel{
  1059. "large": {
  1060. MaxTokens: 100,
  1061. },
  1062. },
  1063. }
  1064. cfg.setDefaults("/tmp")
  1065. env := env.NewFromMap(map[string]string{})
  1066. resolver := NewEnvironmentVariableResolver(env)
  1067. err := cfg.configureProviders(env, resolver, knownProviders)
  1068. require.NoError(t, err)
  1069. err = cfg.configureSelectedModels(knownProviders)
  1070. require.NoError(t, err)
  1071. large := cfg.Models[SelectedModelTypeLarge]
  1072. require.Equal(t, "large-model", large.Model)
  1073. require.Equal(t, "openai", large.Provider)
  1074. require.Equal(t, int64(100), large.MaxTokens)
  1075. })
  1076. }