root.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. package cmd
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "strconv"
  12. "strings"
  13. tea "charm.land/bubbletea/v2"
  14. "charm.land/lipgloss/v2"
  15. "github.com/charmbracelet/colorprofile"
  16. "github.com/charmbracelet/crush/internal/app"
  17. "github.com/charmbracelet/crush/internal/config"
  18. "github.com/charmbracelet/crush/internal/db"
  19. "github.com/charmbracelet/crush/internal/event"
  20. "github.com/charmbracelet/crush/internal/projects"
  21. "github.com/charmbracelet/crush/internal/stringext"
  22. "github.com/charmbracelet/crush/internal/tui"
  23. "github.com/charmbracelet/crush/internal/version"
  24. "github.com/charmbracelet/fang"
  25. uv "github.com/charmbracelet/ultraviolet"
  26. "github.com/charmbracelet/x/ansi"
  27. "github.com/charmbracelet/x/exp/charmtone"
  28. "github.com/charmbracelet/x/term"
  29. "github.com/spf13/cobra"
  30. )
  31. func init() {
  32. rootCmd.PersistentFlags().StringP("cwd", "c", "", "Current working directory")
  33. rootCmd.PersistentFlags().StringP("data-dir", "D", "", "Custom crush data directory")
  34. rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug")
  35. rootCmd.Flags().BoolP("help", "h", false, "Help")
  36. rootCmd.Flags().BoolP("yolo", "y", false, "Automatically accept all permissions (dangerous mode)")
  37. rootCmd.AddCommand(
  38. runCmd,
  39. dirsCmd,
  40. projectsCmd,
  41. updateProvidersCmd,
  42. logsCmd,
  43. schemaCmd,
  44. loginCmd,
  45. )
  46. }
  47. var rootCmd = &cobra.Command{
  48. Use: "crush",
  49. Short: "An AI assistant for software development",
  50. Long: "An AI assistant for software development and similar tasks with direct access to the terminal",
  51. Example: `
  52. # Run in interactive mode
  53. crush
  54. # Run with debug logging
  55. crush -d
  56. # Run with debug logging in a specific directory
  57. crush -d -c /path/to/project
  58. # Run with custom data directory
  59. crush -D /path/to/custom/.crush
  60. # Print version
  61. crush -v
  62. # Run a single non-interactive prompt
  63. crush run "Explain the use of context in Go"
  64. # Run in dangerous mode (auto-accept all permissions)
  65. crush -y
  66. `,
  67. RunE: func(cmd *cobra.Command, args []string) error {
  68. app, err := setupAppWithProgressBar(cmd)
  69. if err != nil {
  70. return err
  71. }
  72. defer app.Shutdown()
  73. event.AppInitialized()
  74. // Set up the TUI.
  75. var env uv.Environ = os.Environ()
  76. ui := tui.New(app)
  77. ui.QueryVersion = shouldQueryTerminalVersion(env)
  78. program := tea.NewProgram(
  79. ui,
  80. tea.WithEnvironment(env),
  81. tea.WithContext(cmd.Context()),
  82. tea.WithFilter(tui.MouseEventFilter)) // Filter mouse events based on focus state
  83. go app.Subscribe(program)
  84. if _, err := program.Run(); err != nil {
  85. event.Error(err)
  86. slog.Error("TUI run error", "error", err)
  87. return errors.New("Crush crashed. If metrics are enabled, we were notified about it. If you'd like to report it, please copy the stacktrace above and open an issue at https://github.com/charmbracelet/crush/issues/new?template=bug.yml") //nolint:staticcheck
  88. }
  89. return nil
  90. },
  91. PostRun: func(cmd *cobra.Command, args []string) {
  92. event.AppExited()
  93. },
  94. }
  95. var heartbit = lipgloss.NewStyle().Foreground(charmtone.Dolly).SetString(`
  96. ▄▄▄▄▄▄▄▄ ▄▄▄▄▄▄▄▄
  97. ███████████ ███████████
  98. ████████████████████████████
  99. ████████████████████████████
  100. ██████████▀██████▀██████████
  101. ██████████ ██████ ██████████
  102. ▀▀██████▄████▄▄████▄██████▀▀
  103. ████████████████████████
  104. ████████████████████
  105. ▀▀██████████▀▀
  106. ▀▀▀▀▀▀
  107. `)
  108. // copied from cobra:
  109. const defaultVersionTemplate = `{{with .DisplayName}}{{printf "%s " .}}{{end}}{{printf "version %s" .Version}}
  110. `
  111. func Execute() {
  112. // NOTE: very hacky: we create a colorprofile writer with STDOUT, then make
  113. // it forward to a bytes.Buffer, write the colored heartbit to it, and then
  114. // finally prepend it in the version template.
  115. // Unfortunately cobra doesn't give us a way to set a function to handle
  116. // printing the version, and PreRunE runs after the version is already
  117. // handled, so that doesn't work either.
  118. // This is the only way I could find that works relatively well.
  119. if term.IsTerminal(os.Stdout.Fd()) {
  120. var b bytes.Buffer
  121. w := colorprofile.NewWriter(os.Stdout, os.Environ())
  122. w.Forward = &b
  123. _, _ = w.WriteString(heartbit.String())
  124. rootCmd.SetVersionTemplate(b.String() + "\n" + defaultVersionTemplate)
  125. }
  126. if err := fang.Execute(
  127. context.Background(),
  128. rootCmd,
  129. fang.WithVersion(version.Version),
  130. fang.WithNotifySignal(os.Interrupt),
  131. ); err != nil {
  132. os.Exit(1)
  133. }
  134. }
  135. // supportsProgressBar tries to determine whether the current terminal supports
  136. // progress bars by looking into environment variables.
  137. func supportsProgressBar() bool {
  138. if !term.IsTerminal(os.Stderr.Fd()) {
  139. return false
  140. }
  141. termProg := os.Getenv("TERM_PROGRAM")
  142. _, isWindowsTerminal := os.LookupEnv("WT_SESSION")
  143. return isWindowsTerminal || strings.Contains(strings.ToLower(termProg), "ghostty")
  144. }
  145. func setupAppWithProgressBar(cmd *cobra.Command) (*app.App, error) {
  146. if supportsProgressBar() {
  147. _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
  148. defer func() { _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar) }()
  149. }
  150. return setupApp(cmd)
  151. }
  152. // setupApp handles the common setup logic for both interactive and non-interactive modes.
  153. // It returns the app instance, config, cleanup function, and any error.
  154. func setupApp(cmd *cobra.Command) (*app.App, error) {
  155. debug, _ := cmd.Flags().GetBool("debug")
  156. yolo, _ := cmd.Flags().GetBool("yolo")
  157. dataDir, _ := cmd.Flags().GetString("data-dir")
  158. ctx := cmd.Context()
  159. cwd, err := ResolveCwd(cmd)
  160. if err != nil {
  161. return nil, err
  162. }
  163. cfg, err := config.Init(cwd, dataDir, debug)
  164. if err != nil {
  165. return nil, err
  166. }
  167. if cfg.Permissions == nil {
  168. cfg.Permissions = &config.Permissions{}
  169. }
  170. cfg.Permissions.SkipRequests = yolo
  171. if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
  172. return nil, err
  173. }
  174. // Register this project in the centralized projects list.
  175. if err := projects.Register(cwd, cfg.Options.DataDirectory); err != nil {
  176. slog.Warn("Failed to register project", "error", err)
  177. // Non-fatal: continue even if registration fails
  178. }
  179. // Connect to DB; this will also run migrations.
  180. conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
  181. if err != nil {
  182. return nil, err
  183. }
  184. appInstance, err := app.New(ctx, conn, cfg)
  185. if err != nil {
  186. slog.Error("Failed to create app instance", "error", err)
  187. return nil, err
  188. }
  189. if shouldEnableMetrics() {
  190. event.Init()
  191. }
  192. return appInstance, nil
  193. }
  194. func shouldEnableMetrics() bool {
  195. if v, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_METRICS")); v {
  196. return false
  197. }
  198. if v, _ := strconv.ParseBool(os.Getenv("DO_NOT_TRACK")); v {
  199. return false
  200. }
  201. if config.Get().Options.DisableMetrics {
  202. return false
  203. }
  204. return true
  205. }
  206. func MaybePrependStdin(prompt string) (string, error) {
  207. if term.IsTerminal(os.Stdin.Fd()) {
  208. return prompt, nil
  209. }
  210. fi, err := os.Stdin.Stat()
  211. if err != nil {
  212. return prompt, err
  213. }
  214. // Check if stdin is a named pipe ( | ) or regular file ( < ).
  215. if fi.Mode()&os.ModeNamedPipe == 0 && !fi.Mode().IsRegular() {
  216. return prompt, nil
  217. }
  218. bts, err := io.ReadAll(os.Stdin)
  219. if err != nil {
  220. return prompt, err
  221. }
  222. return string(bts) + "\n\n" + prompt, nil
  223. }
  224. func ResolveCwd(cmd *cobra.Command) (string, error) {
  225. cwd, _ := cmd.Flags().GetString("cwd")
  226. if cwd != "" {
  227. err := os.Chdir(cwd)
  228. if err != nil {
  229. return "", fmt.Errorf("failed to change directory: %v", err)
  230. }
  231. return cwd, nil
  232. }
  233. cwd, err := os.Getwd()
  234. if err != nil {
  235. return "", fmt.Errorf("failed to get current working directory: %v", err)
  236. }
  237. return cwd, nil
  238. }
  239. func createDotCrushDir(dir string) error {
  240. if err := os.MkdirAll(dir, 0o700); err != nil {
  241. return fmt.Errorf("failed to create data directory: %q %w", dir, err)
  242. }
  243. gitIgnorePath := filepath.Join(dir, ".gitignore")
  244. if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
  245. if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
  246. return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
  247. }
  248. }
  249. return nil
  250. }
  251. func shouldQueryTerminalVersion(env uv.Environ) bool {
  252. termType := env.Getenv("TERM")
  253. termProg, okTermProg := env.LookupEnv("TERM_PROGRAM")
  254. _, okSSHTTY := env.LookupEnv("SSH_TTY")
  255. return (!okTermProg && !okSSHTTY) ||
  256. (!strings.Contains(termProg, "Apple") && !okSSHTTY) ||
  257. // Terminals that do support XTVERSION.
  258. stringext.ContainsAny(termType, "alacritty", "ghostty", "kitty", "rio", "wezterm")
  259. }