load_test.go 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248
  1. package config
  2. import (
  3. "io"
  4. "log/slog"
  5. "os"
  6. "path/filepath"
  7. "strings"
  8. "testing"
  9. "github.com/charmbracelet/catwalk/pkg/catwalk"
  10. "github.com/charmbracelet/crush/internal/csync"
  11. "github.com/charmbracelet/crush/internal/env"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. )
  15. func TestMain(m *testing.M) {
  16. slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  17. exitVal := m.Run()
  18. os.Exit(exitVal)
  19. }
  20. func TestConfig_LoadFromReaders(t *testing.T) {
  21. data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  22. data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  23. data3 := strings.NewReader(`{"providers": {"openai": {}}}`)
  24. loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3})
  25. require.NoError(t, err)
  26. require.NotNil(t, loadedConfig)
  27. require.Equal(t, 1, loadedConfig.Providers.Len())
  28. pc, _ := loadedConfig.Providers.Get("openai")
  29. require.Equal(t, "key2", pc.APIKey)
  30. require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  31. }
  32. func TestConfig_setDefaults(t *testing.T) {
  33. cfg := &Config{}
  34. cfg.setDefaults("/tmp", "")
  35. require.NotNil(t, cfg.Options)
  36. require.NotNil(t, cfg.Options.TUI)
  37. require.NotNil(t, cfg.Options.ContextPaths)
  38. require.NotNil(t, cfg.Providers)
  39. require.NotNil(t, cfg.Models)
  40. require.NotNil(t, cfg.LSP)
  41. require.NotNil(t, cfg.MCP)
  42. require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  43. for _, path := range defaultContextPaths {
  44. require.Contains(t, cfg.Options.ContextPaths, path)
  45. }
  46. require.Equal(t, "/tmp", cfg.workingDir)
  47. }
  48. func TestConfig_configureProviders(t *testing.T) {
  49. knownProviders := []catwalk.Provider{
  50. {
  51. ID: "openai",
  52. APIKey: "$OPENAI_API_KEY",
  53. APIEndpoint: "https://api.openai.com/v1",
  54. Models: []catwalk.Model{{
  55. ID: "test-model",
  56. }},
  57. },
  58. }
  59. cfg := &Config{}
  60. cfg.setDefaults("/tmp", "")
  61. env := env.NewFromMap(map[string]string{
  62. "OPENAI_API_KEY": "test-key",
  63. })
  64. resolver := NewEnvironmentVariableResolver(env)
  65. err := cfg.configureProviders(env, resolver, knownProviders)
  66. require.NoError(t, err)
  67. require.Equal(t, 1, cfg.Providers.Len())
  68. // We want to make sure that we keep the configured API key as a placeholder
  69. pc, _ := cfg.Providers.Get("openai")
  70. require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  71. }
  72. func TestConfig_configureProvidersWithOverride(t *testing.T) {
  73. knownProviders := []catwalk.Provider{
  74. {
  75. ID: "openai",
  76. APIKey: "$OPENAI_API_KEY",
  77. APIEndpoint: "https://api.openai.com/v1",
  78. Models: []catwalk.Model{{
  79. ID: "test-model",
  80. }},
  81. },
  82. }
  83. cfg := &Config{
  84. Providers: csync.NewMap[string, ProviderConfig](),
  85. }
  86. cfg.Providers.Set("openai", ProviderConfig{
  87. APIKey: "xyz",
  88. BaseURL: "https://api.openai.com/v2",
  89. Models: []catwalk.Model{
  90. {
  91. ID: "test-model",
  92. Name: "Updated",
  93. },
  94. {
  95. ID: "another-model",
  96. },
  97. },
  98. })
  99. cfg.setDefaults("/tmp", "")
  100. env := env.NewFromMap(map[string]string{
  101. "OPENAI_API_KEY": "test-key",
  102. })
  103. resolver := NewEnvironmentVariableResolver(env)
  104. err := cfg.configureProviders(env, resolver, knownProviders)
  105. require.NoError(t, err)
  106. require.Equal(t, 1, cfg.Providers.Len())
  107. // We want to make sure that we keep the configured API key as a placeholder
  108. pc, _ := cfg.Providers.Get("openai")
  109. require.Equal(t, "xyz", pc.APIKey)
  110. require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  111. require.Len(t, pc.Models, 2)
  112. require.Equal(t, "Updated", pc.Models[0].Name)
  113. }
  114. func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
  115. knownProviders := []catwalk.Provider{
  116. {
  117. ID: "openai",
  118. APIKey: "$OPENAI_API_KEY",
  119. APIEndpoint: "https://api.openai.com/v1",
  120. Models: []catwalk.Model{{
  121. ID: "test-model",
  122. }},
  123. },
  124. }
  125. cfg := &Config{
  126. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  127. "custom": {
  128. APIKey: "xyz",
  129. BaseURL: "https://api.someendpoint.com/v2",
  130. Models: []catwalk.Model{
  131. {
  132. ID: "test-model",
  133. },
  134. },
  135. },
  136. }),
  137. }
  138. cfg.setDefaults("/tmp", "")
  139. env := env.NewFromMap(map[string]string{
  140. "OPENAI_API_KEY": "test-key",
  141. })
  142. resolver := NewEnvironmentVariableResolver(env)
  143. err := cfg.configureProviders(env, resolver, knownProviders)
  144. require.NoError(t, err)
  145. // Should be to because of the env variable
  146. require.Equal(t, cfg.Providers.Len(), 2)
  147. // We want to make sure that we keep the configured API key as a placeholder
  148. pc, _ := cfg.Providers.Get("custom")
  149. require.Equal(t, "xyz", pc.APIKey)
  150. // Make sure we set the ID correctly
  151. require.Equal(t, "custom", pc.ID)
  152. require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
  153. require.Len(t, pc.Models, 1)
  154. _, ok := cfg.Providers.Get("openai")
  155. require.True(t, ok, "OpenAI provider should still be present")
  156. }
  157. func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
  158. knownProviders := []catwalk.Provider{
  159. {
  160. ID: catwalk.InferenceProviderBedrock,
  161. APIKey: "",
  162. APIEndpoint: "",
  163. Models: []catwalk.Model{{
  164. ID: "anthropic.claude-sonnet-4-20250514-v1:0",
  165. }},
  166. },
  167. }
  168. cfg := &Config{}
  169. cfg.setDefaults("/tmp", "")
  170. env := env.NewFromMap(map[string]string{
  171. "AWS_ACCESS_KEY_ID": "test-key-id",
  172. "AWS_SECRET_ACCESS_KEY": "test-secret-key",
  173. })
  174. resolver := NewEnvironmentVariableResolver(env)
  175. err := cfg.configureProviders(env, resolver, knownProviders)
  176. require.NoError(t, err)
  177. require.Equal(t, cfg.Providers.Len(), 1)
  178. bedrockProvider, ok := cfg.Providers.Get("bedrock")
  179. require.True(t, ok, "Bedrock provider should be present")
  180. require.Len(t, bedrockProvider.Models, 1)
  181. require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
  182. }
  183. func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
  184. knownProviders := []catwalk.Provider{
  185. {
  186. ID: catwalk.InferenceProviderBedrock,
  187. APIKey: "",
  188. APIEndpoint: "",
  189. Models: []catwalk.Model{{
  190. ID: "anthropic.claude-sonnet-4-20250514-v1:0",
  191. }},
  192. },
  193. }
  194. cfg := &Config{}
  195. cfg.setDefaults("/tmp", "")
  196. env := env.NewFromMap(map[string]string{})
  197. resolver := NewEnvironmentVariableResolver(env)
  198. err := cfg.configureProviders(env, resolver, knownProviders)
  199. require.NoError(t, err)
  200. // Provider should not be configured without credentials
  201. require.Equal(t, cfg.Providers.Len(), 0)
  202. }
  203. func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
  204. knownProviders := []catwalk.Provider{
  205. {
  206. ID: catwalk.InferenceProviderBedrock,
  207. APIKey: "",
  208. APIEndpoint: "",
  209. Models: []catwalk.Model{{
  210. ID: "some-random-model",
  211. }},
  212. },
  213. }
  214. cfg := &Config{}
  215. cfg.setDefaults("/tmp", "")
  216. env := env.NewFromMap(map[string]string{
  217. "AWS_ACCESS_KEY_ID": "test-key-id",
  218. "AWS_SECRET_ACCESS_KEY": "test-secret-key",
  219. })
  220. resolver := NewEnvironmentVariableResolver(env)
  221. err := cfg.configureProviders(env, resolver, knownProviders)
  222. require.Error(t, err)
  223. }
  224. func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
  225. knownProviders := []catwalk.Provider{
  226. {
  227. ID: catwalk.InferenceProviderVertexAI,
  228. APIKey: "",
  229. APIEndpoint: "",
  230. Models: []catwalk.Model{{
  231. ID: "gemini-pro",
  232. }},
  233. },
  234. }
  235. cfg := &Config{}
  236. cfg.setDefaults("/tmp", "")
  237. env := env.NewFromMap(map[string]string{
  238. "VERTEXAI_PROJECT": "test-project",
  239. "VERTEXAI_LOCATION": "us-central1",
  240. })
  241. resolver := NewEnvironmentVariableResolver(env)
  242. err := cfg.configureProviders(env, resolver, knownProviders)
  243. require.NoError(t, err)
  244. require.Equal(t, cfg.Providers.Len(), 1)
  245. vertexProvider, ok := cfg.Providers.Get("vertexai")
  246. require.True(t, ok, "VertexAI provider should be present")
  247. require.Len(t, vertexProvider.Models, 1)
  248. require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
  249. require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
  250. require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
  251. }
  252. func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
  253. knownProviders := []catwalk.Provider{
  254. {
  255. ID: catwalk.InferenceProviderVertexAI,
  256. APIKey: "",
  257. APIEndpoint: "",
  258. Models: []catwalk.Model{{
  259. ID: "gemini-pro",
  260. }},
  261. },
  262. }
  263. cfg := &Config{}
  264. cfg.setDefaults("/tmp", "")
  265. env := env.NewFromMap(map[string]string{
  266. "GOOGLE_GENAI_USE_VERTEXAI": "false",
  267. "GOOGLE_CLOUD_PROJECT": "test-project",
  268. "GOOGLE_CLOUD_LOCATION": "us-central1",
  269. })
  270. resolver := NewEnvironmentVariableResolver(env)
  271. err := cfg.configureProviders(env, resolver, knownProviders)
  272. require.NoError(t, err)
  273. // Provider should not be configured without proper credentials
  274. require.Equal(t, cfg.Providers.Len(), 0)
  275. }
  276. func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
  277. knownProviders := []catwalk.Provider{
  278. {
  279. ID: catwalk.InferenceProviderVertexAI,
  280. APIKey: "",
  281. APIEndpoint: "",
  282. Models: []catwalk.Model{{
  283. ID: "gemini-pro",
  284. }},
  285. },
  286. }
  287. cfg := &Config{}
  288. cfg.setDefaults("/tmp", "")
  289. env := env.NewFromMap(map[string]string{
  290. "GOOGLE_GENAI_USE_VERTEXAI": "true",
  291. "GOOGLE_CLOUD_LOCATION": "us-central1",
  292. })
  293. resolver := NewEnvironmentVariableResolver(env)
  294. err := cfg.configureProviders(env, resolver, knownProviders)
  295. require.NoError(t, err)
  296. // Provider should not be configured without project
  297. require.Equal(t, cfg.Providers.Len(), 0)
  298. }
  299. func TestConfig_configureProvidersSetProviderID(t *testing.T) {
  300. knownProviders := []catwalk.Provider{
  301. {
  302. ID: "openai",
  303. APIKey: "$OPENAI_API_KEY",
  304. APIEndpoint: "https://api.openai.com/v1",
  305. Models: []catwalk.Model{{
  306. ID: "test-model",
  307. }},
  308. },
  309. }
  310. cfg := &Config{}
  311. cfg.setDefaults("/tmp", "")
  312. env := env.NewFromMap(map[string]string{
  313. "OPENAI_API_KEY": "test-key",
  314. })
  315. resolver := NewEnvironmentVariableResolver(env)
  316. err := cfg.configureProviders(env, resolver, knownProviders)
  317. require.NoError(t, err)
  318. require.Equal(t, cfg.Providers.Len(), 1)
  319. // Provider ID should be set
  320. pc, _ := cfg.Providers.Get("openai")
  321. require.Equal(t, "openai", pc.ID)
  322. }
  323. func TestConfig_EnabledProviders(t *testing.T) {
  324. t.Run("all providers enabled", func(t *testing.T) {
  325. cfg := &Config{
  326. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  327. "openai": {
  328. ID: "openai",
  329. APIKey: "key1",
  330. Disable: false,
  331. },
  332. "anthropic": {
  333. ID: "anthropic",
  334. APIKey: "key2",
  335. Disable: false,
  336. },
  337. }),
  338. }
  339. enabled := cfg.EnabledProviders()
  340. require.Len(t, enabled, 2)
  341. })
  342. t.Run("some providers disabled", func(t *testing.T) {
  343. cfg := &Config{
  344. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  345. "openai": {
  346. ID: "openai",
  347. APIKey: "key1",
  348. Disable: false,
  349. },
  350. "anthropic": {
  351. ID: "anthropic",
  352. APIKey: "key2",
  353. Disable: true,
  354. },
  355. }),
  356. }
  357. enabled := cfg.EnabledProviders()
  358. require.Len(t, enabled, 1)
  359. require.Equal(t, "openai", enabled[0].ID)
  360. })
  361. t.Run("empty providers map", func(t *testing.T) {
  362. cfg := &Config{
  363. Providers: csync.NewMap[string, ProviderConfig](),
  364. }
  365. enabled := cfg.EnabledProviders()
  366. require.Len(t, enabled, 0)
  367. })
  368. }
  369. func TestConfig_IsConfigured(t *testing.T) {
  370. t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
  371. cfg := &Config{
  372. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  373. "openai": {
  374. ID: "openai",
  375. APIKey: "key1",
  376. Disable: false,
  377. },
  378. }),
  379. }
  380. require.True(t, cfg.IsConfigured())
  381. })
  382. t.Run("returns false when no providers are configured", func(t *testing.T) {
  383. cfg := &Config{
  384. Providers: csync.NewMap[string, ProviderConfig](),
  385. }
  386. require.False(t, cfg.IsConfigured())
  387. })
  388. t.Run("returns false when all providers are disabled", func(t *testing.T) {
  389. cfg := &Config{
  390. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  391. "openai": {
  392. ID: "openai",
  393. APIKey: "key1",
  394. Disable: true,
  395. },
  396. "anthropic": {
  397. ID: "anthropic",
  398. APIKey: "key2",
  399. Disable: true,
  400. },
  401. }),
  402. }
  403. require.False(t, cfg.IsConfigured())
  404. })
  405. }
  406. func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
  407. cfg := &Config{
  408. Options: &Options{
  409. DisabledTools: []string{},
  410. },
  411. }
  412. cfg.SetupAgents()
  413. coderAgent, ok := cfg.Agents[AgentCoder]
  414. require.True(t, ok)
  415. assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
  416. taskAgent, ok := cfg.Agents[AgentTask]
  417. require.True(t, ok)
  418. assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
  419. }
  420. func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
  421. cfg := &Config{
  422. Options: &Options{
  423. DisabledTools: []string{
  424. "edit",
  425. "download",
  426. "grep",
  427. },
  428. },
  429. }
  430. cfg.SetupAgents()
  431. coderAgent, ok := cfg.Agents[AgentCoder]
  432. require.True(t, ok)
  433. assert.Equal(t, []string{"agent", "bash", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools)
  434. taskAgent, ok := cfg.Agents[AgentTask]
  435. require.True(t, ok)
  436. assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
  437. }
  438. func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
  439. cfg := &Config{
  440. Options: &Options{
  441. DisabledTools: []string{
  442. "glob",
  443. "grep",
  444. "ls",
  445. "sourcegraph",
  446. "view",
  447. },
  448. },
  449. }
  450. cfg.SetupAgents()
  451. coderAgent, ok := cfg.Agents[AgentCoder]
  452. require.True(t, ok)
  453. assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "write"}, coderAgent.AllowedTools)
  454. taskAgent, ok := cfg.Agents[AgentTask]
  455. require.True(t, ok)
  456. assert.Equal(t, []string{}, taskAgent.AllowedTools)
  457. }
  458. func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
  459. knownProviders := []catwalk.Provider{
  460. {
  461. ID: "openai",
  462. APIKey: "$OPENAI_API_KEY",
  463. APIEndpoint: "https://api.openai.com/v1",
  464. Models: []catwalk.Model{{
  465. ID: "test-model",
  466. }},
  467. },
  468. }
  469. cfg := &Config{
  470. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  471. "openai": {
  472. Disable: true,
  473. },
  474. }),
  475. }
  476. cfg.setDefaults("/tmp", "")
  477. env := env.NewFromMap(map[string]string{
  478. "OPENAI_API_KEY": "test-key",
  479. })
  480. resolver := NewEnvironmentVariableResolver(env)
  481. err := cfg.configureProviders(env, resolver, knownProviders)
  482. require.NoError(t, err)
  483. require.Equal(t, cfg.Providers.Len(), 1)
  484. prov, exists := cfg.Providers.Get("openai")
  485. require.True(t, exists)
  486. require.True(t, prov.Disable)
  487. }
  488. func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
  489. t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
  490. cfg := &Config{
  491. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  492. "custom": {
  493. BaseURL: "https://api.custom.com/v1",
  494. Models: []catwalk.Model{{
  495. ID: "test-model",
  496. }},
  497. },
  498. "openai": {
  499. APIKey: "$MISSING",
  500. },
  501. }),
  502. }
  503. cfg.setDefaults("/tmp", "")
  504. env := env.NewFromMap(map[string]string{})
  505. resolver := NewEnvironmentVariableResolver(env)
  506. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  507. require.NoError(t, err)
  508. require.Equal(t, cfg.Providers.Len(), 1)
  509. _, exists := cfg.Providers.Get("custom")
  510. require.True(t, exists)
  511. })
  512. t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
  513. cfg := &Config{
  514. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  515. "custom": {
  516. APIKey: "test-key",
  517. Models: []catwalk.Model{{
  518. ID: "test-model",
  519. }},
  520. },
  521. }),
  522. }
  523. cfg.setDefaults("/tmp", "")
  524. env := env.NewFromMap(map[string]string{})
  525. resolver := NewEnvironmentVariableResolver(env)
  526. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  527. require.NoError(t, err)
  528. require.Equal(t, cfg.Providers.Len(), 0)
  529. _, exists := cfg.Providers.Get("custom")
  530. require.False(t, exists)
  531. })
  532. t.Run("custom provider with no models is removed", func(t *testing.T) {
  533. cfg := &Config{
  534. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  535. "custom": {
  536. APIKey: "test-key",
  537. BaseURL: "https://api.custom.com/v1",
  538. Models: []catwalk.Model{},
  539. },
  540. }),
  541. }
  542. cfg.setDefaults("/tmp", "")
  543. env := env.NewFromMap(map[string]string{})
  544. resolver := NewEnvironmentVariableResolver(env)
  545. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  546. require.NoError(t, err)
  547. require.Equal(t, cfg.Providers.Len(), 0)
  548. _, exists := cfg.Providers.Get("custom")
  549. require.False(t, exists)
  550. })
  551. t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
  552. cfg := &Config{
  553. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  554. "custom": {
  555. APIKey: "test-key",
  556. BaseURL: "https://api.custom.com/v1",
  557. Type: "unsupported",
  558. Models: []catwalk.Model{{
  559. ID: "test-model",
  560. }},
  561. },
  562. }),
  563. }
  564. cfg.setDefaults("/tmp", "")
  565. env := env.NewFromMap(map[string]string{})
  566. resolver := NewEnvironmentVariableResolver(env)
  567. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  568. require.NoError(t, err)
  569. require.Equal(t, cfg.Providers.Len(), 0)
  570. _, exists := cfg.Providers.Get("custom")
  571. require.False(t, exists)
  572. })
  573. t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
  574. cfg := &Config{
  575. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  576. "custom": {
  577. APIKey: "test-key",
  578. BaseURL: "https://api.custom.com/v1",
  579. Type: catwalk.TypeOpenAI,
  580. Models: []catwalk.Model{{
  581. ID: "test-model",
  582. }},
  583. },
  584. }),
  585. }
  586. cfg.setDefaults("/tmp", "")
  587. env := env.NewFromMap(map[string]string{})
  588. resolver := NewEnvironmentVariableResolver(env)
  589. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  590. require.NoError(t, err)
  591. require.Equal(t, cfg.Providers.Len(), 1)
  592. customProvider, exists := cfg.Providers.Get("custom")
  593. require.True(t, exists)
  594. require.Equal(t, "custom", customProvider.ID)
  595. require.Equal(t, "test-key", customProvider.APIKey)
  596. require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
  597. })
  598. t.Run("custom anthropic provider is supported", func(t *testing.T) {
  599. cfg := &Config{
  600. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  601. "custom-anthropic": {
  602. APIKey: "test-key",
  603. BaseURL: "https://api.anthropic.com/v1",
  604. Type: catwalk.TypeAnthropic,
  605. Models: []catwalk.Model{{
  606. ID: "claude-3-sonnet",
  607. }},
  608. },
  609. }),
  610. }
  611. cfg.setDefaults("/tmp", "")
  612. env := env.NewFromMap(map[string]string{})
  613. resolver := NewEnvironmentVariableResolver(env)
  614. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  615. require.NoError(t, err)
  616. require.Equal(t, cfg.Providers.Len(), 1)
  617. customProvider, exists := cfg.Providers.Get("custom-anthropic")
  618. require.True(t, exists)
  619. require.Equal(t, "custom-anthropic", customProvider.ID)
  620. require.Equal(t, "test-key", customProvider.APIKey)
  621. require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
  622. require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
  623. })
  624. t.Run("disabled custom provider is removed", func(t *testing.T) {
  625. cfg := &Config{
  626. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  627. "custom": {
  628. APIKey: "test-key",
  629. BaseURL: "https://api.custom.com/v1",
  630. Type: catwalk.TypeOpenAI,
  631. Disable: true,
  632. Models: []catwalk.Model{{
  633. ID: "test-model",
  634. }},
  635. },
  636. }),
  637. }
  638. cfg.setDefaults("/tmp", "")
  639. env := env.NewFromMap(map[string]string{})
  640. resolver := NewEnvironmentVariableResolver(env)
  641. err := cfg.configureProviders(env, resolver, []catwalk.Provider{})
  642. require.NoError(t, err)
  643. require.Equal(t, cfg.Providers.Len(), 0)
  644. _, exists := cfg.Providers.Get("custom")
  645. require.False(t, exists)
  646. })
  647. }
  648. func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
  649. t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
  650. knownProviders := []catwalk.Provider{
  651. {
  652. ID: catwalk.InferenceProviderVertexAI,
  653. APIKey: "",
  654. APIEndpoint: "",
  655. Models: []catwalk.Model{{
  656. ID: "gemini-pro",
  657. }},
  658. },
  659. }
  660. cfg := &Config{
  661. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  662. "vertexai": {
  663. BaseURL: "custom-url",
  664. },
  665. }),
  666. }
  667. cfg.setDefaults("/tmp", "")
  668. env := env.NewFromMap(map[string]string{
  669. "GOOGLE_GENAI_USE_VERTEXAI": "false",
  670. })
  671. resolver := NewEnvironmentVariableResolver(env)
  672. err := cfg.configureProviders(env, resolver, knownProviders)
  673. require.NoError(t, err)
  674. require.Equal(t, cfg.Providers.Len(), 0)
  675. _, exists := cfg.Providers.Get("vertexai")
  676. require.False(t, exists)
  677. })
  678. t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
  679. knownProviders := []catwalk.Provider{
  680. {
  681. ID: catwalk.InferenceProviderBedrock,
  682. APIKey: "",
  683. APIEndpoint: "",
  684. Models: []catwalk.Model{{
  685. ID: "anthropic.claude-sonnet-4-20250514-v1:0",
  686. }},
  687. },
  688. }
  689. cfg := &Config{
  690. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  691. "bedrock": {
  692. BaseURL: "custom-url",
  693. },
  694. }),
  695. }
  696. cfg.setDefaults("/tmp", "")
  697. env := env.NewFromMap(map[string]string{})
  698. resolver := NewEnvironmentVariableResolver(env)
  699. err := cfg.configureProviders(env, resolver, knownProviders)
  700. require.NoError(t, err)
  701. require.Equal(t, cfg.Providers.Len(), 0)
  702. _, exists := cfg.Providers.Get("bedrock")
  703. require.False(t, exists)
  704. })
  705. t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
  706. knownProviders := []catwalk.Provider{
  707. {
  708. ID: "openai",
  709. APIKey: "$MISSING_API_KEY",
  710. APIEndpoint: "https://api.openai.com/v1",
  711. Models: []catwalk.Model{{
  712. ID: "test-model",
  713. }},
  714. },
  715. }
  716. cfg := &Config{
  717. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  718. "openai": {
  719. BaseURL: "custom-url",
  720. },
  721. }),
  722. }
  723. cfg.setDefaults("/tmp", "")
  724. env := env.NewFromMap(map[string]string{})
  725. resolver := NewEnvironmentVariableResolver(env)
  726. err := cfg.configureProviders(env, resolver, knownProviders)
  727. require.NoError(t, err)
  728. require.Equal(t, cfg.Providers.Len(), 0)
  729. _, exists := cfg.Providers.Get("openai")
  730. require.False(t, exists)
  731. })
  732. t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
  733. knownProviders := []catwalk.Provider{
  734. {
  735. ID: "openai",
  736. APIKey: "$OPENAI_API_KEY",
  737. APIEndpoint: "$MISSING_ENDPOINT",
  738. Models: []catwalk.Model{{
  739. ID: "test-model",
  740. }},
  741. },
  742. }
  743. cfg := &Config{
  744. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  745. "openai": {
  746. APIKey: "test-key",
  747. },
  748. }),
  749. }
  750. cfg.setDefaults("/tmp", "")
  751. env := env.NewFromMap(map[string]string{
  752. "OPENAI_API_KEY": "test-key",
  753. })
  754. resolver := NewEnvironmentVariableResolver(env)
  755. err := cfg.configureProviders(env, resolver, knownProviders)
  756. require.NoError(t, err)
  757. require.Equal(t, cfg.Providers.Len(), 1)
  758. _, exists := cfg.Providers.Get("openai")
  759. require.True(t, exists)
  760. })
  761. }
  762. func TestConfig_defaultModelSelection(t *testing.T) {
  763. t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
  764. knownProviders := []catwalk.Provider{
  765. {
  766. ID: "openai",
  767. APIKey: "abc",
  768. DefaultLargeModelID: "large-model",
  769. DefaultSmallModelID: "small-model",
  770. Models: []catwalk.Model{
  771. {
  772. ID: "large-model",
  773. DefaultMaxTokens: 1000,
  774. },
  775. {
  776. ID: "small-model",
  777. DefaultMaxTokens: 500,
  778. },
  779. },
  780. },
  781. }
  782. cfg := &Config{}
  783. cfg.setDefaults("/tmp", "")
  784. env := env.NewFromMap(map[string]string{})
  785. resolver := NewEnvironmentVariableResolver(env)
  786. err := cfg.configureProviders(env, resolver, knownProviders)
  787. require.NoError(t, err)
  788. large, small, err := cfg.defaultModelSelection(knownProviders)
  789. require.NoError(t, err)
  790. require.Equal(t, "large-model", large.Model)
  791. require.Equal(t, "openai", large.Provider)
  792. require.Equal(t, int64(1000), large.MaxTokens)
  793. require.Equal(t, "small-model", small.Model)
  794. require.Equal(t, "openai", small.Provider)
  795. require.Equal(t, int64(500), small.MaxTokens)
  796. })
  797. t.Run("should error if no providers configured", func(t *testing.T) {
  798. knownProviders := []catwalk.Provider{
  799. {
  800. ID: "openai",
  801. APIKey: "$MISSING_KEY",
  802. DefaultLargeModelID: "large-model",
  803. DefaultSmallModelID: "small-model",
  804. Models: []catwalk.Model{
  805. {
  806. ID: "large-model",
  807. DefaultMaxTokens: 1000,
  808. },
  809. {
  810. ID: "small-model",
  811. DefaultMaxTokens: 500,
  812. },
  813. },
  814. },
  815. }
  816. cfg := &Config{}
  817. cfg.setDefaults("/tmp", "")
  818. env := env.NewFromMap(map[string]string{})
  819. resolver := NewEnvironmentVariableResolver(env)
  820. err := cfg.configureProviders(env, resolver, knownProviders)
  821. require.NoError(t, err)
  822. _, _, err = cfg.defaultModelSelection(knownProviders)
  823. require.Error(t, err)
  824. })
  825. t.Run("should error if model is missing", func(t *testing.T) {
  826. knownProviders := []catwalk.Provider{
  827. {
  828. ID: "openai",
  829. APIKey: "abc",
  830. DefaultLargeModelID: "large-model",
  831. DefaultSmallModelID: "small-model",
  832. Models: []catwalk.Model{
  833. {
  834. ID: "not-large-model",
  835. DefaultMaxTokens: 1000,
  836. },
  837. {
  838. ID: "small-model",
  839. DefaultMaxTokens: 500,
  840. },
  841. },
  842. },
  843. }
  844. cfg := &Config{}
  845. cfg.setDefaults("/tmp", "")
  846. env := env.NewFromMap(map[string]string{})
  847. resolver := NewEnvironmentVariableResolver(env)
  848. err := cfg.configureProviders(env, resolver, knownProviders)
  849. require.NoError(t, err)
  850. _, _, err = cfg.defaultModelSelection(knownProviders)
  851. require.Error(t, err)
  852. })
  853. t.Run("should configure the default models with a custom provider", func(t *testing.T) {
  854. knownProviders := []catwalk.Provider{
  855. {
  856. ID: "openai",
  857. APIKey: "$MISSING", // will not be included in the config
  858. DefaultLargeModelID: "large-model",
  859. DefaultSmallModelID: "small-model",
  860. Models: []catwalk.Model{
  861. {
  862. ID: "not-large-model",
  863. DefaultMaxTokens: 1000,
  864. },
  865. {
  866. ID: "small-model",
  867. DefaultMaxTokens: 500,
  868. },
  869. },
  870. },
  871. }
  872. cfg := &Config{
  873. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  874. "custom": {
  875. APIKey: "test-key",
  876. BaseURL: "https://api.custom.com/v1",
  877. Models: []catwalk.Model{
  878. {
  879. ID: "model",
  880. DefaultMaxTokens: 600,
  881. },
  882. },
  883. },
  884. }),
  885. }
  886. cfg.setDefaults("/tmp", "")
  887. env := env.NewFromMap(map[string]string{})
  888. resolver := NewEnvironmentVariableResolver(env)
  889. err := cfg.configureProviders(env, resolver, knownProviders)
  890. require.NoError(t, err)
  891. large, small, err := cfg.defaultModelSelection(knownProviders)
  892. require.NoError(t, err)
  893. require.Equal(t, "model", large.Model)
  894. require.Equal(t, "custom", large.Provider)
  895. require.Equal(t, int64(600), large.MaxTokens)
  896. require.Equal(t, "model", small.Model)
  897. require.Equal(t, "custom", small.Provider)
  898. require.Equal(t, int64(600), small.MaxTokens)
  899. })
  900. t.Run("should fail if no model configured", func(t *testing.T) {
  901. knownProviders := []catwalk.Provider{
  902. {
  903. ID: "openai",
  904. APIKey: "$MISSING", // will not be included in the config
  905. DefaultLargeModelID: "large-model",
  906. DefaultSmallModelID: "small-model",
  907. Models: []catwalk.Model{
  908. {
  909. ID: "not-large-model",
  910. DefaultMaxTokens: 1000,
  911. },
  912. {
  913. ID: "small-model",
  914. DefaultMaxTokens: 500,
  915. },
  916. },
  917. },
  918. }
  919. cfg := &Config{
  920. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  921. "custom": {
  922. APIKey: "test-key",
  923. BaseURL: "https://api.custom.com/v1",
  924. Models: []catwalk.Model{},
  925. },
  926. }),
  927. }
  928. cfg.setDefaults("/tmp", "")
  929. env := env.NewFromMap(map[string]string{})
  930. resolver := NewEnvironmentVariableResolver(env)
  931. err := cfg.configureProviders(env, resolver, knownProviders)
  932. require.NoError(t, err)
  933. _, _, err = cfg.defaultModelSelection(knownProviders)
  934. require.Error(t, err)
  935. })
  936. t.Run("should use the default provider first", func(t *testing.T) {
  937. knownProviders := []catwalk.Provider{
  938. {
  939. ID: "openai",
  940. APIKey: "set",
  941. DefaultLargeModelID: "large-model",
  942. DefaultSmallModelID: "small-model",
  943. Models: []catwalk.Model{
  944. {
  945. ID: "large-model",
  946. DefaultMaxTokens: 1000,
  947. },
  948. {
  949. ID: "small-model",
  950. DefaultMaxTokens: 500,
  951. },
  952. },
  953. },
  954. }
  955. cfg := &Config{
  956. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  957. "custom": {
  958. APIKey: "test-key",
  959. BaseURL: "https://api.custom.com/v1",
  960. Models: []catwalk.Model{
  961. {
  962. ID: "large-model",
  963. DefaultMaxTokens: 1000,
  964. },
  965. },
  966. },
  967. }),
  968. }
  969. cfg.setDefaults("/tmp", "")
  970. env := env.NewFromMap(map[string]string{})
  971. resolver := NewEnvironmentVariableResolver(env)
  972. err := cfg.configureProviders(env, resolver, knownProviders)
  973. require.NoError(t, err)
  974. large, small, err := cfg.defaultModelSelection(knownProviders)
  975. require.NoError(t, err)
  976. require.Equal(t, "large-model", large.Model)
  977. require.Equal(t, "openai", large.Provider)
  978. require.Equal(t, int64(1000), large.MaxTokens)
  979. require.Equal(t, "small-model", small.Model)
  980. require.Equal(t, "openai", small.Provider)
  981. require.Equal(t, int64(500), small.MaxTokens)
  982. })
  983. }
  984. func TestConfig_configureSelectedModels(t *testing.T) {
  985. t.Run("should override defaults", func(t *testing.T) {
  986. knownProviders := []catwalk.Provider{
  987. {
  988. ID: "openai",
  989. APIKey: "abc",
  990. DefaultLargeModelID: "large-model",
  991. DefaultSmallModelID: "small-model",
  992. Models: []catwalk.Model{
  993. {
  994. ID: "larger-model",
  995. DefaultMaxTokens: 2000,
  996. },
  997. {
  998. ID: "large-model",
  999. DefaultMaxTokens: 1000,
  1000. },
  1001. {
  1002. ID: "small-model",
  1003. DefaultMaxTokens: 500,
  1004. },
  1005. },
  1006. },
  1007. }
  1008. cfg := &Config{
  1009. Models: map[SelectedModelType]SelectedModel{
  1010. "large": {
  1011. Model: "larger-model",
  1012. },
  1013. },
  1014. }
  1015. cfg.setDefaults("/tmp", "")
  1016. env := env.NewFromMap(map[string]string{})
  1017. resolver := NewEnvironmentVariableResolver(env)
  1018. err := cfg.configureProviders(env, resolver, knownProviders)
  1019. require.NoError(t, err)
  1020. err = cfg.configureSelectedModels(knownProviders)
  1021. require.NoError(t, err)
  1022. large := cfg.Models[SelectedModelTypeLarge]
  1023. small := cfg.Models[SelectedModelTypeSmall]
  1024. require.Equal(t, "larger-model", large.Model)
  1025. require.Equal(t, "openai", large.Provider)
  1026. require.Equal(t, int64(2000), large.MaxTokens)
  1027. require.Equal(t, "small-model", small.Model)
  1028. require.Equal(t, "openai", small.Provider)
  1029. require.Equal(t, int64(500), small.MaxTokens)
  1030. })
  1031. t.Run("should be possible to use multiple providers", func(t *testing.T) {
  1032. knownProviders := []catwalk.Provider{
  1033. {
  1034. ID: "openai",
  1035. APIKey: "abc",
  1036. DefaultLargeModelID: "large-model",
  1037. DefaultSmallModelID: "small-model",
  1038. Models: []catwalk.Model{
  1039. {
  1040. ID: "large-model",
  1041. DefaultMaxTokens: 1000,
  1042. },
  1043. {
  1044. ID: "small-model",
  1045. DefaultMaxTokens: 500,
  1046. },
  1047. },
  1048. },
  1049. {
  1050. ID: "anthropic",
  1051. APIKey: "abc",
  1052. DefaultLargeModelID: "a-large-model",
  1053. DefaultSmallModelID: "a-small-model",
  1054. Models: []catwalk.Model{
  1055. {
  1056. ID: "a-large-model",
  1057. DefaultMaxTokens: 1000,
  1058. },
  1059. {
  1060. ID: "a-small-model",
  1061. DefaultMaxTokens: 200,
  1062. },
  1063. },
  1064. },
  1065. }
  1066. cfg := &Config{
  1067. Models: map[SelectedModelType]SelectedModel{
  1068. "small": {
  1069. Model: "a-small-model",
  1070. Provider: "anthropic",
  1071. MaxTokens: 300,
  1072. },
  1073. },
  1074. }
  1075. cfg.setDefaults("/tmp", "")
  1076. env := env.NewFromMap(map[string]string{})
  1077. resolver := NewEnvironmentVariableResolver(env)
  1078. err := cfg.configureProviders(env, resolver, knownProviders)
  1079. require.NoError(t, err)
  1080. err = cfg.configureSelectedModels(knownProviders)
  1081. require.NoError(t, err)
  1082. large := cfg.Models[SelectedModelTypeLarge]
  1083. small := cfg.Models[SelectedModelTypeSmall]
  1084. require.Equal(t, "large-model", large.Model)
  1085. require.Equal(t, "openai", large.Provider)
  1086. require.Equal(t, int64(1000), large.MaxTokens)
  1087. require.Equal(t, "a-small-model", small.Model)
  1088. require.Equal(t, "anthropic", small.Provider)
  1089. require.Equal(t, int64(300), small.MaxTokens)
  1090. })
  1091. t.Run("should override the max tokens only", func(t *testing.T) {
  1092. knownProviders := []catwalk.Provider{
  1093. {
  1094. ID: "openai",
  1095. APIKey: "abc",
  1096. DefaultLargeModelID: "large-model",
  1097. DefaultSmallModelID: "small-model",
  1098. Models: []catwalk.Model{
  1099. {
  1100. ID: "large-model",
  1101. DefaultMaxTokens: 1000,
  1102. },
  1103. {
  1104. ID: "small-model",
  1105. DefaultMaxTokens: 500,
  1106. },
  1107. },
  1108. },
  1109. }
  1110. cfg := &Config{
  1111. Models: map[SelectedModelType]SelectedModel{
  1112. "large": {
  1113. MaxTokens: 100,
  1114. },
  1115. },
  1116. }
  1117. cfg.setDefaults("/tmp", "")
  1118. env := env.NewFromMap(map[string]string{})
  1119. resolver := NewEnvironmentVariableResolver(env)
  1120. err := cfg.configureProviders(env, resolver, knownProviders)
  1121. require.NoError(t, err)
  1122. err = cfg.configureSelectedModels(knownProviders)
  1123. require.NoError(t, err)
  1124. large := cfg.Models[SelectedModelTypeLarge]
  1125. require.Equal(t, "large-model", large.Model)
  1126. require.Equal(t, "openai", large.Provider)
  1127. require.Equal(t, int64(100), large.MaxTokens)
  1128. })
  1129. }