app.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. // Package app wires together services, coordinates agents, and manages
  2. // application lifecycle.
  3. package app
  4. import (
  5. "context"
  6. "database/sql"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "os"
  12. "sync"
  13. "time"
  14. tea "charm.land/bubbletea/v2"
  15. "charm.land/fantasy"
  16. "charm.land/lipgloss/v2"
  17. "github.com/charmbracelet/crush/internal/agent"
  18. "github.com/charmbracelet/crush/internal/agent/tools/mcp"
  19. "github.com/charmbracelet/crush/internal/config"
  20. "github.com/charmbracelet/crush/internal/csync"
  21. "github.com/charmbracelet/crush/internal/db"
  22. "github.com/charmbracelet/crush/internal/format"
  23. "github.com/charmbracelet/crush/internal/history"
  24. "github.com/charmbracelet/crush/internal/log"
  25. "github.com/charmbracelet/crush/internal/lsp"
  26. "github.com/charmbracelet/crush/internal/message"
  27. "github.com/charmbracelet/crush/internal/permission"
  28. "github.com/charmbracelet/crush/internal/pubsub"
  29. "github.com/charmbracelet/crush/internal/session"
  30. "github.com/charmbracelet/crush/internal/shell"
  31. "github.com/charmbracelet/crush/internal/term"
  32. "github.com/charmbracelet/crush/internal/tui/components/anim"
  33. "github.com/charmbracelet/crush/internal/tui/styles"
  34. "github.com/charmbracelet/crush/internal/update"
  35. "github.com/charmbracelet/crush/internal/version"
  36. "github.com/charmbracelet/x/ansi"
  37. "github.com/charmbracelet/x/exp/charmtone"
  38. )
  39. type App struct {
  40. Sessions session.Service
  41. Messages message.Service
  42. History history.Service
  43. Permissions permission.Service
  44. AgentCoordinator agent.Coordinator
  45. LSPClients *csync.Map[string, *lsp.Client]
  46. config *config.Config
  47. serviceEventsWG *sync.WaitGroup
  48. eventsCtx context.Context
  49. events chan tea.Msg
  50. tuiWG *sync.WaitGroup
  51. // global context and cleanup functions
  52. globalCtx context.Context
  53. cleanupFuncs []func() error
  54. }
  55. // New initializes a new applcation instance.
  56. func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
  57. q := db.New(conn)
  58. sessions := session.NewService(q)
  59. messages := message.NewService(q)
  60. files := history.NewService(q, conn)
  61. skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests
  62. allowedTools := []string{}
  63. if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil {
  64. allowedTools = cfg.Permissions.AllowedTools
  65. }
  66. app := &App{
  67. Sessions: sessions,
  68. Messages: messages,
  69. History: files,
  70. Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
  71. LSPClients: csync.NewMap[string, *lsp.Client](),
  72. globalCtx: ctx,
  73. config: cfg,
  74. events: make(chan tea.Msg, 100),
  75. serviceEventsWG: &sync.WaitGroup{},
  76. tuiWG: &sync.WaitGroup{},
  77. }
  78. app.setupEvents()
  79. // Initialize LSP clients in the background.
  80. app.initLSPClients(ctx)
  81. // Check for updates in the background.
  82. go app.checkForUpdates(ctx)
  83. go func() {
  84. slog.Info("Initializing MCP clients")
  85. mcp.Initialize(ctx, app.Permissions, cfg)
  86. }()
  87. // cleanup database upon app shutdown
  88. app.cleanupFuncs = append(app.cleanupFuncs, conn.Close, mcp.Close)
  89. // TODO: remove the concept of agent config, most likely.
  90. if !cfg.IsConfigured() {
  91. slog.Warn("No agent configuration found")
  92. return app, nil
  93. }
  94. if err := app.InitCoderAgent(ctx); err != nil {
  95. return nil, fmt.Errorf("failed to initialize coder agent: %w", err)
  96. }
  97. return app, nil
  98. }
  99. // Config returns the application configuration.
  100. func (app *App) Config() *config.Config {
  101. return app.config
  102. }
  103. // RunNonInteractive runs the application in non-interactive mode with the
  104. // given prompt, printing to stdout.
  105. func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt string, quiet bool) error {
  106. slog.Info("Running in non-interactive mode")
  107. ctx, cancel := context.WithCancel(ctx)
  108. defer cancel()
  109. var spinner *format.Spinner
  110. if !quiet {
  111. t := styles.CurrentTheme()
  112. // Detect background color to set the appropriate color for the
  113. // spinner's 'Generating...' text. Without this, that text would be
  114. // unreadable in light terminals.
  115. hasDarkBG := true
  116. if f, ok := output.(*os.File); ok {
  117. hasDarkBG = lipgloss.HasDarkBackground(os.Stdin, f)
  118. }
  119. defaultFG := lipgloss.LightDark(hasDarkBG)(charmtone.Pepper, t.FgBase)
  120. spinner = format.NewSpinner(ctx, cancel, anim.Settings{
  121. Size: 10,
  122. Label: "Generating",
  123. LabelColor: defaultFG,
  124. GradColorA: t.Primary,
  125. GradColorB: t.Secondary,
  126. CycleColors: true,
  127. })
  128. spinner.Start()
  129. }
  130. // Helper function to stop spinner once.
  131. stopSpinner := func() {
  132. if !quiet && spinner != nil {
  133. spinner.Stop()
  134. spinner = nil
  135. }
  136. }
  137. defer stopSpinner()
  138. const maxPromptLengthForTitle = 100
  139. const titlePrefix = "Non-interactive: "
  140. var titleSuffix string
  141. if len(prompt) > maxPromptLengthForTitle {
  142. titleSuffix = prompt[:maxPromptLengthForTitle] + "..."
  143. } else {
  144. titleSuffix = prompt
  145. }
  146. title := titlePrefix + titleSuffix
  147. sess, err := app.Sessions.Create(ctx, title)
  148. if err != nil {
  149. return fmt.Errorf("failed to create session for non-interactive mode: %w", err)
  150. }
  151. slog.Info("Created session for non-interactive run", "session_id", sess.ID)
  152. // Automatically approve all permission requests for this non-interactive
  153. // session.
  154. app.Permissions.AutoApproveSession(sess.ID)
  155. type response struct {
  156. result *fantasy.AgentResult
  157. err error
  158. }
  159. done := make(chan response, 1)
  160. go func(ctx context.Context, sessionID, prompt string) {
  161. result, err := app.AgentCoordinator.Run(ctx, sess.ID, prompt)
  162. if err != nil {
  163. done <- response{
  164. err: fmt.Errorf("failed to start agent processing stream: %w", err),
  165. }
  166. }
  167. done <- response{
  168. result: result,
  169. }
  170. }(ctx, sess.ID, prompt)
  171. messageEvents := app.Messages.Subscribe(ctx)
  172. messageReadBytes := make(map[string]int)
  173. supportsProgressBar := term.SupportsProgressBar()
  174. defer func() {
  175. if supportsProgressBar {
  176. _, _ = fmt.Fprintf(os.Stderr, ansi.ResetProgressBar)
  177. }
  178. // Always print a newline at the end. If output is a TTY this will
  179. // prevent the prompt from overwriting the last line of output.
  180. _, _ = fmt.Fprintln(output)
  181. }()
  182. for {
  183. if supportsProgressBar {
  184. // HACK: Reinitialize the terminal progress bar on every iteration so
  185. // it doesn't get hidden by the terminal due to inactivity.
  186. _, _ = fmt.Fprintf(os.Stderr, ansi.SetIndeterminateProgressBar)
  187. }
  188. select {
  189. case result := <-done:
  190. stopSpinner()
  191. if result.err != nil {
  192. if errors.Is(result.err, context.Canceled) || errors.Is(result.err, agent.ErrRequestCancelled) {
  193. slog.Info("Non-interactive: agent processing cancelled", "session_id", sess.ID)
  194. return nil
  195. }
  196. return fmt.Errorf("agent processing failed: %w", result.err)
  197. }
  198. return nil
  199. case event := <-messageEvents:
  200. msg := event.Payload
  201. if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 {
  202. stopSpinner()
  203. content := msg.Content().String()
  204. readBytes := messageReadBytes[msg.ID]
  205. if len(content) < readBytes {
  206. slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(content), "read_bytes", readBytes)
  207. return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
  208. }
  209. part := content[readBytes:]
  210. fmt.Fprint(output, part)
  211. messageReadBytes[msg.ID] = len(content)
  212. }
  213. case <-ctx.Done():
  214. stopSpinner()
  215. return ctx.Err()
  216. }
  217. }
  218. }
  219. func (app *App) UpdateAgentModel(ctx context.Context) error {
  220. if app.AgentCoordinator == nil {
  221. return fmt.Errorf("agent configuration is missing")
  222. }
  223. return app.AgentCoordinator.UpdateModels(ctx)
  224. }
  225. func (app *App) setupEvents() {
  226. ctx, cancel := context.WithCancel(app.globalCtx)
  227. app.eventsCtx = ctx
  228. setupSubscriber(ctx, app.serviceEventsWG, "sessions", app.Sessions.Subscribe, app.events)
  229. setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events)
  230. setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
  231. setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events)
  232. setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
  233. setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events)
  234. setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events)
  235. cleanupFunc := func() error {
  236. cancel()
  237. app.serviceEventsWG.Wait()
  238. return nil
  239. }
  240. app.cleanupFuncs = append(app.cleanupFuncs, cleanupFunc)
  241. }
  242. func setupSubscriber[T any](
  243. ctx context.Context,
  244. wg *sync.WaitGroup,
  245. name string,
  246. subscriber func(context.Context) <-chan pubsub.Event[T],
  247. outputCh chan<- tea.Msg,
  248. ) {
  249. wg.Go(func() {
  250. subCh := subscriber(ctx)
  251. for {
  252. select {
  253. case event, ok := <-subCh:
  254. if !ok {
  255. slog.Debug("subscription channel closed", "name", name)
  256. return
  257. }
  258. var msg tea.Msg = event
  259. select {
  260. case outputCh <- msg:
  261. case <-time.After(2 * time.Second):
  262. slog.Warn("message dropped due to slow consumer", "name", name)
  263. case <-ctx.Done():
  264. slog.Debug("subscription cancelled", "name", name)
  265. return
  266. }
  267. case <-ctx.Done():
  268. slog.Debug("subscription cancelled", "name", name)
  269. return
  270. }
  271. }
  272. })
  273. }
  274. func (app *App) InitCoderAgent(ctx context.Context) error {
  275. coderAgentCfg := app.config.Agents[config.AgentCoder]
  276. if coderAgentCfg.ID == "" {
  277. return fmt.Errorf("coder agent configuration is missing")
  278. }
  279. var err error
  280. app.AgentCoordinator, err = agent.NewCoordinator(
  281. ctx,
  282. app.config,
  283. app.Sessions,
  284. app.Messages,
  285. app.Permissions,
  286. app.History,
  287. app.LSPClients,
  288. )
  289. if err != nil {
  290. slog.Error("Failed to create coder agent", "err", err)
  291. return err
  292. }
  293. return nil
  294. }
  295. // Subscribe sends events to the TUI as tea.Msgs.
  296. func (app *App) Subscribe(program *tea.Program) {
  297. defer log.RecoverPanic("app.Subscribe", func() {
  298. slog.Info("TUI subscription panic: attempting graceful shutdown")
  299. program.Quit()
  300. })
  301. app.tuiWG.Add(1)
  302. tuiCtx, tuiCancel := context.WithCancel(app.globalCtx)
  303. app.cleanupFuncs = append(app.cleanupFuncs, func() error {
  304. slog.Debug("Cancelling TUI message handler")
  305. tuiCancel()
  306. app.tuiWG.Wait()
  307. return nil
  308. })
  309. defer app.tuiWG.Done()
  310. for {
  311. select {
  312. case <-tuiCtx.Done():
  313. slog.Debug("TUI message handler shutting down")
  314. return
  315. case msg, ok := <-app.events:
  316. if !ok {
  317. slog.Debug("TUI message channel closed")
  318. return
  319. }
  320. program.Send(msg)
  321. }
  322. }
  323. }
  324. // Shutdown performs a graceful shutdown of the application.
  325. func (app *App) Shutdown() {
  326. start := time.Now()
  327. defer func() { slog.Info("Shutdown took " + time.Since(start).String()) }()
  328. var wg sync.WaitGroup
  329. if app.AgentCoordinator != nil {
  330. wg.Go(func() {
  331. app.AgentCoordinator.CancelAll()
  332. })
  333. }
  334. // Kill all background shells.
  335. wg.Go(func() {
  336. shell.GetBackgroundShellManager().KillAll()
  337. })
  338. // Shutdown all LSP clients.
  339. for name, client := range app.LSPClients.Seq2() {
  340. wg.Go(func() {
  341. shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
  342. defer cancel()
  343. if err := client.Close(shutdownCtx); err != nil {
  344. slog.Error("Failed to shutdown LSP client", "name", name, "error", err)
  345. }
  346. })
  347. }
  348. // Call call cleanup functions.
  349. for _, cleanup := range app.cleanupFuncs {
  350. if cleanup != nil {
  351. wg.Go(func() {
  352. if err := cleanup(); err != nil {
  353. slog.Error("Failed to cleanup app properly on shutdown", "error", err)
  354. }
  355. })
  356. }
  357. }
  358. wg.Wait()
  359. }
  360. // checkForUpdates checks for available updates.
  361. func (app *App) checkForUpdates(ctx context.Context) {
  362. checkCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
  363. defer cancel()
  364. info, err := update.Check(checkCtx, version.Version, update.Default)
  365. if err != nil || !info.Available() {
  366. return
  367. }
  368. app.events <- pubsub.UpdateAvailableMsg{
  369. CurrentVersion: info.Current,
  370. LatestVersion: info.Latest,
  371. IsDevelopment: info.IsDevelopment(),
  372. }
  373. }