agent.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package backend
  2. import (
  3. "context"
  4. "github.com/charmbracelet/crush/internal/config"
  5. "github.com/charmbracelet/crush/internal/proto"
  6. )
  7. // SendMessage sends a prompt to the agent coordinator for the given
  8. // workspace and session.
  9. func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto.AgentMessage) error {
  10. ws, err := b.GetWorkspace(workspaceID)
  11. if err != nil {
  12. return err
  13. }
  14. if ws.AgentCoordinator == nil {
  15. return ErrAgentNotInitialized
  16. }
  17. _, err = ws.AgentCoordinator.Run(ctx, msg.SessionID, msg.Prompt)
  18. return err
  19. }
  20. // GetAgentInfo returns the agent's model and busy status.
  21. func (b *Backend) GetAgentInfo(workspaceID string) (proto.AgentInfo, error) {
  22. ws, err := b.GetWorkspace(workspaceID)
  23. if err != nil {
  24. return proto.AgentInfo{}, err
  25. }
  26. var agentInfo proto.AgentInfo
  27. if ws.AgentCoordinator != nil {
  28. m := ws.AgentCoordinator.Model()
  29. agentInfo = proto.AgentInfo{
  30. Model: m.CatwalkCfg,
  31. ModelCfg: m.ModelCfg,
  32. IsBusy: ws.AgentCoordinator.IsBusy(),
  33. IsReady: true,
  34. }
  35. }
  36. return agentInfo, nil
  37. }
  38. // InitAgent initializes the coder agent for the workspace.
  39. func (b *Backend) InitAgent(ctx context.Context, workspaceID string) error {
  40. ws, err := b.GetWorkspace(workspaceID)
  41. if err != nil {
  42. return err
  43. }
  44. return ws.InitCoderAgent(ctx)
  45. }
  46. // UpdateAgent reloads the agent model configuration.
  47. func (b *Backend) UpdateAgent(ctx context.Context, workspaceID string) error {
  48. ws, err := b.GetWorkspace(workspaceID)
  49. if err != nil {
  50. return err
  51. }
  52. return ws.UpdateAgentModel(ctx)
  53. }
  54. // CancelSession cancels an ongoing agent operation for the given
  55. // session.
  56. func (b *Backend) CancelSession(workspaceID, sessionID string) error {
  57. ws, err := b.GetWorkspace(workspaceID)
  58. if err != nil {
  59. return err
  60. }
  61. if ws.AgentCoordinator != nil {
  62. ws.AgentCoordinator.Cancel(sessionID)
  63. }
  64. return nil
  65. }
  66. // SummarizeSession triggers a session summarization.
  67. func (b *Backend) SummarizeSession(ctx context.Context, workspaceID, sessionID string) error {
  68. ws, err := b.GetWorkspace(workspaceID)
  69. if err != nil {
  70. return err
  71. }
  72. if ws.AgentCoordinator == nil {
  73. return ErrAgentNotInitialized
  74. }
  75. return ws.AgentCoordinator.Summarize(ctx, sessionID)
  76. }
  77. // QueuedPrompts returns the number of queued prompts for the session.
  78. func (b *Backend) QueuedPrompts(workspaceID, sessionID string) (int, error) {
  79. ws, err := b.GetWorkspace(workspaceID)
  80. if err != nil {
  81. return 0, err
  82. }
  83. if ws.AgentCoordinator == nil {
  84. return 0, nil
  85. }
  86. return ws.AgentCoordinator.QueuedPrompts(sessionID), nil
  87. }
  88. // ClearQueue clears the prompt queue for the session.
  89. func (b *Backend) ClearQueue(workspaceID, sessionID string) error {
  90. ws, err := b.GetWorkspace(workspaceID)
  91. if err != nil {
  92. return err
  93. }
  94. if ws.AgentCoordinator != nil {
  95. ws.AgentCoordinator.ClearQueue(sessionID)
  96. }
  97. return nil
  98. }
  99. // QueuedPromptsList returns the list of queued prompt strings for a
  100. // session.
  101. func (b *Backend) QueuedPromptsList(workspaceID, sessionID string) ([]string, error) {
  102. ws, err := b.GetWorkspace(workspaceID)
  103. if err != nil {
  104. return nil, err
  105. }
  106. if ws.AgentCoordinator == nil {
  107. return nil, nil
  108. }
  109. return ws.AgentCoordinator.QueuedPromptsList(sessionID), nil
  110. }
  111. // GetDefaultSmallModel returns the default small model for a provider.
  112. func (b *Backend) GetDefaultSmallModel(workspaceID, providerID string) (config.SelectedModel, error) {
  113. ws, err := b.GetWorkspace(workspaceID)
  114. if err != nil {
  115. return config.SelectedModel{}, err
  116. }
  117. return ws.GetDefaultSmallModel(providerID), nil
  118. }