root.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. package cmd
  2. import (
  3. "context"
  4. "fmt"
  5. "os"
  6. "os/signal"
  7. "sync"
  8. "syscall"
  9. "time"
  10. tea "github.com/charmbracelet/bubbletea/v2"
  11. "github.com/charmbracelet/crush/internal/app"
  12. "github.com/charmbracelet/crush/internal/config"
  13. "github.com/charmbracelet/crush/internal/db"
  14. "github.com/charmbracelet/crush/internal/format"
  15. "github.com/charmbracelet/crush/internal/llm/agent"
  16. "github.com/charmbracelet/crush/internal/logging"
  17. "github.com/charmbracelet/crush/internal/pubsub"
  18. "github.com/charmbracelet/crush/internal/tui"
  19. "github.com/charmbracelet/crush/internal/version"
  20. "github.com/charmbracelet/fang"
  21. "github.com/spf13/cobra"
  22. )
  23. var rootCmd = &cobra.Command{
  24. Use: "crush",
  25. Short: "Terminal-based AI assistant for software development",
  26. Long: `Crush is a powerful terminal-based AI assistant that helps with software development tasks.
  27. It provides an interactive chat interface with AI capabilities, code analysis, and LSP integration
  28. to assist developers in writing, debugging, and understanding code directly from the terminal.`,
  29. Example: `
  30. # Run in interactive mode
  31. crush
  32. # Run with debug logging
  33. crush -d
  34. # Run with debug logging in a specific directory
  35. crush -d -c /path/to/project
  36. # Print version
  37. crush -v
  38. # Run a single non-interactive prompt
  39. crush -p "Explain the use of context in Go"
  40. # Run a single non-interactive prompt with JSON output format
  41. crush -p "Explain the use of context in Go" -f json
  42. `,
  43. RunE: func(cmd *cobra.Command, args []string) error {
  44. // Load the config
  45. debug, _ := cmd.Flags().GetBool("debug")
  46. cwd, _ := cmd.Flags().GetString("cwd")
  47. prompt, _ := cmd.Flags().GetString("prompt")
  48. outputFormat, _ := cmd.Flags().GetString("output-format")
  49. quiet, _ := cmd.Flags().GetBool("quiet")
  50. // Validate format option
  51. if !format.IsValid(outputFormat) {
  52. return fmt.Errorf("invalid format option: %s\n%s", outputFormat, format.GetHelpText())
  53. }
  54. if cwd != "" {
  55. err := os.Chdir(cwd)
  56. if err != nil {
  57. return fmt.Errorf("failed to change directory: %v", err)
  58. }
  59. }
  60. if cwd == "" {
  61. c, err := os.Getwd()
  62. if err != nil {
  63. return fmt.Errorf("failed to get current working directory: %v", err)
  64. }
  65. cwd = c
  66. }
  67. _, err := config.Load(cwd, debug)
  68. if err != nil {
  69. return err
  70. }
  71. // Connect DB, this will also run migrations
  72. conn, err := db.Connect()
  73. if err != nil {
  74. return err
  75. }
  76. // Create main context for the application with signal handling
  77. ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
  78. defer cancel()
  79. app, err := app.New(ctx, conn)
  80. if err != nil {
  81. logging.Error("Failed to create app: %v", err)
  82. return err
  83. }
  84. // Defer shutdown here so it runs for both interactive and non-interactive modes
  85. defer app.Shutdown()
  86. // Initialize MCP tools early for both modes
  87. initMCPTools(ctx, app)
  88. // Non-interactive mode
  89. if prompt != "" {
  90. // Run non-interactive flow using the App method
  91. return app.RunNonInteractive(ctx, prompt, outputFormat, quiet)
  92. }
  93. // Set up the TUI
  94. program := tea.NewProgram(
  95. tui.New(app),
  96. tea.WithAltScreen(),
  97. tea.WithKeyReleases(),
  98. tea.WithUniformKeyLayout(),
  99. )
  100. // Setup the subscriptions, this will send services events to the TUI
  101. ch, cancelSubs := setupSubscriptions(app, ctx)
  102. // Create a context for the TUI message handler
  103. tuiCtx, tuiCancel := context.WithCancel(ctx)
  104. var tuiWg sync.WaitGroup
  105. tuiWg.Add(1)
  106. // Set up message handling for the TUI
  107. go func() {
  108. defer tuiWg.Done()
  109. defer logging.RecoverPanic("TUI-message-handler", func() {
  110. attemptTUIRecovery(program)
  111. })
  112. for {
  113. select {
  114. case <-tuiCtx.Done():
  115. logging.Info("TUI message handler shutting down")
  116. return
  117. case msg, ok := <-ch:
  118. if !ok {
  119. logging.Info("TUI message channel closed")
  120. return
  121. }
  122. program.Send(msg)
  123. }
  124. }
  125. }()
  126. // Cleanup function for when the program exits
  127. cleanup := func() {
  128. // Shutdown the app
  129. app.Shutdown()
  130. // Cancel subscriptions first
  131. cancelSubs()
  132. // Then cancel TUI message handler
  133. tuiCancel()
  134. // Wait for TUI message handler to finish
  135. tuiWg.Wait()
  136. logging.Info("All goroutines cleaned up")
  137. }
  138. // Run the TUI
  139. result, err := program.Run()
  140. cleanup()
  141. if err != nil {
  142. logging.Error("TUI error: %v", err)
  143. return fmt.Errorf("TUI error: %v", err)
  144. }
  145. logging.Info("TUI exited with result: %v", result)
  146. return nil
  147. },
  148. }
  149. // attemptTUIRecovery tries to recover the TUI after a panic
  150. func attemptTUIRecovery(program *tea.Program) {
  151. logging.Info("Attempting to recover TUI after panic")
  152. // We could try to restart the TUI or gracefully exit
  153. // For now, we'll just quit the program to avoid further issues
  154. program.Quit()
  155. }
  156. func initMCPTools(ctx context.Context, app *app.App) {
  157. go func() {
  158. defer logging.RecoverPanic("MCP-goroutine", nil)
  159. // Create a context with timeout for the initial MCP tools fetch
  160. ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
  161. defer cancel()
  162. // Set this up once with proper error handling
  163. agent.GetMcpTools(ctxWithTimeout, app.Permissions)
  164. logging.Info("MCP message handling goroutine exiting")
  165. }()
  166. }
  167. func setupSubscriber[T any](
  168. ctx context.Context,
  169. wg *sync.WaitGroup,
  170. name string,
  171. subscriber func(context.Context) <-chan pubsub.Event[T],
  172. outputCh chan<- tea.Msg,
  173. ) {
  174. wg.Add(1)
  175. go func() {
  176. defer wg.Done()
  177. defer logging.RecoverPanic(fmt.Sprintf("subscription-%s", name), nil)
  178. subCh := subscriber(ctx)
  179. for {
  180. select {
  181. case event, ok := <-subCh:
  182. if !ok {
  183. logging.Info("subscription channel closed", "name", name)
  184. return
  185. }
  186. var msg tea.Msg = event
  187. select {
  188. case outputCh <- msg:
  189. case <-time.After(2 * time.Second):
  190. logging.Warn("message dropped due to slow consumer", "name", name)
  191. case <-ctx.Done():
  192. logging.Info("subscription cancelled", "name", name)
  193. return
  194. }
  195. case <-ctx.Done():
  196. logging.Info("subscription cancelled", "name", name)
  197. return
  198. }
  199. }
  200. }()
  201. }
  202. func setupSubscriptions(app *app.App, parentCtx context.Context) (chan tea.Msg, func()) {
  203. ch := make(chan tea.Msg, 100)
  204. wg := sync.WaitGroup{}
  205. ctx, cancel := context.WithCancel(parentCtx) // Inherit from parent context
  206. setupSubscriber(ctx, &wg, "logging", logging.Subscribe, ch)
  207. setupSubscriber(ctx, &wg, "sessions", app.Sessions.Subscribe, ch)
  208. setupSubscriber(ctx, &wg, "messages", app.Messages.Subscribe, ch)
  209. setupSubscriber(ctx, &wg, "permissions", app.Permissions.Subscribe, ch)
  210. setupSubscriber(ctx, &wg, "coderAgent", app.CoderAgent.Subscribe, ch)
  211. setupSubscriber(ctx, &wg, "history", app.History.Subscribe, ch)
  212. cleanupFunc := func() {
  213. logging.Info("Cancelling all subscriptions")
  214. cancel() // Signal all goroutines to stop
  215. waitCh := make(chan struct{})
  216. go func() {
  217. defer logging.RecoverPanic("subscription-cleanup", nil)
  218. wg.Wait()
  219. close(waitCh)
  220. }()
  221. select {
  222. case <-waitCh:
  223. logging.Info("All subscription goroutines completed successfully")
  224. close(ch) // Only close after all writers are confirmed done
  225. case <-time.After(5 * time.Second):
  226. logging.Warn("Timed out waiting for some subscription goroutines to complete")
  227. close(ch)
  228. }
  229. }
  230. return ch, cleanupFunc
  231. }
  232. func Execute() {
  233. if err := fang.Execute(
  234. context.Background(),
  235. rootCmd,
  236. fang.WithVersion(version.Version),
  237. ); err != nil {
  238. os.Exit(1)
  239. }
  240. }
  241. func init() {
  242. rootCmd.Flags().BoolP("help", "h", false, "Help")
  243. rootCmd.Flags().BoolP("debug", "d", false, "Debug")
  244. rootCmd.Flags().StringP("cwd", "c", "", "Current working directory")
  245. rootCmd.Flags().StringP("prompt", "p", "", "Prompt to run in non-interactive mode")
  246. // Add format flag with validation logic
  247. rootCmd.Flags().StringP("output-format", "f", format.Text.String(),
  248. "Output format for non-interactive mode (text, json)")
  249. // Add quiet flag to hide spinner in non-interactive mode
  250. rootCmd.Flags().BoolP("quiet", "q", false, "Hide spinner in non-interactive mode")
  251. // Register custom validation for the format flag
  252. rootCmd.RegisterFlagCompletionFunc("output-format", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
  253. return format.SupportedFormats, cobra.ShellCompDirectiveNoFileComp
  254. })
  255. }