load_test.go 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464
  1. package config
  2. import (
  3. "io"
  4. "log/slog"
  5. "os"
  6. "path/filepath"
  7. "testing"
  8. "charm.land/catwalk/pkg/catwalk"
  9. "github.com/charmbracelet/crush/internal/csync"
  10. "github.com/charmbracelet/crush/internal/env"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func TestMain(m *testing.M) {
  15. slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil)))
  16. exitVal := m.Run()
  17. os.Exit(exitVal)
  18. }
  19. func TestConfig_LoadFromBytes(t *testing.T) {
  20. data1 := []byte(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`)
  21. data2 := []byte(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`)
  22. data3 := []byte(`{"providers": {"openai": {}}}`)
  23. loadedConfig, err := loadFromBytes([][]byte{data1, data2, data3})
  24. require.NoError(t, err)
  25. require.NotNil(t, loadedConfig)
  26. require.Equal(t, 1, loadedConfig.Providers.Len())
  27. pc, _ := loadedConfig.Providers.Get("openai")
  28. require.Equal(t, "key2", pc.APIKey)
  29. require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  30. }
  31. // testStore wraps a Config in a minimal ConfigStore for testing.
  32. func testStore(cfg *Config) *ConfigStore {
  33. return &ConfigStore{config: cfg}
  34. }
  35. func TestConfig_setDefaults(t *testing.T) {
  36. cfg := &Config{}
  37. cfg.setDefaults("/tmp", "")
  38. require.NotNil(t, cfg.Options)
  39. require.NotNil(t, cfg.Options.TUI)
  40. require.NotNil(t, cfg.Options.ContextPaths)
  41. require.NotNil(t, cfg.Providers)
  42. require.NotNil(t, cfg.Models)
  43. require.NotNil(t, cfg.LSP)
  44. require.NotNil(t, cfg.MCP)
  45. require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory)
  46. require.Equal(t, "AGENTS.md", cfg.Options.InitializeAs)
  47. for _, path := range defaultContextPaths {
  48. require.Contains(t, cfg.Options.ContextPaths, path)
  49. }
  50. }
  51. func TestConfig_configureProviders(t *testing.T) {
  52. knownProviders := []catwalk.Provider{
  53. {
  54. ID: "openai",
  55. APIKey: "$OPENAI_API_KEY",
  56. APIEndpoint: "https://api.openai.com/v1",
  57. Models: []catwalk.Model{{
  58. ID: "test-model",
  59. }},
  60. },
  61. }
  62. cfg := &Config{}
  63. cfg.setDefaults("/tmp", "")
  64. env := env.NewFromMap(map[string]string{
  65. "OPENAI_API_KEY": "test-key",
  66. })
  67. resolver := NewEnvironmentVariableResolver(env)
  68. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  69. require.NoError(t, err)
  70. require.Equal(t, 1, cfg.Providers.Len())
  71. // We want to make sure that we keep the configured API key as a placeholder
  72. pc, _ := cfg.Providers.Get("openai")
  73. require.Equal(t, "$OPENAI_API_KEY", pc.APIKey)
  74. }
  75. func TestConfig_configureProvidersWithOverride(t *testing.T) {
  76. knownProviders := []catwalk.Provider{
  77. {
  78. ID: "openai",
  79. APIKey: "$OPENAI_API_KEY",
  80. APIEndpoint: "https://api.openai.com/v1",
  81. Models: []catwalk.Model{{
  82. ID: "test-model",
  83. }},
  84. },
  85. }
  86. cfg := &Config{
  87. Providers: csync.NewMap[string, ProviderConfig](),
  88. }
  89. cfg.Providers.Set("openai", ProviderConfig{
  90. APIKey: "xyz",
  91. BaseURL: "https://api.openai.com/v2",
  92. Models: []catwalk.Model{
  93. {
  94. ID: "test-model",
  95. Name: "Updated",
  96. },
  97. {
  98. ID: "another-model",
  99. },
  100. },
  101. })
  102. cfg.setDefaults("/tmp", "")
  103. env := env.NewFromMap(map[string]string{
  104. "OPENAI_API_KEY": "test-key",
  105. })
  106. resolver := NewEnvironmentVariableResolver(env)
  107. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  108. require.NoError(t, err)
  109. require.Equal(t, 1, cfg.Providers.Len())
  110. // We want to make sure that we keep the configured API key as a placeholder
  111. pc, _ := cfg.Providers.Get("openai")
  112. require.Equal(t, "xyz", pc.APIKey)
  113. require.Equal(t, "https://api.openai.com/v2", pc.BaseURL)
  114. require.Len(t, pc.Models, 2)
  115. require.Equal(t, "Updated", pc.Models[0].Name)
  116. }
  117. func TestConfig_configureProvidersWithNewProvider(t *testing.T) {
  118. knownProviders := []catwalk.Provider{
  119. {
  120. ID: "openai",
  121. APIKey: "$OPENAI_API_KEY",
  122. APIEndpoint: "https://api.openai.com/v1",
  123. Models: []catwalk.Model{{
  124. ID: "test-model",
  125. }},
  126. },
  127. }
  128. cfg := &Config{
  129. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  130. "custom": {
  131. APIKey: "xyz",
  132. BaseURL: "https://api.someendpoint.com/v2",
  133. Models: []catwalk.Model{
  134. {
  135. ID: "test-model",
  136. },
  137. },
  138. },
  139. }),
  140. }
  141. cfg.setDefaults("/tmp", "")
  142. env := env.NewFromMap(map[string]string{
  143. "OPENAI_API_KEY": "test-key",
  144. })
  145. resolver := NewEnvironmentVariableResolver(env)
  146. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  147. require.NoError(t, err)
  148. // Should be to because of the env variable
  149. require.Equal(t, cfg.Providers.Len(), 2)
  150. // We want to make sure that we keep the configured API key as a placeholder
  151. pc, _ := cfg.Providers.Get("custom")
  152. require.Equal(t, "xyz", pc.APIKey)
  153. // Make sure we set the ID correctly
  154. require.Equal(t, "custom", pc.ID)
  155. require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL)
  156. require.Len(t, pc.Models, 1)
  157. _, ok := cfg.Providers.Get("openai")
  158. require.True(t, ok, "OpenAI provider should still be present")
  159. }
  160. func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) {
  161. knownProviders := []catwalk.Provider{
  162. {
  163. ID: catwalk.InferenceProviderBedrock,
  164. APIKey: "",
  165. APIEndpoint: "",
  166. Models: []catwalk.Model{{
  167. ID: "anthropic.claude-sonnet-4-20250514-v1:0",
  168. }},
  169. },
  170. }
  171. cfg := &Config{}
  172. cfg.setDefaults("/tmp", "")
  173. env := env.NewFromMap(map[string]string{
  174. "AWS_ACCESS_KEY_ID": "test-key-id",
  175. "AWS_SECRET_ACCESS_KEY": "test-secret-key",
  176. })
  177. resolver := NewEnvironmentVariableResolver(env)
  178. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  179. require.NoError(t, err)
  180. require.Equal(t, cfg.Providers.Len(), 1)
  181. bedrockProvider, ok := cfg.Providers.Get("bedrock")
  182. require.True(t, ok, "Bedrock provider should be present")
  183. require.Len(t, bedrockProvider.Models, 1)
  184. require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID)
  185. }
  186. func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) {
  187. knownProviders := []catwalk.Provider{
  188. {
  189. ID: catwalk.InferenceProviderBedrock,
  190. APIKey: "",
  191. APIEndpoint: "",
  192. Models: []catwalk.Model{{
  193. ID: "anthropic.claude-sonnet-4-20250514-v1:0",
  194. }},
  195. },
  196. }
  197. cfg := &Config{}
  198. cfg.setDefaults("/tmp", "")
  199. env := env.NewFromMap(map[string]string{})
  200. resolver := NewEnvironmentVariableResolver(env)
  201. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  202. require.NoError(t, err)
  203. // Provider should not be configured without credentials
  204. require.Equal(t, cfg.Providers.Len(), 0)
  205. }
  206. func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) {
  207. knownProviders := []catwalk.Provider{
  208. {
  209. ID: catwalk.InferenceProviderBedrock,
  210. APIKey: "",
  211. APIEndpoint: "",
  212. Models: []catwalk.Model{{
  213. ID: "some-random-model",
  214. }},
  215. },
  216. }
  217. cfg := &Config{}
  218. cfg.setDefaults("/tmp", "")
  219. env := env.NewFromMap(map[string]string{
  220. "AWS_ACCESS_KEY_ID": "test-key-id",
  221. "AWS_SECRET_ACCESS_KEY": "test-secret-key",
  222. })
  223. resolver := NewEnvironmentVariableResolver(env)
  224. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  225. require.Error(t, err)
  226. }
  227. func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) {
  228. knownProviders := []catwalk.Provider{
  229. {
  230. ID: catwalk.InferenceProviderVertexAI,
  231. APIKey: "",
  232. APIEndpoint: "",
  233. Models: []catwalk.Model{{
  234. ID: "gemini-pro",
  235. }},
  236. },
  237. }
  238. cfg := &Config{}
  239. cfg.setDefaults("/tmp", "")
  240. env := env.NewFromMap(map[string]string{
  241. "VERTEXAI_PROJECT": "test-project",
  242. "VERTEXAI_LOCATION": "us-central1",
  243. })
  244. resolver := NewEnvironmentVariableResolver(env)
  245. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  246. require.NoError(t, err)
  247. require.Equal(t, cfg.Providers.Len(), 1)
  248. vertexProvider, ok := cfg.Providers.Get("vertexai")
  249. require.True(t, ok, "VertexAI provider should be present")
  250. require.Len(t, vertexProvider.Models, 1)
  251. require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID)
  252. require.Equal(t, "test-project", vertexProvider.ExtraParams["project"])
  253. require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"])
  254. }
  255. func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) {
  256. knownProviders := []catwalk.Provider{
  257. {
  258. ID: catwalk.InferenceProviderVertexAI,
  259. APIKey: "",
  260. APIEndpoint: "",
  261. Models: []catwalk.Model{{
  262. ID: "gemini-pro",
  263. }},
  264. },
  265. }
  266. cfg := &Config{}
  267. cfg.setDefaults("/tmp", "")
  268. env := env.NewFromMap(map[string]string{
  269. "GOOGLE_GENAI_USE_VERTEXAI": "false",
  270. "GOOGLE_CLOUD_PROJECT": "test-project",
  271. "GOOGLE_CLOUD_LOCATION": "us-central1",
  272. })
  273. resolver := NewEnvironmentVariableResolver(env)
  274. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  275. require.NoError(t, err)
  276. // Provider should not be configured without proper credentials
  277. require.Equal(t, cfg.Providers.Len(), 0)
  278. }
  279. func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) {
  280. knownProviders := []catwalk.Provider{
  281. {
  282. ID: catwalk.InferenceProviderVertexAI,
  283. APIKey: "",
  284. APIEndpoint: "",
  285. Models: []catwalk.Model{{
  286. ID: "gemini-pro",
  287. }},
  288. },
  289. }
  290. cfg := &Config{}
  291. cfg.setDefaults("/tmp", "")
  292. env := env.NewFromMap(map[string]string{
  293. "GOOGLE_GENAI_USE_VERTEXAI": "true",
  294. "GOOGLE_CLOUD_LOCATION": "us-central1",
  295. })
  296. resolver := NewEnvironmentVariableResolver(env)
  297. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  298. require.NoError(t, err)
  299. // Provider should not be configured without project
  300. require.Equal(t, cfg.Providers.Len(), 0)
  301. }
  302. func TestConfig_configureProvidersSetProviderID(t *testing.T) {
  303. knownProviders := []catwalk.Provider{
  304. {
  305. ID: "openai",
  306. APIKey: "$OPENAI_API_KEY",
  307. APIEndpoint: "https://api.openai.com/v1",
  308. Models: []catwalk.Model{{
  309. ID: "test-model",
  310. }},
  311. },
  312. }
  313. cfg := &Config{}
  314. cfg.setDefaults("/tmp", "")
  315. env := env.NewFromMap(map[string]string{
  316. "OPENAI_API_KEY": "test-key",
  317. })
  318. resolver := NewEnvironmentVariableResolver(env)
  319. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  320. require.NoError(t, err)
  321. require.Equal(t, cfg.Providers.Len(), 1)
  322. // Provider ID should be set
  323. pc, _ := cfg.Providers.Get("openai")
  324. require.Equal(t, "openai", pc.ID)
  325. }
  326. func TestConfig_EnabledProviders(t *testing.T) {
  327. t.Run("all providers enabled", func(t *testing.T) {
  328. cfg := &Config{
  329. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  330. "openai": {
  331. ID: "openai",
  332. APIKey: "key1",
  333. Disable: false,
  334. },
  335. "anthropic": {
  336. ID: "anthropic",
  337. APIKey: "key2",
  338. Disable: false,
  339. },
  340. }),
  341. }
  342. enabled := cfg.EnabledProviders()
  343. require.Len(t, enabled, 2)
  344. })
  345. t.Run("some providers disabled", func(t *testing.T) {
  346. cfg := &Config{
  347. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  348. "openai": {
  349. ID: "openai",
  350. APIKey: "key1",
  351. Disable: false,
  352. },
  353. "anthropic": {
  354. ID: "anthropic",
  355. APIKey: "key2",
  356. Disable: true,
  357. },
  358. }),
  359. }
  360. enabled := cfg.EnabledProviders()
  361. require.Len(t, enabled, 1)
  362. require.Equal(t, "openai", enabled[0].ID)
  363. })
  364. t.Run("empty providers map", func(t *testing.T) {
  365. cfg := &Config{
  366. Providers: csync.NewMap[string, ProviderConfig](),
  367. }
  368. enabled := cfg.EnabledProviders()
  369. require.Len(t, enabled, 0)
  370. })
  371. }
  372. func TestConfig_IsConfigured(t *testing.T) {
  373. t.Run("returns true when at least one provider is enabled", func(t *testing.T) {
  374. cfg := &Config{
  375. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  376. "openai": {
  377. ID: "openai",
  378. APIKey: "key1",
  379. Disable: false,
  380. },
  381. }),
  382. }
  383. require.True(t, cfg.IsConfigured())
  384. })
  385. t.Run("returns false when no providers are configured", func(t *testing.T) {
  386. cfg := &Config{
  387. Providers: csync.NewMap[string, ProviderConfig](),
  388. }
  389. require.False(t, cfg.IsConfigured())
  390. })
  391. t.Run("returns false when all providers are disabled", func(t *testing.T) {
  392. cfg := &Config{
  393. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  394. "openai": {
  395. ID: "openai",
  396. APIKey: "key1",
  397. Disable: true,
  398. },
  399. "anthropic": {
  400. ID: "anthropic",
  401. APIKey: "key2",
  402. Disable: true,
  403. },
  404. }),
  405. }
  406. require.False(t, cfg.IsConfigured())
  407. })
  408. }
  409. func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) {
  410. cfg := &Config{
  411. Options: &Options{
  412. DisabledTools: []string{},
  413. },
  414. }
  415. cfg.SetupAgents()
  416. coderAgent, ok := cfg.Agents[AgentCoder]
  417. require.True(t, ok)
  418. assert.Equal(t, allToolNames(), coderAgent.AllowedTools)
  419. taskAgent, ok := cfg.Agents[AgentTask]
  420. require.True(t, ok)
  421. assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
  422. }
  423. func TestConfig_setupAgentsWithDisabledTools(t *testing.T) {
  424. cfg := &Config{
  425. Options: &Options{
  426. DisabledTools: []string{
  427. "edit",
  428. "download",
  429. "grep",
  430. },
  431. },
  432. }
  433. cfg.SetupAgents()
  434. coderAgent, ok := cfg.Agents[AgentCoder]
  435. require.True(t, ok)
  436. assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "todos", "view", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)
  437. taskAgent, ok := cfg.Agents[AgentTask]
  438. require.True(t, ok)
  439. assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools)
  440. }
  441. func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) {
  442. cfg := &Config{
  443. Options: &Options{
  444. DisabledTools: []string{
  445. "glob",
  446. "grep",
  447. "ls",
  448. "sourcegraph",
  449. "view",
  450. },
  451. },
  452. }
  453. cfg.SetupAgents()
  454. coderAgent, ok := cfg.Agents[AgentCoder]
  455. require.True(t, ok)
  456. assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "lsp_restart", "fetch", "agentic_fetch", "todos", "write", "list_mcp_resources", "read_mcp_resource"}, coderAgent.AllowedTools)
  457. taskAgent, ok := cfg.Agents[AgentTask]
  458. require.True(t, ok)
  459. assert.Len(t, taskAgent.AllowedTools, 0)
  460. }
  461. func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) {
  462. knownProviders := []catwalk.Provider{
  463. {
  464. ID: "openai",
  465. APIKey: "$OPENAI_API_KEY",
  466. APIEndpoint: "https://api.openai.com/v1",
  467. Models: []catwalk.Model{{
  468. ID: "test-model",
  469. }},
  470. },
  471. }
  472. cfg := &Config{
  473. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  474. "openai": {
  475. Disable: true,
  476. },
  477. }),
  478. }
  479. cfg.setDefaults("/tmp", "")
  480. env := env.NewFromMap(map[string]string{
  481. "OPENAI_API_KEY": "test-key",
  482. })
  483. resolver := NewEnvironmentVariableResolver(env)
  484. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  485. require.NoError(t, err)
  486. require.Equal(t, cfg.Providers.Len(), 1)
  487. prov, exists := cfg.Providers.Get("openai")
  488. require.True(t, exists)
  489. require.True(t, prov.Disable)
  490. }
  491. func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) {
  492. t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) {
  493. cfg := &Config{
  494. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  495. "custom": {
  496. BaseURL: "https://api.custom.com/v1",
  497. Models: []catwalk.Model{{
  498. ID: "test-model",
  499. }},
  500. },
  501. "openai": {
  502. APIKey: "$MISSING",
  503. },
  504. }),
  505. }
  506. cfg.setDefaults("/tmp", "")
  507. env := env.NewFromMap(map[string]string{})
  508. resolver := NewEnvironmentVariableResolver(env)
  509. err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
  510. require.NoError(t, err)
  511. require.Equal(t, cfg.Providers.Len(), 1)
  512. _, exists := cfg.Providers.Get("custom")
  513. require.True(t, exists)
  514. })
  515. t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) {
  516. cfg := &Config{
  517. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  518. "custom": {
  519. APIKey: "test-key",
  520. Models: []catwalk.Model{{
  521. ID: "test-model",
  522. }},
  523. },
  524. }),
  525. }
  526. cfg.setDefaults("/tmp", "")
  527. env := env.NewFromMap(map[string]string{})
  528. resolver := NewEnvironmentVariableResolver(env)
  529. err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
  530. require.NoError(t, err)
  531. require.Equal(t, cfg.Providers.Len(), 0)
  532. _, exists := cfg.Providers.Get("custom")
  533. require.False(t, exists)
  534. })
  535. t.Run("custom provider with no models is removed", func(t *testing.T) {
  536. cfg := &Config{
  537. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  538. "custom": {
  539. APIKey: "test-key",
  540. BaseURL: "https://api.custom.com/v1",
  541. Models: []catwalk.Model{},
  542. },
  543. }),
  544. }
  545. cfg.setDefaults("/tmp", "")
  546. env := env.NewFromMap(map[string]string{})
  547. resolver := NewEnvironmentVariableResolver(env)
  548. err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
  549. require.NoError(t, err)
  550. require.Equal(t, cfg.Providers.Len(), 0)
  551. _, exists := cfg.Providers.Get("custom")
  552. require.False(t, exists)
  553. })
  554. t.Run("custom provider with unsupported type is removed", func(t *testing.T) {
  555. cfg := &Config{
  556. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  557. "custom": {
  558. APIKey: "test-key",
  559. BaseURL: "https://api.custom.com/v1",
  560. Type: "unsupported",
  561. Models: []catwalk.Model{{
  562. ID: "test-model",
  563. }},
  564. },
  565. }),
  566. }
  567. cfg.setDefaults("/tmp", "")
  568. env := env.NewFromMap(map[string]string{})
  569. resolver := NewEnvironmentVariableResolver(env)
  570. err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
  571. require.NoError(t, err)
  572. require.Equal(t, cfg.Providers.Len(), 0)
  573. _, exists := cfg.Providers.Get("custom")
  574. require.False(t, exists)
  575. })
  576. t.Run("valid custom provider is kept and ID is set", func(t *testing.T) {
  577. cfg := &Config{
  578. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  579. "custom": {
  580. APIKey: "test-key",
  581. BaseURL: "https://api.custom.com/v1",
  582. Type: catwalk.TypeOpenAI,
  583. Models: []catwalk.Model{{
  584. ID: "test-model",
  585. }},
  586. },
  587. }),
  588. }
  589. cfg.setDefaults("/tmp", "")
  590. env := env.NewFromMap(map[string]string{})
  591. resolver := NewEnvironmentVariableResolver(env)
  592. err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
  593. require.NoError(t, err)
  594. require.Equal(t, cfg.Providers.Len(), 1)
  595. customProvider, exists := cfg.Providers.Get("custom")
  596. require.True(t, exists)
  597. require.Equal(t, "custom", customProvider.ID)
  598. require.Equal(t, "test-key", customProvider.APIKey)
  599. require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL)
  600. })
  601. t.Run("custom anthropic provider is supported", func(t *testing.T) {
  602. cfg := &Config{
  603. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  604. "custom-anthropic": {
  605. APIKey: "test-key",
  606. BaseURL: "https://api.anthropic.com/v1",
  607. Type: catwalk.TypeAnthropic,
  608. Models: []catwalk.Model{{
  609. ID: "claude-3-sonnet",
  610. }},
  611. },
  612. }),
  613. }
  614. cfg.setDefaults("/tmp", "")
  615. env := env.NewFromMap(map[string]string{})
  616. resolver := NewEnvironmentVariableResolver(env)
  617. err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
  618. require.NoError(t, err)
  619. require.Equal(t, cfg.Providers.Len(), 1)
  620. customProvider, exists := cfg.Providers.Get("custom-anthropic")
  621. require.True(t, exists)
  622. require.Equal(t, "custom-anthropic", customProvider.ID)
  623. require.Equal(t, "test-key", customProvider.APIKey)
  624. require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL)
  625. require.Equal(t, catwalk.TypeAnthropic, customProvider.Type)
  626. })
  627. t.Run("disabled custom provider is removed", func(t *testing.T) {
  628. cfg := &Config{
  629. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  630. "custom": {
  631. APIKey: "test-key",
  632. BaseURL: "https://api.custom.com/v1",
  633. Type: catwalk.TypeOpenAI,
  634. Disable: true,
  635. Models: []catwalk.Model{{
  636. ID: "test-model",
  637. }},
  638. },
  639. }),
  640. }
  641. cfg.setDefaults("/tmp", "")
  642. env := env.NewFromMap(map[string]string{})
  643. resolver := NewEnvironmentVariableResolver(env)
  644. err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
  645. require.NoError(t, err)
  646. require.Equal(t, cfg.Providers.Len(), 0)
  647. _, exists := cfg.Providers.Get("custom")
  648. require.False(t, exists)
  649. })
  650. }
  651. func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) {
  652. t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) {
  653. knownProviders := []catwalk.Provider{
  654. {
  655. ID: catwalk.InferenceProviderVertexAI,
  656. APIKey: "",
  657. APIEndpoint: "",
  658. Models: []catwalk.Model{{
  659. ID: "gemini-pro",
  660. }},
  661. },
  662. }
  663. cfg := &Config{
  664. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  665. "vertexai": {
  666. BaseURL: "custom-url",
  667. },
  668. }),
  669. }
  670. cfg.setDefaults("/tmp", "")
  671. env := env.NewFromMap(map[string]string{
  672. "GOOGLE_GENAI_USE_VERTEXAI": "false",
  673. })
  674. resolver := NewEnvironmentVariableResolver(env)
  675. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  676. require.NoError(t, err)
  677. require.Equal(t, cfg.Providers.Len(), 0)
  678. _, exists := cfg.Providers.Get("vertexai")
  679. require.False(t, exists)
  680. })
  681. t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) {
  682. knownProviders := []catwalk.Provider{
  683. {
  684. ID: catwalk.InferenceProviderBedrock,
  685. APIKey: "",
  686. APIEndpoint: "",
  687. Models: []catwalk.Model{{
  688. ID: "anthropic.claude-sonnet-4-20250514-v1:0",
  689. }},
  690. },
  691. }
  692. cfg := &Config{
  693. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  694. "bedrock": {
  695. BaseURL: "custom-url",
  696. },
  697. }),
  698. }
  699. cfg.setDefaults("/tmp", "")
  700. env := env.NewFromMap(map[string]string{})
  701. resolver := NewEnvironmentVariableResolver(env)
  702. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  703. require.NoError(t, err)
  704. require.Equal(t, cfg.Providers.Len(), 0)
  705. _, exists := cfg.Providers.Get("bedrock")
  706. require.False(t, exists)
  707. })
  708. t.Run("provider removed when API key missing with existing config", func(t *testing.T) {
  709. knownProviders := []catwalk.Provider{
  710. {
  711. ID: "openai",
  712. APIKey: "$MISSING_API_KEY",
  713. APIEndpoint: "https://api.openai.com/v1",
  714. Models: []catwalk.Model{{
  715. ID: "test-model",
  716. }},
  717. },
  718. }
  719. cfg := &Config{
  720. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  721. "openai": {
  722. BaseURL: "custom-url",
  723. },
  724. }),
  725. }
  726. cfg.setDefaults("/tmp", "")
  727. env := env.NewFromMap(map[string]string{})
  728. resolver := NewEnvironmentVariableResolver(env)
  729. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  730. require.NoError(t, err)
  731. require.Equal(t, cfg.Providers.Len(), 0)
  732. _, exists := cfg.Providers.Get("openai")
  733. require.False(t, exists)
  734. })
  735. t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) {
  736. knownProviders := []catwalk.Provider{
  737. {
  738. ID: "openai",
  739. APIKey: "$OPENAI_API_KEY",
  740. APIEndpoint: "$MISSING_ENDPOINT",
  741. Models: []catwalk.Model{{
  742. ID: "test-model",
  743. }},
  744. },
  745. }
  746. cfg := &Config{
  747. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  748. "openai": {
  749. APIKey: "test-key",
  750. },
  751. }),
  752. }
  753. cfg.setDefaults("/tmp", "")
  754. env := env.NewFromMap(map[string]string{
  755. "OPENAI_API_KEY": "test-key",
  756. })
  757. resolver := NewEnvironmentVariableResolver(env)
  758. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  759. require.NoError(t, err)
  760. require.Equal(t, cfg.Providers.Len(), 1)
  761. _, exists := cfg.Providers.Get("openai")
  762. require.True(t, exists)
  763. })
  764. }
  765. func TestConfig_defaultModelSelection(t *testing.T) {
  766. t.Run("default behavior uses the default models for given provider", func(t *testing.T) {
  767. knownProviders := []catwalk.Provider{
  768. {
  769. ID: "openai",
  770. APIKey: "abc",
  771. DefaultLargeModelID: "large-model",
  772. DefaultSmallModelID: "small-model",
  773. Models: []catwalk.Model{
  774. {
  775. ID: "large-model",
  776. DefaultMaxTokens: 1000,
  777. },
  778. {
  779. ID: "small-model",
  780. DefaultMaxTokens: 500,
  781. },
  782. },
  783. },
  784. }
  785. cfg := &Config{}
  786. cfg.setDefaults("/tmp", "")
  787. env := env.NewFromMap(map[string]string{})
  788. resolver := NewEnvironmentVariableResolver(env)
  789. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  790. require.NoError(t, err)
  791. large, small, err := cfg.defaultModelSelection(knownProviders)
  792. require.NoError(t, err)
  793. require.Equal(t, "large-model", large.Model)
  794. require.Equal(t, "openai", large.Provider)
  795. require.Equal(t, int64(1000), large.MaxTokens)
  796. require.Equal(t, "small-model", small.Model)
  797. require.Equal(t, "openai", small.Provider)
  798. require.Equal(t, int64(500), small.MaxTokens)
  799. })
  800. t.Run("should error if no providers configured", func(t *testing.T) {
  801. knownProviders := []catwalk.Provider{
  802. {
  803. ID: "openai",
  804. APIKey: "$MISSING_KEY",
  805. DefaultLargeModelID: "large-model",
  806. DefaultSmallModelID: "small-model",
  807. Models: []catwalk.Model{
  808. {
  809. ID: "large-model",
  810. DefaultMaxTokens: 1000,
  811. },
  812. {
  813. ID: "small-model",
  814. DefaultMaxTokens: 500,
  815. },
  816. },
  817. },
  818. }
  819. cfg := &Config{}
  820. cfg.setDefaults("/tmp", "")
  821. env := env.NewFromMap(map[string]string{})
  822. resolver := NewEnvironmentVariableResolver(env)
  823. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  824. require.NoError(t, err)
  825. _, _, err = cfg.defaultModelSelection(knownProviders)
  826. require.Error(t, err)
  827. })
  828. t.Run("should error if model is missing", func(t *testing.T) {
  829. knownProviders := []catwalk.Provider{
  830. {
  831. ID: "openai",
  832. APIKey: "abc",
  833. DefaultLargeModelID: "large-model",
  834. DefaultSmallModelID: "small-model",
  835. Models: []catwalk.Model{
  836. {
  837. ID: "not-large-model",
  838. DefaultMaxTokens: 1000,
  839. },
  840. {
  841. ID: "small-model",
  842. DefaultMaxTokens: 500,
  843. },
  844. },
  845. },
  846. }
  847. cfg := &Config{}
  848. cfg.setDefaults("/tmp", "")
  849. env := env.NewFromMap(map[string]string{})
  850. resolver := NewEnvironmentVariableResolver(env)
  851. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  852. require.NoError(t, err)
  853. _, _, err = cfg.defaultModelSelection(knownProviders)
  854. require.Error(t, err)
  855. })
  856. t.Run("should configure the default models with a custom provider", func(t *testing.T) {
  857. knownProviders := []catwalk.Provider{
  858. {
  859. ID: "openai",
  860. APIKey: "$MISSING", // will not be included in the config
  861. DefaultLargeModelID: "large-model",
  862. DefaultSmallModelID: "small-model",
  863. Models: []catwalk.Model{
  864. {
  865. ID: "not-large-model",
  866. DefaultMaxTokens: 1000,
  867. },
  868. {
  869. ID: "small-model",
  870. DefaultMaxTokens: 500,
  871. },
  872. },
  873. },
  874. }
  875. cfg := &Config{
  876. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  877. "custom": {
  878. APIKey: "test-key",
  879. BaseURL: "https://api.custom.com/v1",
  880. Models: []catwalk.Model{
  881. {
  882. ID: "model",
  883. DefaultMaxTokens: 600,
  884. },
  885. },
  886. },
  887. }),
  888. }
  889. cfg.setDefaults("/tmp", "")
  890. env := env.NewFromMap(map[string]string{})
  891. resolver := NewEnvironmentVariableResolver(env)
  892. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  893. require.NoError(t, err)
  894. large, small, err := cfg.defaultModelSelection(knownProviders)
  895. require.NoError(t, err)
  896. require.Equal(t, "model", large.Model)
  897. require.Equal(t, "custom", large.Provider)
  898. require.Equal(t, int64(600), large.MaxTokens)
  899. require.Equal(t, "model", small.Model)
  900. require.Equal(t, "custom", small.Provider)
  901. require.Equal(t, int64(600), small.MaxTokens)
  902. })
  903. t.Run("should fail if no model configured", func(t *testing.T) {
  904. knownProviders := []catwalk.Provider{
  905. {
  906. ID: "openai",
  907. APIKey: "$MISSING", // will not be included in the config
  908. DefaultLargeModelID: "large-model",
  909. DefaultSmallModelID: "small-model",
  910. Models: []catwalk.Model{
  911. {
  912. ID: "not-large-model",
  913. DefaultMaxTokens: 1000,
  914. },
  915. {
  916. ID: "small-model",
  917. DefaultMaxTokens: 500,
  918. },
  919. },
  920. },
  921. }
  922. cfg := &Config{
  923. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  924. "custom": {
  925. APIKey: "test-key",
  926. BaseURL: "https://api.custom.com/v1",
  927. Models: []catwalk.Model{},
  928. },
  929. }),
  930. }
  931. cfg.setDefaults("/tmp", "")
  932. env := env.NewFromMap(map[string]string{})
  933. resolver := NewEnvironmentVariableResolver(env)
  934. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  935. require.NoError(t, err)
  936. _, _, err = cfg.defaultModelSelection(knownProviders)
  937. require.Error(t, err)
  938. })
  939. t.Run("should use the default provider first", func(t *testing.T) {
  940. knownProviders := []catwalk.Provider{
  941. {
  942. ID: "openai",
  943. APIKey: "set",
  944. DefaultLargeModelID: "large-model",
  945. DefaultSmallModelID: "small-model",
  946. Models: []catwalk.Model{
  947. {
  948. ID: "large-model",
  949. DefaultMaxTokens: 1000,
  950. },
  951. {
  952. ID: "small-model",
  953. DefaultMaxTokens: 500,
  954. },
  955. },
  956. },
  957. }
  958. cfg := &Config{
  959. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  960. "custom": {
  961. APIKey: "test-key",
  962. BaseURL: "https://api.custom.com/v1",
  963. Models: []catwalk.Model{
  964. {
  965. ID: "large-model",
  966. DefaultMaxTokens: 1000,
  967. },
  968. },
  969. },
  970. }),
  971. }
  972. cfg.setDefaults("/tmp", "")
  973. env := env.NewFromMap(map[string]string{})
  974. resolver := NewEnvironmentVariableResolver(env)
  975. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  976. require.NoError(t, err)
  977. large, small, err := cfg.defaultModelSelection(knownProviders)
  978. require.NoError(t, err)
  979. require.Equal(t, "large-model", large.Model)
  980. require.Equal(t, "openai", large.Provider)
  981. require.Equal(t, int64(1000), large.MaxTokens)
  982. require.Equal(t, "small-model", small.Model)
  983. require.Equal(t, "openai", small.Provider)
  984. require.Equal(t, int64(500), small.MaxTokens)
  985. })
  986. }
  987. func TestConfig_configureProvidersDisableDefaultProviders(t *testing.T) {
  988. t.Run("when enabled, ignores all default providers and requires full specification", func(t *testing.T) {
  989. knownProviders := []catwalk.Provider{
  990. {
  991. ID: "openai",
  992. APIKey: "$OPENAI_API_KEY",
  993. APIEndpoint: "https://api.openai.com/v1",
  994. Models: []catwalk.Model{{
  995. ID: "gpt-4",
  996. }},
  997. },
  998. }
  999. // User references openai but doesn't fully specify it (no base_url, no
  1000. // models). This should be rejected because disable_default_providers
  1001. // treats all providers as custom.
  1002. cfg := &Config{
  1003. Options: &Options{
  1004. DisableDefaultProviders: true,
  1005. },
  1006. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  1007. "openai": {
  1008. APIKey: "$OPENAI_API_KEY",
  1009. },
  1010. }),
  1011. }
  1012. cfg.setDefaults("/tmp", "")
  1013. env := env.NewFromMap(map[string]string{
  1014. "OPENAI_API_KEY": "test-key",
  1015. })
  1016. resolver := NewEnvironmentVariableResolver(env)
  1017. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  1018. require.ErrorContains(t, err, "no custom providers")
  1019. // openai should NOT be present because it lacks base_url and models.
  1020. require.Equal(t, 0, cfg.Providers.Len())
  1021. _, exists := cfg.Providers.Get("openai")
  1022. require.False(t, exists, "openai should not be present without full specification")
  1023. })
  1024. t.Run("when enabled, fully specified providers work", func(t *testing.T) {
  1025. knownProviders := []catwalk.Provider{
  1026. {
  1027. ID: "openai",
  1028. APIKey: "$OPENAI_API_KEY",
  1029. APIEndpoint: "https://api.openai.com/v1",
  1030. Models: []catwalk.Model{{
  1031. ID: "gpt-4",
  1032. }},
  1033. },
  1034. }
  1035. // User fully specifies their provider.
  1036. cfg := &Config{
  1037. Options: &Options{
  1038. DisableDefaultProviders: true,
  1039. },
  1040. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  1041. "my-llm": {
  1042. APIKey: "$MY_API_KEY",
  1043. BaseURL: "https://my-llm.example.com/v1",
  1044. Models: []catwalk.Model{{
  1045. ID: "my-model",
  1046. }},
  1047. },
  1048. }),
  1049. }
  1050. cfg.setDefaults("/tmp", "")
  1051. env := env.NewFromMap(map[string]string{
  1052. "MY_API_KEY": "test-key",
  1053. "OPENAI_API_KEY": "test-key",
  1054. })
  1055. resolver := NewEnvironmentVariableResolver(env)
  1056. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  1057. require.NoError(t, err)
  1058. // Only fully specified provider should be present.
  1059. require.Equal(t, 1, cfg.Providers.Len())
  1060. provider, exists := cfg.Providers.Get("my-llm")
  1061. require.True(t, exists, "my-llm should be present")
  1062. require.Equal(t, "https://my-llm.example.com/v1", provider.BaseURL)
  1063. require.Len(t, provider.Models, 1)
  1064. // Default openai should NOT be present.
  1065. _, exists = cfg.Providers.Get("openai")
  1066. require.False(t, exists, "openai should not be present")
  1067. })
  1068. t.Run("when disabled, includes all known providers with valid credentials", func(t *testing.T) {
  1069. knownProviders := []catwalk.Provider{
  1070. {
  1071. ID: "openai",
  1072. APIKey: "$OPENAI_API_KEY",
  1073. APIEndpoint: "https://api.openai.com/v1",
  1074. Models: []catwalk.Model{{
  1075. ID: "gpt-4",
  1076. }},
  1077. },
  1078. {
  1079. ID: "anthropic",
  1080. APIKey: "$ANTHROPIC_API_KEY",
  1081. APIEndpoint: "https://api.anthropic.com/v1",
  1082. Models: []catwalk.Model{{
  1083. ID: "claude-3",
  1084. }},
  1085. },
  1086. }
  1087. // User only configures openai, both API keys are available, but option
  1088. // is disabled.
  1089. cfg := &Config{
  1090. Options: &Options{
  1091. DisableDefaultProviders: false,
  1092. },
  1093. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  1094. "openai": {
  1095. APIKey: "$OPENAI_API_KEY",
  1096. },
  1097. }),
  1098. }
  1099. cfg.setDefaults("/tmp", "")
  1100. env := env.NewFromMap(map[string]string{
  1101. "OPENAI_API_KEY": "test-key",
  1102. "ANTHROPIC_API_KEY": "test-key",
  1103. })
  1104. resolver := NewEnvironmentVariableResolver(env)
  1105. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  1106. require.NoError(t, err)
  1107. // Both providers should be present.
  1108. require.Equal(t, 2, cfg.Providers.Len())
  1109. _, exists := cfg.Providers.Get("openai")
  1110. require.True(t, exists, "openai should be present")
  1111. _, exists = cfg.Providers.Get("anthropic")
  1112. require.True(t, exists, "anthropic should be present")
  1113. })
  1114. t.Run("when enabled, provider missing models is rejected", func(t *testing.T) {
  1115. cfg := &Config{
  1116. Options: &Options{
  1117. DisableDefaultProviders: true,
  1118. },
  1119. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  1120. "my-llm": {
  1121. APIKey: "test-key",
  1122. BaseURL: "https://my-llm.example.com/v1",
  1123. Models: []catwalk.Model{}, // No models.
  1124. },
  1125. }),
  1126. }
  1127. cfg.setDefaults("/tmp", "")
  1128. env := env.NewFromMap(map[string]string{})
  1129. resolver := NewEnvironmentVariableResolver(env)
  1130. err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
  1131. require.ErrorContains(t, err, "no custom providers")
  1132. // Provider should be rejected for missing models.
  1133. require.Equal(t, 0, cfg.Providers.Len())
  1134. })
  1135. t.Run("when enabled, provider missing base_url is rejected", func(t *testing.T) {
  1136. cfg := &Config{
  1137. Options: &Options{
  1138. DisableDefaultProviders: true,
  1139. },
  1140. Providers: csync.NewMapFrom(map[string]ProviderConfig{
  1141. "my-llm": {
  1142. APIKey: "test-key",
  1143. Models: []catwalk.Model{{ID: "model"}},
  1144. // No BaseURL.
  1145. },
  1146. }),
  1147. }
  1148. cfg.setDefaults("/tmp", "")
  1149. env := env.NewFromMap(map[string]string{})
  1150. resolver := NewEnvironmentVariableResolver(env)
  1151. err := cfg.configureProviders(testStore(cfg), env, resolver, []catwalk.Provider{})
  1152. require.ErrorContains(t, err, "no custom providers")
  1153. // Provider should be rejected for missing base_url.
  1154. require.Equal(t, 0, cfg.Providers.Len())
  1155. })
  1156. }
  1157. func TestConfig_setDefaultsDisableDefaultProvidersEnvVar(t *testing.T) {
  1158. t.Run("sets option from environment variable", func(t *testing.T) {
  1159. t.Setenv("CRUSH_DISABLE_DEFAULT_PROVIDERS", "true")
  1160. cfg := &Config{}
  1161. cfg.setDefaults("/tmp", "")
  1162. require.True(t, cfg.Options.DisableDefaultProviders)
  1163. })
  1164. t.Run("does not override when env var is not set", func(t *testing.T) {
  1165. cfg := &Config{
  1166. Options: &Options{
  1167. DisableDefaultProviders: true,
  1168. },
  1169. }
  1170. cfg.setDefaults("/tmp", "")
  1171. require.True(t, cfg.Options.DisableDefaultProviders)
  1172. })
  1173. }
  1174. func TestConfig_configureSelectedModels(t *testing.T) {
  1175. t.Run("should override defaults", func(t *testing.T) {
  1176. knownProviders := []catwalk.Provider{
  1177. {
  1178. ID: "openai",
  1179. APIKey: "abc",
  1180. DefaultLargeModelID: "large-model",
  1181. DefaultSmallModelID: "small-model",
  1182. Models: []catwalk.Model{
  1183. {
  1184. ID: "larger-model",
  1185. DefaultMaxTokens: 2000,
  1186. },
  1187. {
  1188. ID: "large-model",
  1189. DefaultMaxTokens: 1000,
  1190. },
  1191. {
  1192. ID: "small-model",
  1193. DefaultMaxTokens: 500,
  1194. },
  1195. },
  1196. },
  1197. }
  1198. cfg := &Config{
  1199. Models: map[SelectedModelType]SelectedModel{
  1200. "large": {
  1201. Model: "larger-model",
  1202. },
  1203. },
  1204. }
  1205. cfg.setDefaults("/tmp", "")
  1206. env := env.NewFromMap(map[string]string{})
  1207. resolver := NewEnvironmentVariableResolver(env)
  1208. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  1209. require.NoError(t, err)
  1210. err = configureSelectedModels(testStore(cfg), knownProviders)
  1211. require.NoError(t, err)
  1212. large := cfg.Models[SelectedModelTypeLarge]
  1213. small := cfg.Models[SelectedModelTypeSmall]
  1214. require.Equal(t, "larger-model", large.Model)
  1215. require.Equal(t, "openai", large.Provider)
  1216. require.Equal(t, int64(2000), large.MaxTokens)
  1217. require.Equal(t, "small-model", small.Model)
  1218. require.Equal(t, "openai", small.Provider)
  1219. require.Equal(t, int64(500), small.MaxTokens)
  1220. })
  1221. t.Run("should be possible to use multiple providers", func(t *testing.T) {
  1222. knownProviders := []catwalk.Provider{
  1223. {
  1224. ID: "openai",
  1225. APIKey: "abc",
  1226. DefaultLargeModelID: "large-model",
  1227. DefaultSmallModelID: "small-model",
  1228. Models: []catwalk.Model{
  1229. {
  1230. ID: "large-model",
  1231. DefaultMaxTokens: 1000,
  1232. },
  1233. {
  1234. ID: "small-model",
  1235. DefaultMaxTokens: 500,
  1236. },
  1237. },
  1238. },
  1239. {
  1240. ID: "anthropic",
  1241. APIKey: "abc",
  1242. DefaultLargeModelID: "a-large-model",
  1243. DefaultSmallModelID: "a-small-model",
  1244. Models: []catwalk.Model{
  1245. {
  1246. ID: "a-large-model",
  1247. DefaultMaxTokens: 1000,
  1248. },
  1249. {
  1250. ID: "a-small-model",
  1251. DefaultMaxTokens: 200,
  1252. },
  1253. },
  1254. },
  1255. }
  1256. cfg := &Config{
  1257. Models: map[SelectedModelType]SelectedModel{
  1258. "small": {
  1259. Model: "a-small-model",
  1260. Provider: "anthropic",
  1261. MaxTokens: 300,
  1262. },
  1263. },
  1264. }
  1265. cfg.setDefaults("/tmp", "")
  1266. env := env.NewFromMap(map[string]string{})
  1267. resolver := NewEnvironmentVariableResolver(env)
  1268. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  1269. require.NoError(t, err)
  1270. err = configureSelectedModels(testStore(cfg), knownProviders)
  1271. require.NoError(t, err)
  1272. large := cfg.Models[SelectedModelTypeLarge]
  1273. small := cfg.Models[SelectedModelTypeSmall]
  1274. require.Equal(t, "large-model", large.Model)
  1275. require.Equal(t, "openai", large.Provider)
  1276. require.Equal(t, int64(1000), large.MaxTokens)
  1277. require.Equal(t, "a-small-model", small.Model)
  1278. require.Equal(t, "anthropic", small.Provider)
  1279. require.Equal(t, int64(300), small.MaxTokens)
  1280. })
  1281. t.Run("should override the max tokens only", func(t *testing.T) {
  1282. knownProviders := []catwalk.Provider{
  1283. {
  1284. ID: "openai",
  1285. APIKey: "abc",
  1286. DefaultLargeModelID: "large-model",
  1287. DefaultSmallModelID: "small-model",
  1288. Models: []catwalk.Model{
  1289. {
  1290. ID: "large-model",
  1291. DefaultMaxTokens: 1000,
  1292. },
  1293. {
  1294. ID: "small-model",
  1295. DefaultMaxTokens: 500,
  1296. },
  1297. },
  1298. },
  1299. }
  1300. cfg := &Config{
  1301. Models: map[SelectedModelType]SelectedModel{
  1302. "large": {
  1303. MaxTokens: 100,
  1304. },
  1305. },
  1306. }
  1307. cfg.setDefaults("/tmp", "")
  1308. env := env.NewFromMap(map[string]string{})
  1309. resolver := NewEnvironmentVariableResolver(env)
  1310. err := cfg.configureProviders(testStore(cfg), env, resolver, knownProviders)
  1311. require.NoError(t, err)
  1312. err = configureSelectedModels(testStore(cfg), knownProviders)
  1313. require.NoError(t, err)
  1314. large := cfg.Models[SelectedModelTypeLarge]
  1315. require.Equal(t, "large-model", large.Model)
  1316. require.Equal(t, "openai", large.Provider)
  1317. require.Equal(t, int64(100), large.MaxTokens)
  1318. })
  1319. }