load.go 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902
  1. package config
  2. import (
  3. "cmp"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "maps"
  9. "os"
  10. "os/exec"
  11. "path/filepath"
  12. "regexp"
  13. "runtime"
  14. "slices"
  15. "strconv"
  16. "strings"
  17. "testing"
  18. "charm.land/catwalk/pkg/catwalk"
  19. "github.com/charmbracelet/crush/internal/agent/hyper"
  20. "github.com/charmbracelet/crush/internal/csync"
  21. "github.com/charmbracelet/crush/internal/env"
  22. "github.com/charmbracelet/crush/internal/fsext"
  23. "github.com/charmbracelet/crush/internal/home"
  24. powernapConfig "github.com/charmbracelet/x/powernap/pkg/config"
  25. "github.com/qjebbs/go-jsons"
  26. )
  27. const defaultCatwalkURL = "https://catwalk.charm.sh"
  28. // Load loads the configuration from the default paths and returns a
  29. // ConfigStore that owns both the pure-data Config and all runtime state.
  30. func Load(workingDir, dataDir string, debug bool) (*ConfigStore, error) {
  31. configPaths := lookupConfigs(workingDir)
  32. cfg, err := loadFromConfigPaths(configPaths)
  33. if err != nil {
  34. return nil, fmt.Errorf("failed to load config from paths %v: %w", configPaths, err)
  35. }
  36. cfg.setDefaults(workingDir, dataDir)
  37. store := &ConfigStore{
  38. config: cfg,
  39. workingDir: workingDir,
  40. globalDataPath: GlobalConfigData(),
  41. workspacePath: filepath.Join(cfg.Options.DataDirectory, fmt.Sprintf("%s.json", appName)),
  42. }
  43. if debug {
  44. cfg.Options.Debug = true
  45. }
  46. // Load workspace config last so it has highest priority.
  47. if wsData, err := os.ReadFile(store.workspacePath); err == nil && len(wsData) > 0 {
  48. merged, mergeErr := loadFromBytes(append([][]byte{mustMarshalConfig(cfg)}, wsData))
  49. if mergeErr == nil {
  50. // Preserve defaults that setDefaults already applied.
  51. dataDir := cfg.Options.DataDirectory
  52. *cfg = *merged
  53. cfg.setDefaults(workingDir, dataDir)
  54. store.config = cfg
  55. }
  56. }
  57. // Validate hooks after all config merging is complete so workspace
  58. // hooks also get their matcher regexes compiled.
  59. if err := cfg.ValidateHooks(); err != nil {
  60. return nil, fmt.Errorf("invalid hook configuration: %w", err)
  61. }
  62. if !isInsideWorktree() {
  63. const depth = 2
  64. const items = 100
  65. slog.Warn("No git repository detected in working directory, will limit file walk operations", "depth", depth, "items", items)
  66. assignIfNil(&cfg.Tools.Ls.MaxDepth, depth)
  67. assignIfNil(&cfg.Tools.Ls.MaxItems, items)
  68. assignIfNil(&cfg.Options.TUI.Completions.MaxDepth, depth)
  69. assignIfNil(&cfg.Options.TUI.Completions.MaxItems, items)
  70. }
  71. if isAppleTerminal() {
  72. slog.Warn("Detected Apple Terminal, enabling transparent mode")
  73. assignIfNil(&cfg.Options.TUI.Transparent, true)
  74. }
  75. // Load known providers, this loads the config from catwalk
  76. providers, err := Providers(cfg)
  77. if err != nil {
  78. return nil, err
  79. }
  80. store.knownProviders = providers
  81. env := env.New()
  82. // Configure providers
  83. valueResolver := NewShellVariableResolver(env)
  84. store.resolver = valueResolver
  85. if err := cfg.configureProviders(store, env, valueResolver, store.knownProviders); err != nil {
  86. return nil, fmt.Errorf("failed to configure providers: %w", err)
  87. }
  88. if !cfg.IsConfigured() {
  89. slog.Warn("No providers configured")
  90. return store, nil
  91. }
  92. if err := configureSelectedModels(store, store.knownProviders); err != nil {
  93. return nil, fmt.Errorf("failed to configure selected models: %w", err)
  94. }
  95. store.SetupAgents()
  96. return store, nil
  97. }
  98. // mustMarshalConfig marshals the config to JSON bytes, returning empty JSON on
  99. // error.
  100. func mustMarshalConfig(cfg *Config) []byte {
  101. data, err := json.Marshal(cfg)
  102. if err != nil {
  103. return []byte("{}")
  104. }
  105. return data
  106. }
  107. func PushPopCrushEnv() func() {
  108. var found []string
  109. for _, ev := range os.Environ() {
  110. if strings.HasPrefix(ev, "CRUSH_") {
  111. pair := strings.SplitN(ev, "=", 2)
  112. if len(pair) != 2 {
  113. continue
  114. }
  115. found = append(found, strings.TrimPrefix(pair[0], "CRUSH_"))
  116. }
  117. }
  118. backups := make(map[string]string)
  119. for _, ev := range found {
  120. backups[ev] = os.Getenv(ev)
  121. }
  122. for _, ev := range found {
  123. os.Setenv(ev, os.Getenv("CRUSH_"+ev))
  124. }
  125. restore := func() {
  126. for k, v := range backups {
  127. os.Setenv(k, v)
  128. }
  129. }
  130. return restore
  131. }
  132. func (c *Config) configureProviders(store *ConfigStore, env env.Env, resolver VariableResolver, knownProviders []catwalk.Provider) error {
  133. knownProviderNames := make(map[string]bool)
  134. restore := PushPopCrushEnv()
  135. defer restore()
  136. // When disable_default_providers is enabled, skip all default/embedded
  137. // providers entirely. Users must fully specify any providers they want.
  138. // We skip to the custom provider validation loop which handles all
  139. // user-configured providers uniformly.
  140. if c.Options.DisableDefaultProviders {
  141. knownProviders = nil
  142. }
  143. for _, p := range knownProviders {
  144. knownProviderNames[string(p.ID)] = true
  145. config, configExists := c.Providers.Get(string(p.ID))
  146. // if the user configured a known provider we need to allow it to override a couple of parameters
  147. if configExists {
  148. if config.BaseURL != "" {
  149. p.APIEndpoint = config.BaseURL
  150. }
  151. if config.APIKey != "" {
  152. p.APIKey = config.APIKey
  153. }
  154. if len(config.Models) > 0 {
  155. models := []catwalk.Model{}
  156. seen := make(map[string]bool)
  157. for _, model := range config.Models {
  158. if seen[model.ID] {
  159. continue
  160. }
  161. seen[model.ID] = true
  162. if model.Name == "" {
  163. model.Name = model.ID
  164. }
  165. models = append(models, model)
  166. }
  167. for _, model := range p.Models {
  168. if seen[model.ID] {
  169. continue
  170. }
  171. seen[model.ID] = true
  172. if model.Name == "" {
  173. model.Name = model.ID
  174. }
  175. models = append(models, model)
  176. }
  177. p.Models = models
  178. }
  179. }
  180. headers := map[string]string{}
  181. if len(p.DefaultHeaders) > 0 {
  182. maps.Copy(headers, p.DefaultHeaders)
  183. }
  184. if len(config.ExtraHeaders) > 0 {
  185. maps.Copy(headers, config.ExtraHeaders)
  186. }
  187. for k, v := range headers {
  188. resolved, err := resolver.ResolveValue(v)
  189. if err != nil {
  190. slog.Error("Could not resolve provider header", "err", err.Error())
  191. continue
  192. }
  193. headers[k] = resolved
  194. }
  195. prepared := ProviderConfig{
  196. ID: string(p.ID),
  197. Name: p.Name,
  198. BaseURL: p.APIEndpoint,
  199. APIKey: p.APIKey,
  200. APIKeyTemplate: p.APIKey, // Store original template for re-resolution
  201. OAuthToken: config.OAuthToken,
  202. Type: p.Type,
  203. Disable: config.Disable,
  204. SystemPromptPrefix: config.SystemPromptPrefix,
  205. ExtraHeaders: headers,
  206. ExtraBody: config.ExtraBody,
  207. ExtraParams: make(map[string]string),
  208. Models: p.Models,
  209. }
  210. switch {
  211. case p.ID == catwalk.InferenceProviderAnthropic && config.OAuthToken != nil:
  212. // Claude Code subscription is not supported anymore. Remove to show onboarding.
  213. store.RemoveConfigField(ScopeGlobal, "providers.anthropic")
  214. c.Providers.Del(string(p.ID))
  215. continue
  216. case p.ID == catwalk.InferenceProviderCopilot && config.OAuthToken != nil:
  217. prepared.SetupGitHubCopilot()
  218. }
  219. switch p.ID {
  220. // Handle specific providers that require additional configuration
  221. case catwalk.InferenceProviderVertexAI:
  222. var (
  223. project = env.Get("VERTEXAI_PROJECT")
  224. location = env.Get("VERTEXAI_LOCATION")
  225. )
  226. if project == "" || location == "" {
  227. if configExists {
  228. slog.Warn("Skipping Vertex AI provider due to missing credentials")
  229. c.Providers.Del(string(p.ID))
  230. }
  231. continue
  232. }
  233. prepared.ExtraParams["project"] = project
  234. prepared.ExtraParams["location"] = location
  235. case catwalk.InferenceProviderAzure:
  236. endpoint, err := resolver.ResolveValue(p.APIEndpoint)
  237. if err != nil || endpoint == "" {
  238. if configExists {
  239. slog.Warn("Skipping Azure provider due to missing API endpoint", "provider", p.ID, "error", err)
  240. c.Providers.Del(string(p.ID))
  241. }
  242. continue
  243. }
  244. prepared.BaseURL = endpoint
  245. prepared.ExtraParams["apiVersion"] = env.Get("AZURE_OPENAI_API_VERSION")
  246. case catwalk.InferenceProviderBedrock:
  247. if !hasAWSCredentials(env) {
  248. if configExists {
  249. slog.Warn("Skipping Bedrock provider due to missing AWS credentials")
  250. c.Providers.Del(string(p.ID))
  251. }
  252. continue
  253. }
  254. prepared.ExtraParams["region"] = env.Get("AWS_REGION")
  255. if prepared.ExtraParams["region"] == "" {
  256. prepared.ExtraParams["region"] = env.Get("AWS_DEFAULT_REGION")
  257. }
  258. for _, model := range p.Models {
  259. if !strings.HasPrefix(model.ID, "anthropic.") {
  260. return fmt.Errorf("bedrock provider only supports anthropic models for now, found: %s", model.ID)
  261. }
  262. }
  263. default:
  264. // if the provider api or endpoint are missing we skip them
  265. v, err := resolver.ResolveValue(p.APIKey)
  266. if v == "" || err != nil {
  267. if configExists {
  268. slog.Warn("Skipping provider due to missing API key", "provider", p.ID)
  269. c.Providers.Del(string(p.ID))
  270. }
  271. continue
  272. }
  273. }
  274. c.Providers.Set(string(p.ID), prepared)
  275. }
  276. // validate the custom providers
  277. for id, providerConfig := range c.Providers.Seq2() {
  278. if knownProviderNames[id] {
  279. continue
  280. }
  281. // Make sure the provider ID is set
  282. providerConfig.ID = id
  283. providerConfig.Name = cmp.Or(providerConfig.Name, id) // Use ID as name if not set
  284. // default to OpenAI if not set
  285. providerConfig.Type = cmp.Or(providerConfig.Type, catwalk.TypeOpenAICompat)
  286. if !slices.Contains(catwalk.KnownProviderTypes(), providerConfig.Type) && providerConfig.Type != hyper.Name {
  287. slog.Warn("Skipping custom provider due to unsupported provider type", "provider", id)
  288. c.Providers.Del(id)
  289. continue
  290. }
  291. if providerConfig.Disable {
  292. slog.Debug("Skipping custom provider due to disable flag", "provider", id)
  293. c.Providers.Del(id)
  294. continue
  295. }
  296. if providerConfig.APIKey == "" {
  297. slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
  298. }
  299. if providerConfig.BaseURL == "" {
  300. slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id)
  301. c.Providers.Del(id)
  302. continue
  303. }
  304. if len(providerConfig.Models) == 0 {
  305. slog.Warn("Skipping custom provider because the provider has no models", "provider", id)
  306. c.Providers.Del(id)
  307. continue
  308. }
  309. apiKey, err := resolver.ResolveValue(providerConfig.APIKey)
  310. if apiKey == "" || err != nil {
  311. slog.Warn("Provider is missing API key, this might be OK for local providers", "provider", id)
  312. }
  313. baseURL, err := resolver.ResolveValue(providerConfig.BaseURL)
  314. if baseURL == "" || err != nil {
  315. slog.Warn("Skipping custom provider due to missing API endpoint", "provider", id, "error", err)
  316. c.Providers.Del(id)
  317. continue
  318. }
  319. for k, v := range providerConfig.ExtraHeaders {
  320. resolved, err := resolver.ResolveValue(v)
  321. if err != nil {
  322. slog.Error("Could not resolve provider header", "err", err.Error())
  323. continue
  324. }
  325. providerConfig.ExtraHeaders[k] = resolved
  326. }
  327. c.Providers.Set(id, providerConfig)
  328. }
  329. if c.Providers.Len() == 0 && c.Options.DisableDefaultProviders {
  330. return fmt.Errorf("default providers are disabled and there are no custom providers are configured")
  331. }
  332. return nil
  333. }
  334. func (c *Config) setDefaults(workingDir, dataDir string) {
  335. if c.Options == nil {
  336. c.Options = &Options{}
  337. }
  338. if c.Options.TUI == nil {
  339. c.Options.TUI = &TUIOptions{}
  340. }
  341. if dataDir != "" {
  342. c.Options.DataDirectory = dataDir
  343. } else if c.Options.DataDirectory == "" {
  344. if path, ok := fsext.LookupClosest(workingDir, defaultDataDirectory); ok {
  345. c.Options.DataDirectory = path
  346. } else {
  347. c.Options.DataDirectory = filepath.Join(workingDir, defaultDataDirectory)
  348. }
  349. }
  350. if c.Providers == nil {
  351. c.Providers = csync.NewMap[string, ProviderConfig]()
  352. }
  353. if c.Models == nil {
  354. c.Models = make(map[SelectedModelType]SelectedModel)
  355. }
  356. if c.RecentModels == nil {
  357. c.RecentModels = make(map[SelectedModelType][]SelectedModel)
  358. }
  359. if c.MCP == nil {
  360. c.MCP = make(map[string]MCPConfig)
  361. }
  362. if c.LSP == nil {
  363. c.LSP = make(map[string]LSPConfig)
  364. }
  365. // Apply defaults to LSP configurations
  366. c.applyLSPDefaults()
  367. // Add the default context paths if they are not already present
  368. c.Options.ContextPaths = append(defaultContextPaths, c.Options.ContextPaths...)
  369. slices.Sort(c.Options.ContextPaths)
  370. c.Options.ContextPaths = slices.Compact(c.Options.ContextPaths)
  371. // Add the default skills directories if not already present.
  372. for _, dir := range GlobalSkillsDirs() {
  373. if !slices.Contains(c.Options.SkillsPaths, dir) {
  374. c.Options.SkillsPaths = append(c.Options.SkillsPaths, dir)
  375. }
  376. }
  377. // Project specific skills dirs.
  378. c.Options.SkillsPaths = append(c.Options.SkillsPaths, ProjectSkillsDir(workingDir)...)
  379. if str, ok := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE"); ok {
  380. c.Options.DisableProviderAutoUpdate, _ = strconv.ParseBool(str)
  381. }
  382. if str, ok := os.LookupEnv("CRUSH_DISABLE_DEFAULT_PROVIDERS"); ok {
  383. c.Options.DisableDefaultProviders, _ = strconv.ParseBool(str)
  384. }
  385. if c.Options.Attribution == nil {
  386. c.Options.Attribution = &Attribution{
  387. TrailerStyle: TrailerStyleAssistedBy,
  388. GeneratedWith: true,
  389. }
  390. } else if c.Options.Attribution.TrailerStyle == "" {
  391. // Migrate deprecated co_authored_by or apply default
  392. if c.Options.Attribution.CoAuthoredBy != nil {
  393. if *c.Options.Attribution.CoAuthoredBy {
  394. c.Options.Attribution.TrailerStyle = TrailerStyleCoAuthoredBy
  395. } else {
  396. c.Options.Attribution.TrailerStyle = TrailerStyleNone
  397. }
  398. } else {
  399. c.Options.Attribution.TrailerStyle = TrailerStyleAssistedBy
  400. }
  401. }
  402. c.Options.InitializeAs = cmp.Or(c.Options.InitializeAs, defaultInitializeAs)
  403. }
  404. // applyLSPDefaults applies default values from powernap to LSP configurations
  405. func (c *Config) applyLSPDefaults() {
  406. // Get powernap's default configuration
  407. configManager := powernapConfig.NewManager()
  408. configManager.LoadDefaults()
  409. // Apply defaults to each LSP configuration
  410. for name, cfg := range c.LSP {
  411. // Try to get defaults from powernap based on name or command name.
  412. base, ok := configManager.GetServer(name)
  413. if !ok {
  414. base, ok = configManager.GetServer(cfg.Command)
  415. if !ok {
  416. continue
  417. }
  418. }
  419. if cfg.Options == nil {
  420. cfg.Options = base.Settings
  421. }
  422. if cfg.InitOptions == nil {
  423. cfg.InitOptions = base.InitOptions
  424. }
  425. if len(cfg.FileTypes) == 0 {
  426. cfg.FileTypes = base.FileTypes
  427. }
  428. if len(cfg.RootMarkers) == 0 {
  429. cfg.RootMarkers = base.RootMarkers
  430. }
  431. cfg.Command = cmp.Or(cfg.Command, base.Command)
  432. if len(cfg.Args) == 0 {
  433. cfg.Args = base.Args
  434. }
  435. if len(cfg.Env) == 0 {
  436. cfg.Env = base.Environment
  437. }
  438. // Update the config in the map
  439. c.LSP[name] = cfg
  440. }
  441. }
  442. func (c *Config) defaultModelSelection(knownProviders []catwalk.Provider) (largeModel SelectedModel, smallModel SelectedModel, err error) {
  443. if len(knownProviders) == 0 && c.Providers.Len() == 0 {
  444. err = fmt.Errorf("no providers configured, please configure at least one provider")
  445. return largeModel, smallModel, err
  446. }
  447. // Use the first provider enabled based on the known providers order
  448. // if no provider found that is known use the first provider configured
  449. for _, p := range knownProviders {
  450. providerConfig, ok := c.Providers.Get(string(p.ID))
  451. if !ok || providerConfig.Disable {
  452. continue
  453. }
  454. defaultLargeModel := c.GetModel(string(p.ID), p.DefaultLargeModelID)
  455. if defaultLargeModel == nil {
  456. err = fmt.Errorf("default large model %s not found for provider %s", p.DefaultLargeModelID, p.ID)
  457. return largeModel, smallModel, err
  458. }
  459. largeModel = SelectedModel{
  460. Provider: string(p.ID),
  461. Model: defaultLargeModel.ID,
  462. MaxTokens: defaultLargeModel.DefaultMaxTokens,
  463. ReasoningEffort: defaultLargeModel.DefaultReasoningEffort,
  464. }
  465. defaultSmallModel := c.GetModel(string(p.ID), p.DefaultSmallModelID)
  466. if defaultSmallModel == nil {
  467. err = fmt.Errorf("default small model %s not found for provider %s", p.DefaultSmallModelID, p.ID)
  468. return largeModel, smallModel, err
  469. }
  470. smallModel = SelectedModel{
  471. Provider: string(p.ID),
  472. Model: defaultSmallModel.ID,
  473. MaxTokens: defaultSmallModel.DefaultMaxTokens,
  474. ReasoningEffort: defaultSmallModel.DefaultReasoningEffort,
  475. }
  476. return largeModel, smallModel, err
  477. }
  478. enabledProviders := c.EnabledProviders()
  479. slices.SortFunc(enabledProviders, func(a, b ProviderConfig) int {
  480. return strings.Compare(a.ID, b.ID)
  481. })
  482. if len(enabledProviders) == 0 {
  483. err = fmt.Errorf("no providers configured, please configure at least one provider")
  484. return largeModel, smallModel, err
  485. }
  486. providerConfig := enabledProviders[0]
  487. if len(providerConfig.Models) == 0 {
  488. err = fmt.Errorf("provider %s has no models configured", providerConfig.ID)
  489. return largeModel, smallModel, err
  490. }
  491. defaultLargeModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
  492. largeModel = SelectedModel{
  493. Provider: providerConfig.ID,
  494. Model: defaultLargeModel.ID,
  495. MaxTokens: defaultLargeModel.DefaultMaxTokens,
  496. }
  497. defaultSmallModel := c.GetModel(providerConfig.ID, providerConfig.Models[0].ID)
  498. smallModel = SelectedModel{
  499. Provider: providerConfig.ID,
  500. Model: defaultSmallModel.ID,
  501. MaxTokens: defaultSmallModel.DefaultMaxTokens,
  502. }
  503. return largeModel, smallModel, err
  504. }
  505. func configureSelectedModels(store *ConfigStore, knownProviders []catwalk.Provider) error {
  506. c := store.config
  507. defaultLarge, defaultSmall, err := c.defaultModelSelection(knownProviders)
  508. if err != nil {
  509. return fmt.Errorf("failed to select default models: %w", err)
  510. }
  511. large, small := defaultLarge, defaultSmall
  512. largeModelSelected, largeModelConfigured := c.Models[SelectedModelTypeLarge]
  513. if largeModelConfigured {
  514. if largeModelSelected.Model != "" {
  515. large.Model = largeModelSelected.Model
  516. }
  517. if largeModelSelected.Provider != "" {
  518. large.Provider = largeModelSelected.Provider
  519. }
  520. model := c.GetModel(large.Provider, large.Model)
  521. if model == nil {
  522. large = defaultLarge
  523. // override the model type to large
  524. err := store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeLarge, large)
  525. if err != nil {
  526. return fmt.Errorf("failed to update preferred large model: %w", err)
  527. }
  528. } else {
  529. if largeModelSelected.MaxTokens > 0 {
  530. large.MaxTokens = largeModelSelected.MaxTokens
  531. } else {
  532. large.MaxTokens = model.DefaultMaxTokens
  533. }
  534. if largeModelSelected.ReasoningEffort != "" {
  535. large.ReasoningEffort = largeModelSelected.ReasoningEffort
  536. }
  537. large.Think = largeModelSelected.Think
  538. if largeModelSelected.Temperature != nil {
  539. large.Temperature = largeModelSelected.Temperature
  540. }
  541. if largeModelSelected.TopP != nil {
  542. large.TopP = largeModelSelected.TopP
  543. }
  544. if largeModelSelected.TopK != nil {
  545. large.TopK = largeModelSelected.TopK
  546. }
  547. if largeModelSelected.FrequencyPenalty != nil {
  548. large.FrequencyPenalty = largeModelSelected.FrequencyPenalty
  549. }
  550. if largeModelSelected.PresencePenalty != nil {
  551. large.PresencePenalty = largeModelSelected.PresencePenalty
  552. }
  553. }
  554. }
  555. smallModelSelected, smallModelConfigured := c.Models[SelectedModelTypeSmall]
  556. if smallModelConfigured {
  557. if smallModelSelected.Model != "" {
  558. small.Model = smallModelSelected.Model
  559. }
  560. if smallModelSelected.Provider != "" {
  561. small.Provider = smallModelSelected.Provider
  562. }
  563. model := c.GetModel(small.Provider, small.Model)
  564. if model == nil {
  565. small = defaultSmall
  566. // override the model type to small
  567. err := store.UpdatePreferredModel(ScopeGlobal, SelectedModelTypeSmall, small)
  568. if err != nil {
  569. return fmt.Errorf("failed to update preferred small model: %w", err)
  570. }
  571. } else {
  572. if smallModelSelected.MaxTokens > 0 {
  573. small.MaxTokens = smallModelSelected.MaxTokens
  574. } else {
  575. small.MaxTokens = model.DefaultMaxTokens
  576. }
  577. if smallModelSelected.ReasoningEffort != "" {
  578. small.ReasoningEffort = smallModelSelected.ReasoningEffort
  579. }
  580. if smallModelSelected.Temperature != nil {
  581. small.Temperature = smallModelSelected.Temperature
  582. }
  583. if smallModelSelected.TopP != nil {
  584. small.TopP = smallModelSelected.TopP
  585. }
  586. if smallModelSelected.TopK != nil {
  587. small.TopK = smallModelSelected.TopK
  588. }
  589. if smallModelSelected.FrequencyPenalty != nil {
  590. small.FrequencyPenalty = smallModelSelected.FrequencyPenalty
  591. }
  592. if smallModelSelected.PresencePenalty != nil {
  593. small.PresencePenalty = smallModelSelected.PresencePenalty
  594. }
  595. small.Think = smallModelSelected.Think
  596. }
  597. }
  598. c.Models[SelectedModelTypeLarge] = large
  599. c.Models[SelectedModelTypeSmall] = small
  600. return nil
  601. }
  602. // lookupConfigs searches config files recursively from CWD up to FS root
  603. func lookupConfigs(cwd string) []string {
  604. // prepend default config paths
  605. configPaths := []string{
  606. GlobalConfig(),
  607. GlobalConfigData(),
  608. }
  609. configNames := []string{appName + ".json", "." + appName + ".json"}
  610. foundConfigs, err := fsext.Lookup(cwd, configNames...)
  611. if err != nil {
  612. // returns at least default configs
  613. return configPaths
  614. }
  615. // reverse order so last config has more priority
  616. slices.Reverse(foundConfigs)
  617. return append(configPaths, foundConfigs...)
  618. }
  619. func loadFromConfigPaths(configPaths []string) (*Config, error) {
  620. var configs [][]byte
  621. for _, path := range configPaths {
  622. data, err := os.ReadFile(path)
  623. if err != nil {
  624. if os.IsNotExist(err) {
  625. continue
  626. }
  627. return nil, fmt.Errorf("failed to open config file %s: %w", path, err)
  628. }
  629. if len(data) == 0 {
  630. continue
  631. }
  632. configs = append(configs, data)
  633. }
  634. return loadFromBytes(configs)
  635. }
  636. func loadFromBytes(configs [][]byte) (*Config, error) {
  637. if len(configs) == 0 {
  638. return &Config{}, nil
  639. }
  640. data, err := jsons.Merge(configs)
  641. if err != nil {
  642. return nil, err
  643. }
  644. var config Config
  645. if err := json.Unmarshal(data, &config); err != nil {
  646. return nil, err
  647. }
  648. return &config, nil
  649. }
  650. func hasAWSCredentials(env env.Env) bool {
  651. if env.Get("AWS_BEARER_TOKEN_BEDROCK") != "" {
  652. return true
  653. }
  654. if env.Get("AWS_ACCESS_KEY_ID") != "" && env.Get("AWS_SECRET_ACCESS_KEY") != "" {
  655. return true
  656. }
  657. if env.Get("AWS_PROFILE") != "" || env.Get("AWS_DEFAULT_PROFILE") != "" {
  658. return true
  659. }
  660. if env.Get("AWS_REGION") != "" || env.Get("AWS_DEFAULT_REGION") != "" {
  661. return true
  662. }
  663. if env.Get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") != "" ||
  664. env.Get("AWS_CONTAINER_CREDENTIALS_FULL_URI") != "" {
  665. return true
  666. }
  667. if _, err := os.Stat(filepath.Join(home.Dir(), ".aws/credentials")); err == nil && !testing.Testing() {
  668. return true
  669. }
  670. return false
  671. }
  672. // GlobalConfig returns the global configuration file path for the application.
  673. func GlobalConfig() string {
  674. if crushGlobal := os.Getenv("CRUSH_GLOBAL_CONFIG"); crushGlobal != "" {
  675. return filepath.Join(crushGlobal, fmt.Sprintf("%s.json", appName))
  676. }
  677. return filepath.Join(home.Config(), appName, fmt.Sprintf("%s.json", appName))
  678. }
  679. // GlobalCacheDir returns the path to the global cache directory for the
  680. // application.
  681. func GlobalCacheDir() string {
  682. if crushCache := os.Getenv("CRUSH_CACHE_DIR"); crushCache != "" {
  683. return crushCache
  684. }
  685. if xdgCacheHome := os.Getenv("XDG_CACHE_HOME"); xdgCacheHome != "" {
  686. return filepath.Join(xdgCacheHome, appName)
  687. }
  688. if runtime.GOOS == "windows" {
  689. localAppData := cmp.Or(
  690. os.Getenv("LOCALAPPDATA"),
  691. filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local"),
  692. )
  693. return filepath.Join(localAppData, appName, "cache")
  694. }
  695. return filepath.Join(home.Dir(), ".cache", appName)
  696. }
  697. // GlobalConfigData returns the path to the main data directory for the application.
  698. // this config is used when the app overrides configurations instead of updating the global config.
  699. func GlobalConfigData() string {
  700. if crushData := os.Getenv("CRUSH_GLOBAL_DATA"); crushData != "" {
  701. return filepath.Join(crushData, fmt.Sprintf("%s.json", appName))
  702. }
  703. if xdgDataHome := os.Getenv("XDG_DATA_HOME"); xdgDataHome != "" {
  704. return filepath.Join(xdgDataHome, appName, fmt.Sprintf("%s.json", appName))
  705. }
  706. // return the path to the main data directory
  707. // for windows, it should be in `%LOCALAPPDATA%/crush/`
  708. // for linux and macOS, it should be in `$HOME/.local/share/crush/`
  709. if runtime.GOOS == "windows" {
  710. localAppData := cmp.Or(
  711. os.Getenv("LOCALAPPDATA"),
  712. filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local"),
  713. )
  714. return filepath.Join(localAppData, appName, fmt.Sprintf("%s.json", appName))
  715. }
  716. return filepath.Join(home.Dir(), ".local", "share", appName, fmt.Sprintf("%s.json", appName))
  717. }
  718. // GlobalWorkspaceDir returns the path to the global server workspace
  719. // directory. This directory acts as a meta-workspace for the server
  720. // process, giving it a real workingDir so that config loading, scoped
  721. // writes, and provider resolution behave identically to project
  722. // workspaces.
  723. func GlobalWorkspaceDir() string {
  724. return filepath.Dir(GlobalConfigData())
  725. }
  726. func assignIfNil[T any](ptr **T, val T) {
  727. if *ptr == nil {
  728. *ptr = &val
  729. }
  730. }
  731. func isInsideWorktree() bool {
  732. bts, err := exec.CommandContext(
  733. context.Background(),
  734. "git", "rev-parse",
  735. "--is-inside-work-tree",
  736. ).CombinedOutput()
  737. return err == nil && strings.TrimSpace(string(bts)) == "true"
  738. }
  739. // GlobalSkillsDirs returns the default directories for Agent Skills.
  740. // Skills in these directories are auto-discovered and their files can be read
  741. // without permission prompts.
  742. func GlobalSkillsDirs() []string {
  743. if crushSkills := os.Getenv("CRUSH_SKILLS_DIR"); crushSkills != "" {
  744. return []string{crushSkills}
  745. }
  746. paths := []string{
  747. filepath.Join(home.Config(), appName, "skills"),
  748. filepath.Join(home.Config(), "agents", "skills"),
  749. }
  750. // On Windows, also load from app data on top of `$HOME/.config/crush`.
  751. // This is here mostly for backwards compatibility.
  752. if runtime.GOOS == "windows" {
  753. appData := cmp.Or(
  754. os.Getenv("LOCALAPPDATA"),
  755. filepath.Join(os.Getenv("USERPROFILE"), "AppData", "Local"),
  756. )
  757. paths = append(
  758. paths,
  759. filepath.Join(appData, appName, "skills"),
  760. filepath.Join(appData, "agents", "skills"),
  761. )
  762. }
  763. return paths
  764. }
  765. // ProjectSkillsDir returns the default project directories for which Crush
  766. // will look for skills.
  767. func ProjectSkillsDir(workingDir string) []string {
  768. return []string{
  769. filepath.Join(workingDir, ".agents/skills"),
  770. filepath.Join(workingDir, ".crush/skills"),
  771. filepath.Join(workingDir, ".claude/skills"),
  772. filepath.Join(workingDir, ".cursor/skills"),
  773. }
  774. }
  775. func isAppleTerminal() bool { return os.Getenv("TERM_PROGRAM") == "Apple_Terminal" }
  776. // normalizeHookEvent maps user-provided event names to their canonical
  777. // form. Matching is case-insensitive and accepts snake_case variants
  778. // (e.g. "pre_tool_use" → "PreToolUse").
  779. func normalizeHookEvent(name string) string {
  780. switch strings.ToLower(strings.ReplaceAll(name, "_", "")) {
  781. case "pretooluse":
  782. return "PreToolUse"
  783. default:
  784. return name
  785. }
  786. }
  787. // ValidateHooks normalizes event names and compiles matcher regexes for all
  788. // configured hooks. Returns an error if any regex is invalid.
  789. func (c *Config) ValidateHooks() error {
  790. // Normalize event name keys.
  791. for event, eventHooks := range c.Hooks {
  792. canonical := normalizeHookEvent(event)
  793. if canonical != event {
  794. c.Hooks[canonical] = append(c.Hooks[canonical], eventHooks...)
  795. delete(c.Hooks, event)
  796. }
  797. }
  798. for event, eventHooks := range c.Hooks {
  799. for i := range eventHooks {
  800. h := &c.Hooks[event][i]
  801. if h.Command == "" {
  802. return fmt.Errorf("hook %s[%d]: command is required", event, i)
  803. }
  804. if h.Matcher != "" {
  805. re, err := regexp.Compile(h.Matcher)
  806. if err != nil {
  807. return fmt.Errorf("hook %s[%d]: invalid matcher regex %q: %w", event, i, h.Matcher, err)
  808. }
  809. h.matcherRegex = re
  810. }
  811. }
  812. }
  813. return nil
  814. }