app.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. package app
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "fmt"
  7. "log/slog"
  8. "maps"
  9. "sync"
  10. "time"
  11. tea "github.com/charmbracelet/bubbletea/v2"
  12. "github.com/charmbracelet/crush/internal/config"
  13. "github.com/charmbracelet/crush/internal/db"
  14. "github.com/charmbracelet/crush/internal/format"
  15. "github.com/charmbracelet/crush/internal/history"
  16. "github.com/charmbracelet/crush/internal/llm/agent"
  17. "github.com/charmbracelet/crush/internal/log"
  18. "github.com/charmbracelet/crush/internal/pubsub"
  19. "github.com/charmbracelet/crush/internal/lsp"
  20. "github.com/charmbracelet/crush/internal/message"
  21. "github.com/charmbracelet/crush/internal/permission"
  22. "github.com/charmbracelet/crush/internal/session"
  23. )
  24. type App struct {
  25. Sessions session.Service
  26. Messages message.Service
  27. History history.Service
  28. Permissions permission.Service
  29. CoderAgent agent.Service
  30. LSPClients map[string]*lsp.Client
  31. clientsMutex sync.RWMutex
  32. watcherCancelFuncs []context.CancelFunc
  33. cancelFuncsMutex sync.Mutex
  34. lspWatcherWG sync.WaitGroup
  35. config *config.Config
  36. serviceEventsWG *sync.WaitGroup
  37. eventsCtx context.Context
  38. events chan tea.Msg
  39. tuiWG *sync.WaitGroup
  40. // global context and cleanup functions
  41. globalCtx context.Context
  42. cleanupFuncs []func()
  43. }
  44. // New initializes a new applcation instance.
  45. func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
  46. q := db.New(conn)
  47. sessions := session.NewService(q)
  48. messages := message.NewService(q)
  49. files := history.NewService(q, conn)
  50. skipPermissionsRequests := cfg.Permissions != nil && cfg.Permissions.SkipRequests
  51. allowedTools := []string{}
  52. if cfg.Permissions != nil && cfg.Permissions.AllowedTools != nil {
  53. allowedTools = cfg.Permissions.AllowedTools
  54. }
  55. app := &App{
  56. Sessions: sessions,
  57. Messages: messages,
  58. History: files,
  59. Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
  60. LSPClients: make(map[string]*lsp.Client),
  61. globalCtx: ctx,
  62. config: cfg,
  63. events: make(chan tea.Msg, 100),
  64. serviceEventsWG: &sync.WaitGroup{},
  65. tuiWG: &sync.WaitGroup{},
  66. }
  67. app.setupEvents()
  68. // Initialize LSP clients in the background.
  69. app.initLSPClients(ctx)
  70. // TODO: remove the concept of agent config, most likely.
  71. if cfg.IsConfigured() {
  72. if err := app.InitCoderAgent(); err != nil {
  73. return nil, fmt.Errorf("failed to initialize coder agent: %w", err)
  74. }
  75. } else {
  76. slog.Warn("No agent configuration found")
  77. }
  78. return app, nil
  79. }
  80. // Config returns the application configuration.
  81. func (app *App) Config() *config.Config {
  82. return app.config
  83. }
  84. // RunNonInteractive handles the execution flow when a prompt is provided via
  85. // CLI flag.
  86. func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool) error {
  87. slog.Info("Running in non-interactive mode")
  88. ctx, cancel := context.WithCancel(ctx)
  89. defer cancel()
  90. // Start spinner if not in quiet mode.
  91. var spinner *format.Spinner
  92. if !quiet {
  93. spinner = format.NewSpinner(ctx, cancel, "Generating")
  94. spinner.Start()
  95. }
  96. // Helper function to stop spinner once.
  97. stopSpinner := func() {
  98. if !quiet && spinner != nil {
  99. spinner.Stop()
  100. spinner = nil
  101. }
  102. }
  103. defer stopSpinner()
  104. const maxPromptLengthForTitle = 100
  105. titlePrefix := "Non-interactive: "
  106. var titleSuffix string
  107. if len(prompt) > maxPromptLengthForTitle {
  108. titleSuffix = prompt[:maxPromptLengthForTitle] + "..."
  109. } else {
  110. titleSuffix = prompt
  111. }
  112. title := titlePrefix + titleSuffix
  113. sess, err := app.Sessions.Create(ctx, title)
  114. if err != nil {
  115. return fmt.Errorf("failed to create session for non-interactive mode: %w", err)
  116. }
  117. slog.Info("Created session for non-interactive run", "session_id", sess.ID)
  118. // Automatically approve all permission requests for this non-interactive session
  119. app.Permissions.AutoApproveSession(sess.ID)
  120. done, err := app.CoderAgent.Run(ctx, sess.ID, prompt)
  121. if err != nil {
  122. return fmt.Errorf("failed to start agent processing stream: %w", err)
  123. }
  124. messageEvents := app.Messages.Subscribe(ctx)
  125. readBts := 0
  126. for {
  127. select {
  128. case result := <-done:
  129. stopSpinner()
  130. if result.Error != nil {
  131. if errors.Is(result.Error, context.Canceled) || errors.Is(result.Error, agent.ErrRequestCancelled) {
  132. slog.Info("Non-interactive: agent processing cancelled", "session_id", sess.ID)
  133. return nil
  134. }
  135. return fmt.Errorf("agent processing failed: %w", result.Error)
  136. }
  137. msgContent := result.Message.Content().String()
  138. if len(msgContent) < readBts {
  139. slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(msgContent), "read_bytes", readBts)
  140. return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(msgContent), readBts)
  141. }
  142. fmt.Println(msgContent[readBts:])
  143. slog.Info("Non-interactive: run completed", "session_id", sess.ID)
  144. return nil
  145. case event := <-messageEvents:
  146. msg := event.Payload
  147. if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 {
  148. stopSpinner()
  149. part := msg.Content().String()[readBts:]
  150. fmt.Print(part)
  151. readBts += len(part)
  152. }
  153. case <-ctx.Done():
  154. stopSpinner()
  155. return ctx.Err()
  156. }
  157. }
  158. }
  159. func (app *App) UpdateAgentModel() error {
  160. return app.CoderAgent.UpdateModel()
  161. }
  162. func (app *App) setupEvents() {
  163. ctx, cancel := context.WithCancel(app.globalCtx)
  164. app.eventsCtx = ctx
  165. setupSubscriber(ctx, app.serviceEventsWG, "sessions", app.Sessions.Subscribe, app.events)
  166. setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events)
  167. setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
  168. setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
  169. cleanupFunc := func() {
  170. cancel()
  171. app.serviceEventsWG.Wait()
  172. }
  173. app.cleanupFuncs = append(app.cleanupFuncs, cleanupFunc)
  174. }
  175. func setupSubscriber[T any](
  176. ctx context.Context,
  177. wg *sync.WaitGroup,
  178. name string,
  179. subscriber func(context.Context) <-chan pubsub.Event[T],
  180. outputCh chan<- tea.Msg,
  181. ) {
  182. wg.Add(1)
  183. go func() {
  184. defer wg.Done()
  185. subCh := subscriber(ctx)
  186. for {
  187. select {
  188. case event, ok := <-subCh:
  189. if !ok {
  190. slog.Debug("subscription channel closed", "name", name)
  191. return
  192. }
  193. var msg tea.Msg = event
  194. select {
  195. case outputCh <- msg:
  196. case <-time.After(2 * time.Second):
  197. slog.Warn("message dropped due to slow consumer", "name", name)
  198. case <-ctx.Done():
  199. slog.Debug("subscription cancelled", "name", name)
  200. return
  201. }
  202. case <-ctx.Done():
  203. slog.Debug("subscription cancelled", "name", name)
  204. return
  205. }
  206. }
  207. }()
  208. }
  209. func (app *App) InitCoderAgent() error {
  210. coderAgentCfg := app.config.Agents["coder"]
  211. if coderAgentCfg.ID == "" {
  212. return fmt.Errorf("coder agent configuration is missing")
  213. }
  214. var err error
  215. app.CoderAgent, err = agent.NewAgent(
  216. coderAgentCfg,
  217. app.Permissions,
  218. app.Sessions,
  219. app.Messages,
  220. app.History,
  221. app.LSPClients,
  222. )
  223. if err != nil {
  224. slog.Error("Failed to create coder agent", "err", err)
  225. return err
  226. }
  227. setupSubscriber(app.eventsCtx, app.serviceEventsWG, "coderAgent", app.CoderAgent.Subscribe, app.events)
  228. return nil
  229. }
  230. // Subscribe sends events to the TUI as tea.Msgs.
  231. func (app *App) Subscribe(program *tea.Program) {
  232. defer log.RecoverPanic("app.Subscribe", func() {
  233. slog.Info("TUI subscription panic: attempting graceful shutdown")
  234. program.Quit()
  235. })
  236. app.tuiWG.Add(1)
  237. tuiCtx, tuiCancel := context.WithCancel(app.globalCtx)
  238. app.cleanupFuncs = append(app.cleanupFuncs, func() {
  239. slog.Debug("Cancelling TUI message handler")
  240. tuiCancel()
  241. app.tuiWG.Wait()
  242. })
  243. defer app.tuiWG.Done()
  244. for {
  245. select {
  246. case <-tuiCtx.Done():
  247. slog.Debug("TUI message handler shutting down")
  248. return
  249. case msg, ok := <-app.events:
  250. if !ok {
  251. slog.Debug("TUI message channel closed")
  252. return
  253. }
  254. program.Send(msg)
  255. }
  256. }
  257. }
  258. // Shutdown performs a graceful shutdown of the application.
  259. func (app *App) Shutdown() {
  260. if app.CoderAgent != nil {
  261. app.CoderAgent.CancelAll()
  262. }
  263. app.cancelFuncsMutex.Lock()
  264. for _, cancel := range app.watcherCancelFuncs {
  265. cancel()
  266. }
  267. app.cancelFuncsMutex.Unlock()
  268. // Wait for all LSP watchers to finish.
  269. app.lspWatcherWG.Wait()
  270. // Get all LSP clients.
  271. app.clientsMutex.RLock()
  272. clients := make(map[string]*lsp.Client, len(app.LSPClients))
  273. maps.Copy(clients, app.LSPClients)
  274. app.clientsMutex.RUnlock()
  275. // Shutdown all LSP clients.
  276. for name, client := range clients {
  277. shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
  278. if err := client.Shutdown(shutdownCtx); err != nil {
  279. slog.Error("Failed to shutdown LSP client", "name", name, "error", err)
  280. }
  281. cancel()
  282. }
  283. // Call call cleanup functions.
  284. for _, cleanup := range app.cleanupFuncs {
  285. if cleanup != nil {
  286. cleanup()
  287. }
  288. }
  289. }