package config import ( "io" "log/slog" "os" "path/filepath" "strings" "testing" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { slog.SetDefault(slog.New(slog.NewTextHandler(io.Discard, nil))) exitVal := m.Run() os.Exit(exitVal) } func TestConfig_LoadFromReaders(t *testing.T) { data1 := strings.NewReader(`{"providers": {"openai": {"api_key": "key1", "base_url": "https://api.openai.com/v1"}}}`) data2 := strings.NewReader(`{"providers": {"openai": {"api_key": "key2", "base_url": "https://api.openai.com/v2"}}}`) data3 := strings.NewReader(`{"providers": {"openai": {}}}`) loadedConfig, err := loadFromReaders([]io.Reader{data1, data2, data3}) require.NoError(t, err) require.NotNil(t, loadedConfig) require.Equal(t, 1, loadedConfig.Providers.Len()) pc, _ := loadedConfig.Providers.Get("openai") require.Equal(t, "key2", pc.APIKey) require.Equal(t, "https://api.openai.com/v2", pc.BaseURL) } func TestConfig_setDefaults(t *testing.T) { cfg := &Config{} cfg.setDefaults("/tmp", "") require.NotNil(t, cfg.Options) require.NotNil(t, cfg.Options.TUI) require.NotNil(t, cfg.Options.ContextPaths) require.NotNil(t, cfg.Providers) require.NotNil(t, cfg.Models) require.NotNil(t, cfg.LSP) require.NotNil(t, cfg.MCP) require.Equal(t, filepath.Join("/tmp", ".crush"), cfg.Options.DataDirectory) for _, path := range defaultContextPaths { require.Contains(t, cfg.Options.ContextPaths, path) } require.Equal(t, "/tmp", cfg.workingDir) } func TestConfig_configureProviders(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", Models: []catwalk.Model{{ ID: "test-model", }}, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, 1, cfg.Providers.Len()) // We want to make sure that we keep the configured API key as a placeholder pc, _ := cfg.Providers.Get("openai") require.Equal(t, "$OPENAI_API_KEY", pc.APIKey) } func TestConfig_configureProvidersWithOverride(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", Models: []catwalk.Model{{ ID: "test-model", }}, }, } cfg := &Config{ Providers: csync.NewMap[string, ProviderConfig](), } cfg.Providers.Set("openai", ProviderConfig{ APIKey: "xyz", BaseURL: "https://api.openai.com/v2", Models: []catwalk.Model{ { ID: "test-model", Name: "Updated", }, { ID: "another-model", }, }, }) cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, 1, cfg.Providers.Len()) // We want to make sure that we keep the configured API key as a placeholder pc, _ := cfg.Providers.Get("openai") require.Equal(t, "xyz", pc.APIKey) require.Equal(t, "https://api.openai.com/v2", pc.BaseURL) require.Len(t, pc.Models, 2) require.Equal(t, "Updated", pc.Models[0].Name) } func TestConfig_configureProvidersWithNewProvider(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", Models: []catwalk.Model{{ ID: "test-model", }}, }, } cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "xyz", BaseURL: "https://api.someendpoint.com/v2", Models: []catwalk.Model{ { ID: "test-model", }, }, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Should be to because of the env variable require.Equal(t, cfg.Providers.Len(), 2) // We want to make sure that we keep the configured API key as a placeholder pc, _ := cfg.Providers.Get("custom") require.Equal(t, "xyz", pc.APIKey) // Make sure we set the ID correctly require.Equal(t, "custom", pc.ID) require.Equal(t, "https://api.someendpoint.com/v2", pc.BaseURL) require.Len(t, pc.Models, 1) _, ok := cfg.Providers.Get("openai") require.True(t, ok, "OpenAI provider should still be present") } func TestConfig_configureProvidersBedrockWithCredentials(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", Models: []catwalk.Model{{ ID: "anthropic.claude-sonnet-4-20250514-v1:0", }}, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "AWS_ACCESS_KEY_ID": "test-key-id", "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) bedrockProvider, ok := cfg.Providers.Get("bedrock") require.True(t, ok, "Bedrock provider should be present") require.Len(t, bedrockProvider.Models, 1) require.Equal(t, "anthropic.claude-sonnet-4-20250514-v1:0", bedrockProvider.Models[0].ID) } func TestConfig_configureProvidersBedrockWithoutCredentials(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", Models: []catwalk.Model{{ ID: "anthropic.claude-sonnet-4-20250514-v1:0", }}, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without credentials require.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersBedrockWithoutUnsupportedModel(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", Models: []catwalk.Model{{ ID: "some-random-model", }}, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "AWS_ACCESS_KEY_ID": "test-key-id", "AWS_SECRET_ACCESS_KEY": "test-secret-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.Error(t, err) } func TestConfig_configureProvidersVertexAIWithCredentials(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "VERTEXAI_PROJECT": "test-project", "VERTEXAI_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) vertexProvider, ok := cfg.Providers.Get("vertexai") require.True(t, ok, "VertexAI provider should be present") require.Len(t, vertexProvider.Models, 1) require.Equal(t, "gemini-pro", vertexProvider.Models[0].ID) require.Equal(t, "test-project", vertexProvider.ExtraParams["project"]) require.Equal(t, "us-central1", vertexProvider.ExtraParams["location"]) } func TestConfig_configureProvidersVertexAIWithoutCredentials(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "false", "GOOGLE_CLOUD_PROJECT": "test-project", "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without proper credentials require.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersVertexAIMissingProject(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "true", "GOOGLE_CLOUD_LOCATION": "us-central1", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) // Provider should not be configured without project require.Equal(t, cfg.Providers.Len(), 0) } func TestConfig_configureProvidersSetProviderID(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", Models: []catwalk.Model{{ ID: "test-model", }}, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) // Provider ID should be set pc, _ := cfg.Providers.Get("openai") require.Equal(t, "openai", pc.ID) } func TestConfig_EnabledProviders(t *testing.T) { t.Run("all providers enabled", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", Disable: false, }, "anthropic": { ID: "anthropic", APIKey: "key2", Disable: false, }, }), } enabled := cfg.EnabledProviders() require.Len(t, enabled, 2) }) t.Run("some providers disabled", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", Disable: false, }, "anthropic": { ID: "anthropic", APIKey: "key2", Disable: true, }, }), } enabled := cfg.EnabledProviders() require.Len(t, enabled, 1) require.Equal(t, "openai", enabled[0].ID) }) t.Run("empty providers map", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMap[string, ProviderConfig](), } enabled := cfg.EnabledProviders() require.Len(t, enabled, 0) }) } func TestConfig_IsConfigured(t *testing.T) { t.Run("returns true when at least one provider is enabled", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", Disable: false, }, }), } require.True(t, cfg.IsConfigured()) }) t.Run("returns false when no providers are configured", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMap[string, ProviderConfig](), } require.False(t, cfg.IsConfigured()) }) t.Run("returns false when all providers are disabled", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { ID: "openai", APIKey: "key1", Disable: true, }, "anthropic": { ID: "anthropic", APIKey: "key2", Disable: true, }, }), } require.False(t, cfg.IsConfigured()) }) } func TestConfig_setupAgentsWithNoDisabledTools(t *testing.T) { cfg := &Config{ Options: &Options{ DisabledTools: []string{}, }, } cfg.SetupAgents() coderAgent, ok := cfg.Agents[AgentCoder] require.True(t, ok) assert.Equal(t, allToolNames(), coderAgent.AllowedTools) taskAgent, ok := cfg.Agents[AgentTask] require.True(t, ok) assert.Equal(t, []string{"glob", "grep", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) } func TestConfig_setupAgentsWithDisabledTools(t *testing.T) { cfg := &Config{ Options: &Options{ DisabledTools: []string{ "edit", "download", "grep", }, }, } cfg.SetupAgents() coderAgent, ok := cfg.Agents[AgentCoder] require.True(t, ok) assert.Equal(t, []string{"agent", "bash", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "view", "write"}, coderAgent.AllowedTools) taskAgent, ok := cfg.Agents[AgentTask] require.True(t, ok) assert.Equal(t, []string{"glob", "ls", "sourcegraph", "view"}, taskAgent.AllowedTools) } func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) { cfg := &Config{ Options: &Options{ DisabledTools: []string{ "glob", "grep", "ls", "sourcegraph", "view", }, }, } cfg.SetupAgents() coderAgent, ok := cfg.Agents[AgentCoder] require.True(t, ok) assert.Equal(t, []string{"agent", "bash", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "write"}, coderAgent.AllowedTools) taskAgent, ok := cfg.Agents[AgentTask] require.True(t, ok) assert.Equal(t, []string{}, taskAgent.AllowedTools) } func TestConfig_configureProvidersWithDisabledProvider(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "https://api.openai.com/v1", Models: []catwalk.Model{{ ID: "test-model", }}, }, } cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { Disable: true, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) prov, exists := cfg.Providers.Get("openai") require.True(t, exists) require.True(t, prov.Disable) } func TestConfig_configureProvidersCustomProviderValidation(t *testing.T) { t.Run("custom provider with missing API key is allowed, but not known providers", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{{ ID: "test-model", }}, }, "openai": { APIKey: "$MISSING", }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) _, exists := cfg.Providers.Get("custom") require.True(t, exists) }) t.Run("custom provider with missing BaseURL is removed", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", Models: []catwalk.Model{{ ID: "test-model", }}, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) _, exists := cfg.Providers.Get("custom") require.False(t, exists) }) t.Run("custom provider with no models is removed", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{}, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) _, exists := cfg.Providers.Get("custom") require.False(t, exists) }) t.Run("custom provider with unsupported type is removed", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Type: "unsupported", Models: []catwalk.Model{{ ID: "test-model", }}, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) _, exists := cfg.Providers.Get("custom") require.False(t, exists) }) t.Run("valid custom provider is kept and ID is set", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Type: catwalk.TypeOpenAI, Models: []catwalk.Model{{ ID: "test-model", }}, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) customProvider, exists := cfg.Providers.Get("custom") require.True(t, exists) require.Equal(t, "custom", customProvider.ID) require.Equal(t, "test-key", customProvider.APIKey) require.Equal(t, "https://api.custom.com/v1", customProvider.BaseURL) }) t.Run("custom anthropic provider is supported", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom-anthropic": { APIKey: "test-key", BaseURL: "https://api.anthropic.com/v1", Type: catwalk.TypeAnthropic, Models: []catwalk.Model{{ ID: "claude-3-sonnet", }}, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) customProvider, exists := cfg.Providers.Get("custom-anthropic") require.True(t, exists) require.Equal(t, "custom-anthropic", customProvider.ID) require.Equal(t, "test-key", customProvider.APIKey) require.Equal(t, "https://api.anthropic.com/v1", customProvider.BaseURL) require.Equal(t, catwalk.TypeAnthropic, customProvider.Type) }) t.Run("disabled custom provider is removed", func(t *testing.T) { cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Type: catwalk.TypeOpenAI, Disable: true, Models: []catwalk.Model{{ ID: "test-model", }}, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, []catwalk.Provider{}) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) _, exists := cfg.Providers.Get("custom") require.False(t, exists) }) } func TestConfig_configureProvidersEnhancedCredentialValidation(t *testing.T) { t.Run("VertexAI provider removed when credentials missing with existing config", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: catwalk.InferenceProviderVertexAI, APIKey: "", APIEndpoint: "", Models: []catwalk.Model{{ ID: "gemini-pro", }}, }, } cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "vertexai": { BaseURL: "custom-url", }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "GOOGLE_GENAI_USE_VERTEXAI": "false", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) _, exists := cfg.Providers.Get("vertexai") require.False(t, exists) }) t.Run("Bedrock provider removed when AWS credentials missing with existing config", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: catwalk.InferenceProviderBedrock, APIKey: "", APIEndpoint: "", Models: []catwalk.Model{{ ID: "anthropic.claude-sonnet-4-20250514-v1:0", }}, }, } cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "bedrock": { BaseURL: "custom-url", }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) _, exists := cfg.Providers.Get("bedrock") require.False(t, exists) }) t.Run("provider removed when API key missing with existing config", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING_API_KEY", APIEndpoint: "https://api.openai.com/v1", Models: []catwalk.Model{{ ID: "test-model", }}, }, } cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { BaseURL: "custom-url", }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 0) _, exists := cfg.Providers.Get("openai") require.False(t, exists) }) t.Run("known provider should still be added if the endpoint is missing the client will use default endpoints", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$OPENAI_API_KEY", APIEndpoint: "$MISSING_ENDPOINT", Models: []catwalk.Model{{ ID: "test-model", }}, }, } cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "openai": { APIKey: "test-key", }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{ "OPENAI_API_KEY": "test-key", }) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) require.Equal(t, cfg.Providers.Len(), 1) _, exists := cfg.Providers.Get("openai") require.True(t, exists) }) } func TestConfig_defaultModelSelection(t *testing.T) { t.Run("default behavior uses the default models for given provider", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, }, { ID: "small-model", DefaultMaxTokens: 500, }, }, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) require.NoError(t, err) require.Equal(t, "large-model", large.Model) require.Equal(t, "openai", large.Provider) require.Equal(t, int64(1000), large.MaxTokens) require.Equal(t, "small-model", small.Model) require.Equal(t, "openai", small.Provider) require.Equal(t, int64(500), small.MaxTokens) }) t.Run("should error if no providers configured", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING_KEY", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, }, { ID: "small-model", DefaultMaxTokens: 500, }, }, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) require.Error(t, err) }) t.Run("should error if model is missing", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", Models: []catwalk.Model{ { ID: "not-large-model", DefaultMaxTokens: 1000, }, { ID: "small-model", DefaultMaxTokens: 500, }, }, }, } cfg := &Config{} cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) require.Error(t, err) }) t.Run("should configure the default models with a custom provider", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING", // will not be included in the config DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", Models: []catwalk.Model{ { ID: "not-large-model", DefaultMaxTokens: 1000, }, { ID: "small-model", DefaultMaxTokens: 500, }, }, }, } cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{ { ID: "model", DefaultMaxTokens: 600, }, }, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) require.NoError(t, err) require.Equal(t, "model", large.Model) require.Equal(t, "custom", large.Provider) require.Equal(t, int64(600), large.MaxTokens) require.Equal(t, "model", small.Model) require.Equal(t, "custom", small.Provider) require.Equal(t, int64(600), small.MaxTokens) }) t.Run("should fail if no model configured", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "$MISSING", // will not be included in the config DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", Models: []catwalk.Model{ { ID: "not-large-model", DefaultMaxTokens: 1000, }, { ID: "small-model", DefaultMaxTokens: 500, }, }, }, } cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{}, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) _, _, err = cfg.defaultModelSelection(knownProviders) require.Error(t, err) }) t.Run("should use the default provider first", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "set", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, }, { ID: "small-model", DefaultMaxTokens: 500, }, }, }, } cfg := &Config{ Providers: csync.NewMapFrom(map[string]ProviderConfig{ "custom": { APIKey: "test-key", BaseURL: "https://api.custom.com/v1", Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, }, }, }, }), } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) large, small, err := cfg.defaultModelSelection(knownProviders) require.NoError(t, err) require.Equal(t, "large-model", large.Model) require.Equal(t, "openai", large.Provider) require.Equal(t, int64(1000), large.MaxTokens) require.Equal(t, "small-model", small.Model) require.Equal(t, "openai", small.Provider) require.Equal(t, int64(500), small.MaxTokens) }) } func TestConfig_configureSelectedModels(t *testing.T) { t.Run("should override defaults", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", Models: []catwalk.Model{ { ID: "larger-model", DefaultMaxTokens: 2000, }, { ID: "large-model", DefaultMaxTokens: 1000, }, { ID: "small-model", DefaultMaxTokens: 500, }, }, }, } cfg := &Config{ Models: map[SelectedModelType]SelectedModel{ "large": { Model: "larger-model", }, }, } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) err = cfg.configureSelectedModels(knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] small := cfg.Models[SelectedModelTypeSmall] require.Equal(t, "larger-model", large.Model) require.Equal(t, "openai", large.Provider) require.Equal(t, int64(2000), large.MaxTokens) require.Equal(t, "small-model", small.Model) require.Equal(t, "openai", small.Provider) require.Equal(t, int64(500), small.MaxTokens) }) t.Run("should be possible to use multiple providers", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, }, { ID: "small-model", DefaultMaxTokens: 500, }, }, }, { ID: "anthropic", APIKey: "abc", DefaultLargeModelID: "a-large-model", DefaultSmallModelID: "a-small-model", Models: []catwalk.Model{ { ID: "a-large-model", DefaultMaxTokens: 1000, }, { ID: "a-small-model", DefaultMaxTokens: 200, }, }, }, } cfg := &Config{ Models: map[SelectedModelType]SelectedModel{ "small": { Model: "a-small-model", Provider: "anthropic", MaxTokens: 300, }, }, } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) err = cfg.configureSelectedModels(knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] small := cfg.Models[SelectedModelTypeSmall] require.Equal(t, "large-model", large.Model) require.Equal(t, "openai", large.Provider) require.Equal(t, int64(1000), large.MaxTokens) require.Equal(t, "a-small-model", small.Model) require.Equal(t, "anthropic", small.Provider) require.Equal(t, int64(300), small.MaxTokens) }) t.Run("should override the max tokens only", func(t *testing.T) { knownProviders := []catwalk.Provider{ { ID: "openai", APIKey: "abc", DefaultLargeModelID: "large-model", DefaultSmallModelID: "small-model", Models: []catwalk.Model{ { ID: "large-model", DefaultMaxTokens: 1000, }, { ID: "small-model", DefaultMaxTokens: 500, }, }, }, } cfg := &Config{ Models: map[SelectedModelType]SelectedModel{ "large": { MaxTokens: 100, }, }, } cfg.setDefaults("/tmp", "") env := env.NewFromMap(map[string]string{}) resolver := NewEnvironmentVariableResolver(env) err := cfg.configureProviders(env, resolver, knownProviders) require.NoError(t, err) err = cfg.configureSelectedModels(knownProviders) require.NoError(t, err) large := cfg.Models[SelectedModelTypeLarge] require.Equal(t, "large-model", large.Model) require.Equal(t, "openai", large.Provider) require.Equal(t, int64(100), large.MaxTokens) }) }