coordinator_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. package agent
  2. import (
  3. "context"
  4. "errors"
  5. "testing"
  6. "charm.land/catwalk/pkg/catwalk"
  7. "charm.land/fantasy"
  8. "github.com/charmbracelet/crush/internal/config"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. )
  12. // mockSessionAgent is a minimal mock for the SessionAgent interface.
  13. type mockSessionAgent struct {
  14. model Model
  15. runFunc func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error)
  16. cancelled []string
  17. }
  18. func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  19. return m.runFunc(ctx, call)
  20. }
  21. func (m *mockSessionAgent) Model() Model { return m.model }
  22. func (m *mockSessionAgent) SetModels(large, small Model) {}
  23. func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {}
  24. func (m *mockSessionAgent) SetSystemPrompt(systemPrompt string) {}
  25. func (m *mockSessionAgent) Cancel(sessionID string) {
  26. m.cancelled = append(m.cancelled, sessionID)
  27. }
  28. func (m *mockSessionAgent) CancelAll() {}
  29. func (m *mockSessionAgent) IsSessionBusy(sessionID string) bool { return false }
  30. func (m *mockSessionAgent) IsBusy() bool { return false }
  31. func (m *mockSessionAgent) QueuedPrompts(sessionID string) int { return 0 }
  32. func (m *mockSessionAgent) QueuedPromptsList(sessionID string) []string { return nil }
  33. func (m *mockSessionAgent) ClearQueue(sessionID string) {}
  34. func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOptions) error {
  35. return nil
  36. }
  37. // newTestCoordinator creates a minimal coordinator for unit testing runSubAgent.
  38. func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator {
  39. cfg, err := config.Init(env.workingDir, "", false)
  40. require.NoError(t, err)
  41. cfg.Config().Providers.Set(providerID, providerCfg)
  42. return &coordinator{
  43. cfg: cfg,
  44. sessions: env.sessions,
  45. }
  46. }
  47. // newMockAgent creates a mockSessionAgent with the given provider and run function.
  48. func newMockAgent(providerID string, maxTokens int64, runFunc func(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)) *mockSessionAgent {
  49. return &mockSessionAgent{
  50. model: Model{
  51. CatwalkCfg: catwalk.Model{
  52. DefaultMaxTokens: maxTokens,
  53. },
  54. ModelCfg: config.SelectedModel{
  55. Provider: providerID,
  56. },
  57. },
  58. runFunc: runFunc,
  59. }
  60. }
  61. // agentResultWithText creates a minimal AgentResult with the given text response.
  62. func agentResultWithText(text string) *fantasy.AgentResult {
  63. return &fantasy.AgentResult{
  64. Response: fantasy.Response{
  65. Content: fantasy.ResponseContent{
  66. fantasy.TextContent{Text: text},
  67. },
  68. },
  69. }
  70. }
  71. func TestRunSubAgent(t *testing.T) {
  72. const providerID = "test-provider"
  73. providerCfg := config.ProviderConfig{ID: providerID}
  74. t.Run("happy path", func(t *testing.T) {
  75. env := testEnv(t)
  76. coord := newTestCoordinator(t, env, providerID, providerCfg)
  77. parentSession, err := env.sessions.Create(t.Context(), "Parent")
  78. require.NoError(t, err)
  79. agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  80. assert.Equal(t, "do something", call.Prompt)
  81. assert.Equal(t, int64(4096), call.MaxOutputTokens)
  82. return agentResultWithText("done"), nil
  83. })
  84. resp, err := coord.runSubAgent(t.Context(), subAgentParams{
  85. Agent: agent,
  86. SessionID: parentSession.ID,
  87. AgentMessageID: "msg-1",
  88. ToolCallID: "call-1",
  89. Prompt: "do something",
  90. SessionTitle: "Test Session",
  91. })
  92. require.NoError(t, err)
  93. assert.Equal(t, "done", resp.Content)
  94. assert.False(t, resp.IsError)
  95. })
  96. t.Run("ModelCfg.MaxTokens overrides default", func(t *testing.T) {
  97. env := testEnv(t)
  98. coord := newTestCoordinator(t, env, providerID, providerCfg)
  99. parentSession, err := env.sessions.Create(t.Context(), "Parent")
  100. require.NoError(t, err)
  101. agent := &mockSessionAgent{
  102. model: Model{
  103. CatwalkCfg: catwalk.Model{
  104. DefaultMaxTokens: 4096,
  105. },
  106. ModelCfg: config.SelectedModel{
  107. Provider: providerID,
  108. MaxTokens: 8192,
  109. },
  110. },
  111. runFunc: func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  112. assert.Equal(t, int64(8192), call.MaxOutputTokens)
  113. return agentResultWithText("ok"), nil
  114. },
  115. }
  116. resp, err := coord.runSubAgent(t.Context(), subAgentParams{
  117. Agent: agent,
  118. SessionID: parentSession.ID,
  119. AgentMessageID: "msg-1",
  120. ToolCallID: "call-1",
  121. Prompt: "test",
  122. SessionTitle: "Test",
  123. })
  124. require.NoError(t, err)
  125. assert.Equal(t, "ok", resp.Content)
  126. })
  127. t.Run("session creation failure with canceled context", func(t *testing.T) {
  128. env := testEnv(t)
  129. coord := newTestCoordinator(t, env, providerID, providerCfg)
  130. parentSession, err := env.sessions.Create(t.Context(), "Parent")
  131. require.NoError(t, err)
  132. agent := newMockAgent(providerID, 4096, nil)
  133. // Use a canceled context to trigger CreateTaskSession failure.
  134. ctx, cancel := context.WithCancel(t.Context())
  135. cancel()
  136. _, err = coord.runSubAgent(ctx, subAgentParams{
  137. Agent: agent,
  138. SessionID: parentSession.ID,
  139. AgentMessageID: "msg-1",
  140. ToolCallID: "call-1",
  141. Prompt: "test",
  142. SessionTitle: "Test",
  143. })
  144. require.Error(t, err)
  145. })
  146. t.Run("provider not configured", func(t *testing.T) {
  147. env := testEnv(t)
  148. coord := newTestCoordinator(t, env, providerID, providerCfg)
  149. parentSession, err := env.sessions.Create(t.Context(), "Parent")
  150. require.NoError(t, err)
  151. // Agent references a provider that doesn't exist in config.
  152. agent := newMockAgent("unknown-provider", 4096, nil)
  153. _, err = coord.runSubAgent(t.Context(), subAgentParams{
  154. Agent: agent,
  155. SessionID: parentSession.ID,
  156. AgentMessageID: "msg-1",
  157. ToolCallID: "call-1",
  158. Prompt: "test",
  159. SessionTitle: "Test",
  160. })
  161. require.Error(t, err)
  162. assert.Contains(t, err.Error(), "model provider not configured")
  163. })
  164. t.Run("agent run error returns error response", func(t *testing.T) {
  165. env := testEnv(t)
  166. coord := newTestCoordinator(t, env, providerID, providerCfg)
  167. parentSession, err := env.sessions.Create(t.Context(), "Parent")
  168. require.NoError(t, err)
  169. agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
  170. return nil, errors.New("agent exploded")
  171. })
  172. resp, err := coord.runSubAgent(t.Context(), subAgentParams{
  173. Agent: agent,
  174. SessionID: parentSession.ID,
  175. AgentMessageID: "msg-1",
  176. ToolCallID: "call-1",
  177. Prompt: "test",
  178. SessionTitle: "Test",
  179. })
  180. // runSubAgent returns (errorResponse, nil) when agent.Run fails — not a Go error.
  181. require.NoError(t, err)
  182. assert.True(t, resp.IsError)
  183. assert.Equal(t, "error generating response", resp.Content)
  184. })
  185. t.Run("session setup callback is invoked", func(t *testing.T) {
  186. env := testEnv(t)
  187. coord := newTestCoordinator(t, env, providerID, providerCfg)
  188. parentSession, err := env.sessions.Create(t.Context(), "Parent")
  189. require.NoError(t, err)
  190. var setupCalledWith string
  191. agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
  192. return agentResultWithText("ok"), nil
  193. })
  194. _, err = coord.runSubAgent(t.Context(), subAgentParams{
  195. Agent: agent,
  196. SessionID: parentSession.ID,
  197. AgentMessageID: "msg-1",
  198. ToolCallID: "call-1",
  199. Prompt: "test",
  200. SessionTitle: "Test",
  201. SessionSetup: func(sessionID string) {
  202. setupCalledWith = sessionID
  203. },
  204. })
  205. require.NoError(t, err)
  206. assert.NotEmpty(t, setupCalledWith, "SessionSetup should have been called")
  207. })
  208. t.Run("cost propagation to parent session", func(t *testing.T) {
  209. env := testEnv(t)
  210. coord := newTestCoordinator(t, env, providerID, providerCfg)
  211. parentSession, err := env.sessions.Create(t.Context(), "Parent")
  212. require.NoError(t, err)
  213. agent := newMockAgent(providerID, 4096, func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
  214. // Simulate the agent incurring cost by updating the child session.
  215. childSession, err := env.sessions.Get(ctx, call.SessionID)
  216. if err != nil {
  217. return nil, err
  218. }
  219. childSession.Cost = 0.05
  220. _, err = env.sessions.Save(ctx, childSession)
  221. if err != nil {
  222. return nil, err
  223. }
  224. return agentResultWithText("ok"), nil
  225. })
  226. _, err = coord.runSubAgent(t.Context(), subAgentParams{
  227. Agent: agent,
  228. SessionID: parentSession.ID,
  229. AgentMessageID: "msg-1",
  230. ToolCallID: "call-1",
  231. Prompt: "test",
  232. SessionTitle: "Test",
  233. })
  234. require.NoError(t, err)
  235. updated, err := env.sessions.Get(t.Context(), parentSession.ID)
  236. require.NoError(t, err)
  237. assert.InDelta(t, 0.05, updated.Cost, 1e-9)
  238. })
  239. }
  240. func TestUpdateParentSessionCost(t *testing.T) {
  241. t.Run("accumulates cost correctly", func(t *testing.T) {
  242. env := testEnv(t)
  243. cfg, err := config.Init(env.workingDir, "", false)
  244. require.NoError(t, err)
  245. coord := &coordinator{cfg: cfg, sessions: env.sessions}
  246. parent, err := env.sessions.Create(t.Context(), "Parent")
  247. require.NoError(t, err)
  248. child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
  249. require.NoError(t, err)
  250. // Set child cost.
  251. child.Cost = 0.10
  252. _, err = env.sessions.Save(t.Context(), child)
  253. require.NoError(t, err)
  254. err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
  255. require.NoError(t, err)
  256. updated, err := env.sessions.Get(t.Context(), parent.ID)
  257. require.NoError(t, err)
  258. assert.InDelta(t, 0.10, updated.Cost, 1e-9)
  259. })
  260. t.Run("accumulates multiple child costs", func(t *testing.T) {
  261. env := testEnv(t)
  262. cfg, err := config.Init(env.workingDir, "", false)
  263. require.NoError(t, err)
  264. coord := &coordinator{cfg: cfg, sessions: env.sessions}
  265. parent, err := env.sessions.Create(t.Context(), "Parent")
  266. require.NoError(t, err)
  267. child1, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child1")
  268. require.NoError(t, err)
  269. child1.Cost = 0.05
  270. _, err = env.sessions.Save(t.Context(), child1)
  271. require.NoError(t, err)
  272. child2, err := env.sessions.CreateTaskSession(t.Context(), "tool-2", parent.ID, "Child2")
  273. require.NoError(t, err)
  274. child2.Cost = 0.03
  275. _, err = env.sessions.Save(t.Context(), child2)
  276. require.NoError(t, err)
  277. err = coord.updateParentSessionCost(t.Context(), child1.ID, parent.ID)
  278. require.NoError(t, err)
  279. err = coord.updateParentSessionCost(t.Context(), child2.ID, parent.ID)
  280. require.NoError(t, err)
  281. updated, err := env.sessions.Get(t.Context(), parent.ID)
  282. require.NoError(t, err)
  283. assert.InDelta(t, 0.08, updated.Cost, 1e-9)
  284. })
  285. t.Run("child session not found", func(t *testing.T) {
  286. env := testEnv(t)
  287. cfg, err := config.Init(env.workingDir, "", false)
  288. require.NoError(t, err)
  289. coord := &coordinator{cfg: cfg, sessions: env.sessions}
  290. parent, err := env.sessions.Create(t.Context(), "Parent")
  291. require.NoError(t, err)
  292. err = coord.updateParentSessionCost(t.Context(), "non-existent", parent.ID)
  293. require.Error(t, err)
  294. assert.Contains(t, err.Error(), "get child session")
  295. })
  296. t.Run("parent session not found", func(t *testing.T) {
  297. env := testEnv(t)
  298. cfg, err := config.Init(env.workingDir, "", false)
  299. require.NoError(t, err)
  300. coord := &coordinator{cfg: cfg, sessions: env.sessions}
  301. parent, err := env.sessions.Create(t.Context(), "Parent")
  302. require.NoError(t, err)
  303. child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
  304. require.NoError(t, err)
  305. err = coord.updateParentSessionCost(t.Context(), child.ID, "non-existent")
  306. require.Error(t, err)
  307. assert.Contains(t, err.Error(), "get parent session")
  308. })
  309. t.Run("zero cost handled correctly", func(t *testing.T) {
  310. env := testEnv(t)
  311. cfg, err := config.Init(env.workingDir, "", false)
  312. require.NoError(t, err)
  313. coord := &coordinator{cfg: cfg, sessions: env.sessions}
  314. parent, err := env.sessions.Create(t.Context(), "Parent")
  315. require.NoError(t, err)
  316. child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
  317. require.NoError(t, err)
  318. err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
  319. require.NoError(t, err)
  320. updated, err := env.sessions.Get(t.Context(), parent.ID)
  321. require.NoError(t, err)
  322. assert.InDelta(t, 0.0, updated.Cost, 1e-9)
  323. })
  324. }