app.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. package app
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "log/slog"
  8. "maps"
  9. "sync"
  10. "time"
  11. tea "github.com/charmbracelet/bubbletea/v2"
  12. "github.com/charmbracelet/crush/internal/config"
  13. "github.com/charmbracelet/crush/internal/csync"
  14. "github.com/charmbracelet/crush/internal/db"
  15. "github.com/charmbracelet/crush/internal/format"
  16. "github.com/charmbracelet/crush/internal/history"
  17. "github.com/charmbracelet/crush/internal/llm/agent"
  18. "github.com/charmbracelet/crush/internal/log"
  19. "github.com/charmbracelet/crush/internal/pubsub"
  20. "github.com/charmbracelet/crush/internal/lsp"
  21. "github.com/charmbracelet/crush/internal/message"
  22. "github.com/charmbracelet/crush/internal/permission"
  23. "github.com/charmbracelet/crush/internal/session"
  24. )
  25. type App struct {
  26. Sessions session.Service
  27. Messages message.Service
  28. History history.Service
  29. Permissions permission.Service
  30. CoderAgent agent.Service
  31. LSPClients map[string]*lsp.Client
  32. clientsMutex sync.RWMutex
  33. watcherCancelFuncs *csync.Slice[context.CancelFunc]
  34. lspWatcherWG sync.WaitGroup
  35. config *config.Config
  36. serviceEventsWG *sync.WaitGroup
  37. eventsCtx context.Context
  38. events chan tea.Msg
  39. tuiWG *sync.WaitGroup
  40. // global context and cleanup functions
  41. globalCtx context.Context
  42. cleanupFuncs []func()
  43. }
  44. // New initializes a new applcation instance.
  45. func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
  46. q := db.New(conn)
  47. sessions := session.NewService(q)
  48. messages := message.NewService(q)
  49. files := history.NewService(q, conn)
  50. skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests
  51. allowedTools := []string{}
  52. if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil {
  53. allowedTools = cfg.Permissions.AllowedTools
  54. }
  55. app := &App{
  56. Sessions: sessions,
  57. Messages: messages,
  58. History: files,
  59. Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
  60. LSPClients: make(map[string]*lsp.Client),
  61. globalCtx: ctx,
  62. config: cfg,
  63. watcherCancelFuncs: csync.NewSlice[context.CancelFunc](),
  64. events: make(chan tea.Msg, 100),
  65. serviceEventsWG: &sync.WaitGroup{},
  66. tuiWG: &sync.WaitGroup{},
  67. }
  68. app.setupEvents()
  69. // Initialize LSP clients in the background.
  70. app.initLSPClients(ctx)
  71. // TODO: remove the concept of agent config, most likely.
  72. if cfg.IsConfigured() {
  73. if err := app.InitCoderAgent(); err != nil {
  74. return nil, fmt.Errorf("failed to initialize coder agent: %w", err)
  75. }
  76. } else {
  77. slog.Warn("No agent configuration found")
  78. }
  79. return app, nil
  80. }
  81. // Config returns the application configuration.
  82. func (app *App) Config() *config.Config {
  83. return app.config
  84. }
  85. // RunNonInteractive handles the execution flow when a prompt is provided via
  86. // CLI flag.
  87. func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error {
  88. slog.Info("Running in non-interactive mode")
  89. ctx, cancel := context.WithCancel(ctx)
  90. defer cancel()
  91. // Start spinner if not in quiet mode.
  92. var spinner *format.Spinner
  93. if !quiet {
  94. spinner = format.NewSpinner(ctx, cancel, "Generating")
  95. spinner.Start()
  96. }
  97. // Helper function to stop spinner once.
  98. stopSpinner := func() {
  99. if !quiet && spinner != nil {
  100. spinner.Stop()
  101. spinner = nil
  102. }
  103. }
  104. defer stopSpinner()
  105. const maxPromptLengthForTitle = 100
  106. titlePrefix := "Non-interactive: "
  107. var titleSuffix string
  108. if len(prompt) > maxPromptLengthForTitle {
  109. titleSuffix = prompt[:maxPromptLengthForTitle] + "..."
  110. } else {
  111. titleSuffix = prompt
  112. }
  113. title := titlePrefix + titleSuffix
  114. sess, err := app.Sessions.Create(ctx, title)
  115. if err != nil {
  116. return fmt.Errorf("failed to create session for non-interactive mode: %w", err)
  117. }
  118. slog.Info("Created session for non-interactive run", "session_id", sess.ID)
  119. // Automatically approve all permission requests for this non-interactive session
  120. app.Permissions.AutoApproveSession(sess.ID)
  121. done, err := app.CoderAgent.Run(ctx, sess.ID, prompt)
  122. if err != nil {
  123. return fmt.Errorf("failed to start agent processing stream: %w", err)
  124. }
  125. messageEvents := app.Messages.Subscribe(ctx)
  126. readBts := 0
  127. for {
  128. select {
  129. case result := <-done:
  130. stopSpinner()
  131. if result.Error != nil {
  132. if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) {
  133. slog.Info("Non-interactive: agent processing cancelled", "session_id", sess.ID)
  134. return nil
  135. }
  136. return fmt.Errorf("agent processing failed: %w", result.Error)
  137. }
  138. msgContent := result.Message.Content().String()
  139. if len(msgContent) < readBts {
  140. slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(msgContent), "read_bytes", readBts)
  141. return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(msgContent), readBts)
  142. }
  143. fmt.Println(msgContent[readBts:])
  144. slog.Info("Non-interactive: run completed", "session_id", sess.ID)
  145. return nil
  146. case event := <-messageEvents:
  147. msg := event.Payload
  148. if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 {
  149. stopSpinner()
  150. part := msg.Content().String()[readBts:]
  151. fmt.Print(part)
  152. readBts += len(part)
  153. }
  154. case <-ctx.Done():
  155. stopSpinner()
  156. return ctx.Err()
  157. }
  158. }
  159. }
  160. func (app *App) UpdateAgentModel() error {
  161. return app.CoderAgent.UpdateModel()
  162. }
  163. func (app *App) setupEvents() {
  164. ctx, cancel := context.WithCancel(app.globalCtx)
  165. app.eventsCtx = ctx
  166. setupSubscriber(ctx, app.serviceEventsWG, "sessions", app.Sessions.Subscribe, app.events)
  167. setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events)
  168. setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
  169. setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events)
  170. setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
  171. setupSubscriber(ctx, app.serviceEventsWG, "mcp", agent.SubscribeMCPEvents, app.events)
  172. setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events)
  173. cleanupFunc := func() {
  174. cancel()
  175. app.serviceEventsWG.Wait()
  176. }
  177. app.cleanupFuncs = append(app.cleanupFuncs, cleanupFunc)
  178. }
  179. func setupSubscriber[T any](
  180. ctx context.Context,
  181. wg *sync.WaitGroup,
  182. name string,
  183. subscriber func(context.Context) <-chan pubsub.Event[T],
  184. outputCh chan<- tea.Msg,
  185. ) {
  186. wg.Add(1)
  187. go func() {
  188. defer wg.Done()
  189. subCh := subscriber(ctx)
  190. for {
  191. select {
  192. case event, ok := <-subCh:
  193. if !ok {
  194. slog.Debug("subscription channel closed", "name", name)
  195. return
  196. }
  197. var msg tea.Msg = event
  198. select {
  199. case outputCh <- msg:
  200. case <-time.After(2 * time.Second):
  201. slog.Warn("message dropped due to slow consumer", "name", name)
  202. case <-ctx.Done():
  203. slog.Debug("subscription cancelled", "name", name)
  204. return
  205. }
  206. case <-ctx.Done():
  207. slog.Debug("subscription cancelled", "name", name)
  208. return
  209. }
  210. }
  211. }()
  212. }
  213. func (app *App) InitCoderAgent() error {
  214. coderAgentCfg := app.config.Agents["coder"]
  215. if coderAgentCfg.ID == "" {
  216. return fmt.Errorf("coder agent configuration is missing")
  217. }
  218. var err error
  219. app.CoderAgent, err = agent.NewAgent(
  220. app.globalCtx,
  221. coderAgentCfg,
  222. app.Permissions,
  223. app.Sessions,
  224. app.Messages,
  225. app.History,
  226. app.LSPClients,
  227. )
  228. if err != nil {
  229. slog.Error("Failed to create coder agent", "err", err)
  230. return err
  231. }
  232. // Add MCP client cleanup to shutdown process
  233. app.cleanupFuncs = append(app.cleanupFuncs, agent.CloseMCPClients)
  234. setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events)
  235. return nil
  236. }
  237. // Subscribe sends events to the TUI as tea.Msgs.
  238. func (app *App) Subscribe(program *tea.Program) {
  239. defer log.RecoverPanic("app.Subscribe", func() {
  240. slog.Info("TUI subscription panic: attempting graceful shutdown")
  241. program.Quit()
  242. })
  243. app.tuiWG.Add(1)
  244. tuiCtx, tuiCancel := context.WithCancel(app.globalCtx)
  245. app.cleanupFuncs = append(app.cleanupFuncs, func() {
  246. slog.Debug("Cancelling TUI message handler")
  247. tuiCancel()
  248. app.tuiWG.Wait()
  249. })
  250. defer app.tuiWG.Done()
  251. for {
  252. select {
  253. case <-tuiCtx.Done():
  254. slog.Debug("TUI message handler shutting down")
  255. return
  256. case msg, ok := <-app.events:
  257. if !ok {
  258. slog.Debug("TUI message channel closed")
  259. return
  260. }
  261. program.Send(msg)
  262. }
  263. }
  264. }
  265. // Shutdown performs a graceful shutdown of the application.
  266. func (app *App) Shutdown() {
  267. if app.CoderAgent != nil {
  268. app.CoderAgent.CancelAll()
  269. }
  270. for cancel := range app.watcherCancelFuncs.Seq() {
  271. cancel()
  272. }
  273. // Wait for all LSP watchers to finish.
  274. app.lspWatcherWG.Wait()
  275. // Get all LSP clients.
  276. app.clientsMutex.RLock()
  277. clients := make(map[string]*lsp.Client, len(app.LSPClients))
  278. maps.Copy(clients, app.LSPClients)
  279. app.clientsMutex.RUnlock()
  280. // Shutdown all LSP clients.
  281. for name, client := range clients {
  282. shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
  283. if err := client.Shutdown(shutdownCtx); err != nil {
  284. slog.Error("Failed to shutdown LSP client", "name", name, "error", err)
  285. }
  286. cancel()
  287. }
  288. // Call call cleanup functions.
  289. for _, cleanup := range app.cleanupFuncs {
  290. if cleanup != nil {
  291. cleanup()
  292. }
  293. }
  294. }