package config import ( "context" "encoding/json" "fmt" "io" "log/slog" "maps" "os" "os/exec" "path/filepath" "runtime" "slices" "strconv" "strings" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/home" "github.com/charmbracelet/crush/internal/log" powernapConfig "github.com/charmbracelet/x/powernap/pkg/config" ) const defaultCatwalkURL = "https://catwalk.charm.sh" // LoadReader config via io.Reader. func LoadReader(fd io.Reader) (*Config, error) { data, err := io.ReadAll(fd) if err != nil { return nil, err } var config Config err = json.Unmarshal(data, &config) if err != nil { return nil, err } return &config, err } // Load loads the configuration from the default paths. func Load(workingDir, dataDir string, debug bool) (*Config, error) { configPaths := lookupConfigs(workingDir) cfg, err := loadFromConfigPaths(configPaths) if err != nil { return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err) } cfg.dataConfigDir = GlobalConfigData() cfg.setDefaults(workingDir, dataDir) if debug { cfg.Options.Debug = true } // Setup logs log.Setup( filepath.Join(cfg.Options.DataDirectory, "logs", fmt.Sprintf("%s.log", appName)), cfg.Options.Debug, ) if !isInsideWorktree() { const depth = 2 const items = 100 slog.Warn("No git repository detected in working directory, will limit file walk operations", "depth", depth, "items", items) assignIfNil(&cfg.Tools.Ls.MaxDepth, depth) assignIfNil(&cfg.Tools.Ls.MaxItems, items) assignIfNil(&cfg.Options.TUI.Completions.MaxDepth, depth) assignIfNil(&cfg.Options.TUI.Completions.MaxItems, items) } // Load known providers, this loads the config from catwalk providers, err := Providers(cfg) if err != nil { return nil, err } cfg.knownProviders = providers env := env.New() // Configure providers valueResolver := NewShellVariableResolver(env) cfg.resolver = valueResolver if err := cfg.configureProviders(env, valueResolver, cfg.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure providers: %w", err) } if !cfg.IsConfigured() { slog.Warn("No providers configured") return cfg, nil } if err := cfg.configureSelectedModels(cfg.knownProviders); err != nil { return nil, fmt.Errorf("failed to configure selected models: %w", err) } cfg.SetupAgents() return cfg, nil } func PushPopCrushEnv() func() { found := []string{} for _, ev := range os.Environ() { if strings.HasPrefix(ev, "CRUSH_") { pair := strings.SplitN(ev, "=", 2) if len(pair) != 2 { continue } found = append(found, strings.TrimPrefix(pair[0], "CRUSH_")) } } backups := make(map[string]string) for _, ev := range found { backups[ev] = os.Getenv(ev) } for _, ev := range found { os.Setenv(ev, os.Getenv("CRUSH_"+ev)) } restore := func() { for k, v := range backups { os.Setenv(k, v) } } return restore } func (c *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error { knownProviderNames := make(map[string]bool) restore := PushPopCrushEnv() defer restore() for _, p := range knownProviders { knownProviderNames[string(p.ID)] = true config, configExists := c.Providers.Get(string(p.ID)) // if the user configured a known provider we need to allow it to override a couple of parameters if configExists { if config.BaseURL != "" { p.APIEndpoint = config.BaseURL } if config.APIKey != "" { p.APIKey = config.APIKey } if len(config.Models) > 0 { models := []catwalk.Model{} seen := make(map[string]bool) for _, model := range config.Models { if seen[model.ID] { continue } seen[model.ID] = true if model.Name == "" { model.Name = model.ID } models = append(models, model) } for _, model := range p.Models { if seen[model.ID] { continue } seen[model.ID] = true if model.Name == "" { model.Name = model.ID } models = append(models, model) } p.Models = models } } headers := map[string]string{} if len(p.DefaultHeaders) > 0 { maps.Copy(headers, p.DefaultHeaders) } if len(config.ExtraHeaders) > 0 { maps.Copy(headers, config.ExtraHeaders) } prepared := ProviderConfig{ ID: string(p.ID), Name: p.Name, BaseURL: p.APIEndpoint, APIKey: p.APIKey, Type: p.Type, Disable: config.Disable, SystemPromptPrefix: config.SystemPromptPrefix, ExtraHeaders: headers, ExtraBody: config.ExtraBody, ExtraParams: make(map[string]string), Models: p.Models, } switch p.ID { // Handle specific providers that require additional configuration case catwalk.InferenceProviderVertexAI: if !hasVertexCredentials(env) { if configExists { slog.Warn("Skipping Vertex AI provider due to missing credentials") c.Providers.Del(string(p.ID)) } continue } prepared.ExtraParams["project"] = env.Get("VERTEXAI_PROJECT") prepared.ExtraParams["location"] = env.Get("VERTEXAI_LOCATION") case catwalk.InferenceProviderAzure: endpoint, err := resolver.ResolveValue(p.APIEndpoint) if err != nil || endpoint == "" { if configExists { slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err) c.Providers.Del(string(p.ID)) } continue } prepared.BaseURL = endpoint prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION") case catwalk.InferenceProviderBedrock: if !hasAWSCredentials(env) { if configExists { slog.Warn("Skipping Bedrock provider due to missing AWS credentials") c.Providers.Del(string(p.ID)) } continue } prepared.ExtraParams["region"] = env.Get("AWS_REGION") if prepared.ExtraParams["region"] == "" { prepared.ExtraParams["region"] = env.Get("AWS_DEFAULT_REGION") } for _, model := range p.Models { if !strings.HasPrefix(model.ID, "anthropic.") { return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID) } } default: // if the provider api or endpoint are missing we skip them v, err := resolver.ResolveValue(p.APIKey) if v == "" || err != nil { if configExists { slog.Warn("Skipping provider due to missing API key", "provider", p.ID) c.Providers.Del(string(p.ID)) } continue } } c.Providers.Set(string(p.ID), prepared) } // validate the custom providers for id, providerConfig := range c.Providers.Seq2() { if knownProviderNames[id] { continue } // Make sure the provider ID is set providerConfig.ID = id if providerConfig.Name == "" { providerConfig.Name = id // Use ID as name if not set } // default to OpenAI if not set if providerConfig.Type == "" { providerConfig.Type = catwalk.TypeOpenAICompat } if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) { slog.Warn("Skipping custom provider due to unsupported provider type", "provider", id) c.Providers.Del(id) continue } if providerConfig.Disable { slog.Debug("Skipping custom provider due to disable flag", "provider", id) c.Providers.Del(id) continue } if providerConfig.APIKey == "" { slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id) } if providerConfig.BaseURL == "" { slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id) c.Providers.Del(id) continue } if len(providerConfig.Models) == 0 { slog.Warn("Skipping custom provider because the provider has no models", "provider", id) c.Providers.Del(id) continue } apiKey, err := resolver.ResolveValue(providerConfig.APIKey) if apiKey == "" || err != nil { slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id) } baseURL, err := resolver.ResolveValue(providerConfig.BaseURL) if baseURL == "" || err != nil { slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err) c.Providers.Del(id) continue } c.Providers.Set(id, providerConfig) } return nil } func (c *Config) setDefaults(workingDir, dataDir string) { c.workingDir = workingDir if c.Options == nil { c.Options = &Options{} } if c.Options.TUI == nil { c.Options.TUI = &TUIOptions{} } if c.Options.ContextPaths == nil { c.Options.ContextPaths = []string{} } if dataDir != "" { c.Options.DataDirectory = dataDir } else if c.Options.DataDirectory == "" { if path, ok := fsext.LookupClosest(workingDir, defaultDataDirectory); ok { c.Options.DataDirectory = path } else { c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory) } } if c.Providers == nil { c.Providers = csync.NewMap[string, ProviderConfig]() } if c.Models == nil { c.Models = make(map[SelectedModelType]SelectedModel) } if c.MCP == nil { c.MCP = make(map[string]MCPConfig) } if c.LSP == nil { c.LSP = make(map[string]LSPConfig) } // Apply defaults to LSP configurations c.applyLSPDefaults() // Add the default context paths if they are not already present c.Options.ContextPaths = append(defaultContextPaths, c.Options.ContextPaths...) slices.Sort(c.Options.ContextPaths) c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths) if str, ok := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE"); ok { c.Options.DisableProviderAutoUpdate, _ = strconv.ParseBool(str) } if c.Options.Attribution == nil { c.Options.Attribution = &Attribution{ CoAuthoredBy: true, GeneratedWith: true, } } } // applyLSPDefaults applies default values from powernap to LSP configurations func (c *Config) applyLSPDefaults() { // Get powernap's default configuration configManager := powernapConfig.NewManager() configManager.LoadDefaults() // Apply defaults to each LSP configuration for name, cfg := range c.LSP { // Try to get defaults from powernap based on name or command name. base, ok := configManager.GetServer(name) if !ok { base, ok = configManager.GetServer(cfg.Command) if !ok { continue } } if cfg.Options == nil { cfg.Options = base.Settings } if cfg.InitOptions == nil { cfg.InitOptions = base.InitOptions } if len(cfg.FileTypes) == 0 { cfg.FileTypes = base.FileTypes } if len(cfg.RootMarkers) == 0 { cfg.RootMarkers = base.RootMarkers } if cfg.Command == "" { cfg.Command = base.Command } if len(cfg.Args) == 0 { cfg.Args = base.Args } if len(cfg.Env) == 0 { cfg.Env = base.Environment } // Update the config in the map c.LSP[name] = cfg } } func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) { if len(knownProviders) == 0 && c.Providers.Len() == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return largeModel, smallModel, err } // Use the first provider enabled based on the known providers order // if no provider found that is known use the first provider configured for _, p := range knownProviders { providerConfig, ok := c.Providers.Get(string(p.ID)) if !ok || providerConfig.Disable { continue } defaultLargeModel := c.GetModel(string(p.ID), p.DefaultLargeModelID) if defaultLargeModel == nil { err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID) return largeModel, smallModel, err } largeModel = SelectedModel{ Provider: string(p.ID), Model: defaultLargeModel.ID, MaxTokens: defaultLargeModel.DefaultMaxTokens, ReasoningEffort: defaultLargeModel.DefaultReasoningEffort, } defaultSmallModel := c.GetModel(string(p.ID), p.DefaultSmallModelID) if defaultSmallModel == nil { err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID) return largeModel, smallModel, err } smallModel = SelectedModel{ Provider: string(p.ID), Model: defaultSmallModel.ID, MaxTokens: defaultSmallModel.DefaultMaxTokens, ReasoningEffort: defaultSmallModel.DefaultReasoningEffort, } return largeModel, smallModel, err } enabledProviders := c.EnabledProviders() slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int { return strings.Compare(a.ID, b.ID) }) if len(enabledProviders) == 0 { err = fmt.Errorf("no providers configured, please configure at least one provider") return largeModel, smallModel, err } providerConfig := enabledProviders[0] if len(providerConfig.Models) == 0 { err = fmt.Errorf("provider %s has no models configured", providerConfig.ID) return largeModel, smallModel, err } defaultLargeModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID) largeModel = SelectedModel{ Provider: providerConfig.ID, Model: defaultLargeModel.ID, MaxTokens: defaultLargeModel.DefaultMaxTokens, } defaultSmallModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID) smallModel = SelectedModel{ Provider: providerConfig.ID, Model: defaultSmallModel.ID, MaxTokens: defaultSmallModel.DefaultMaxTokens, } return largeModel, smallModel, err } func (c *Config) configureSelectedModels(knownProviders []catwalk.Provider) error { defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders) if err != nil { return fmt.Errorf("failed to select default models: %w", err) } large, small := defaultLarge, defaultSmall largeModelSelected, largeModelConfigured := c.Models[SelectedModelTypeLarge] if largeModelConfigured { if largeModelSelected.Model != "" { large.Model = largeModelSelected.Model } if largeModelSelected.Provider != "" { large.Provider = largeModelSelected.Provider } model := c.GetModel(large.Provider, large.Model) if model == nil { large = defaultLarge // override the model type to large err := c.UpdatePreferredModel(SelectedModelTypeLarge, large) if err != nil { return fmt.Errorf("failed to update preferred large model: %w", err) } } else { if largeModelSelected.MaxTokens > 0 { large.MaxTokens = largeModelSelected.MaxTokens } else { large.MaxTokens = model.DefaultMaxTokens } if largeModelSelected.ReasoningEffort != "" { large.ReasoningEffort = largeModelSelected.ReasoningEffort } large.Think = largeModelSelected.Think if largeModelSelected.Temperature != nil { large.Temperature = largeModelSelected.Temperature } if largeModelSelected.TopP != nil { large.TopP = largeModelSelected.TopP } if largeModelSelected.TopK != nil { large.TopK = largeModelSelected.TopK } if largeModelSelected.FrequencyPenalty != nil { large.FrequencyPenalty = largeModelSelected.FrequencyPenalty } if largeModelSelected.PresencePenalty != nil { large.PresencePenalty = largeModelSelected.PresencePenalty } } } smallModelSelected, smallModelConfigured := c.Models[SelectedModelTypeSmall] if smallModelConfigured { if smallModelSelected.Model != "" { small.Model = smallModelSelected.Model } if smallModelSelected.Provider != "" { small.Provider = smallModelSelected.Provider } model := c.GetModel(small.Provider, small.Model) if model == nil { small = defaultSmall // override the model type to small err := c.UpdatePreferredModel(SelectedModelTypeSmall, small) if err != nil { return fmt.Errorf("failed to update preferred small model: %w", err) } } else { if smallModelSelected.MaxTokens > 0 { small.MaxTokens = smallModelSelected.MaxTokens } else { small.MaxTokens = model.DefaultMaxTokens } if smallModelSelected.ReasoningEffort != "" { small.ReasoningEffort = smallModelSelected.ReasoningEffort } if smallModelSelected.Temperature != nil { small.Temperature = smallModelSelected.Temperature } if smallModelSelected.TopP != nil { small.TopP = smallModelSelected.TopP } if smallModelSelected.TopK != nil { small.TopK = smallModelSelected.TopK } if smallModelSelected.FrequencyPenalty != nil { small.FrequencyPenalty = smallModelSelected.FrequencyPenalty } if smallModelSelected.PresencePenalty != nil { small.PresencePenalty = smallModelSelected.PresencePenalty } small.Think = smallModelSelected.Think } } c.Models[SelectedModelTypeLarge] = large c.Models[SelectedModelTypeSmall] = small return nil } // lookupConfigs searches config files recursively from CWD up to FS root func lookupConfigs(cwd string) []string { // prepend default config paths configPaths := []string{ GlobalConfig(), GlobalConfigData(), } configNames := []string{appName + ".json", "." + appName + ".json"} foundConfigs, err := fsext.Lookup(cwd, configNames...) if err != nil { // returns at least default configs return configPaths } // reverse order so last config has more priority slices.Reverse(foundConfigs) return append(configPaths, foundConfigs...) } func loadFromConfigPaths(configPaths []string) (*Config, error) { var configs []io.Reader for _, path := range configPaths { fd, err := os.Open(path) if err != nil { if os.IsNotExist(err) { continue } return nil, fmt.Errorf("failed to open config file %s: %w", path, err) } defer fd.Close() configs = append(configs, fd) } return loadFromReaders(configs) } func loadFromReaders(readers []io.Reader) (*Config, error) { if len(readers) == 0 { return &Config{}, nil } merged, err := Merge(readers) if err != nil { return nil, fmt.Errorf("failed to merge configuration readers: %w", err) } return LoadReader(merged) } func hasVertexCredentials(env env.Env) bool { hasProject := env.Get("VERTEXAI_PROJECT") != "" hasLocation := env.Get("VERTEXAI_LOCATION") != "" return hasProject && hasLocation } func hasAWSCredentials(env env.Env) bool { if env.Get("AWS_BEARER_TOKEN_BEDROCK") != "" { return true } if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" { return true } if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" { return true } if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" { return true } if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" || env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" { return true } if _, err := os.Stat(filepath.Join(home.Dir(), ".aws/credentials")); err == nil { return true } return false } // GlobalConfig returns the global configuration file path for the application. func GlobalConfig() string { xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") if xdgConfigHome != "" { return filepath.Join(xdgConfigHome, appName, fmt.Sprintf("%s.json", appName)) } return filepath.Join(home.Dir(), ".config", appName, fmt.Sprintf("%s.json", appName)) } // GlobalConfigData returns the path to the main data directory for the application. // this config is used when the app overrides configurations instead of updating the global config. func GlobalConfigData() string { xdgDataHome := os.Getenv("XDG_DATA_HOME") if xdgDataHome != "" { return filepath.Join(xdgDataHome, appName, fmt.Sprintf("%s.json", appName)) } // return the path to the main data directory // for windows, it should be in `%LOCALAPPDATA%/crush/` // for linux and macOS, it should be in `$HOME/.local/share/crush/` if runtime.GOOS == "windows" { localAppData := os.Getenv("LOCALAPPDATA") if localAppData == "" { localAppData = filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local") } return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName)) } return filepath.Join(home.Dir(), ".local", "share", appName, fmt.Sprintf("%s.json", appName)) } func assignIfNil[T any](ptr **T, val T) { if *ptr == nil { *ptr = &val } } func isInsideWorktree() bool { bts, err := exec.CommandContext( context.Background(), "git", "rev-parse", "--is-inside-work-tree", ).CombinedOutput() return err == nil && strings.TrimSpace(string(bts)) == "true" }