agent_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. package agent
  2. import (
  3. "fmt"
  4. "os"
  5. "path/filepath"
  6. "runtime"
  7. "strings"
  8. "testing"
  9. "charm.land/fantasy"
  10. "charm.land/x/vcr"
  11. "github.com/charmbracelet/crush/internal/agent/tools"
  12. "github.com/charmbracelet/crush/internal/message"
  13. "github.com/charmbracelet/crush/internal/session"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. _ "github.com/joho/godotenv/autoload"
  17. )
  18. var modelPairs = []modelPair{
  19. {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-6"), anthropicBuilder("claude-haiku-4-5-20251001")},
  20. {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
  21. {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
  22. {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
  23. }
  24. func getModels(t *testing.T, r *vcr.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
  25. large, err := pair.largeModel(t, r)
  26. require.NoError(t, err)
  27. small, err := pair.smallModel(t, r)
  28. require.NoError(t, err)
  29. return large, small
  30. }
  31. func setupAgent(t *testing.T, pair modelPair) (SessionAgent, fakeEnv) {
  32. r := vcr.NewRecorder(t)
  33. large, small := getModels(t, r, pair)
  34. env := testEnv(t)
  35. createSimpleGoProject(t, env.workingDir)
  36. agent, err := coderAgent(r, env, large, small)
  37. require.NoError(t, err)
  38. return agent, env
  39. }
  40. func TestCoderAgent(t *testing.T) {
  41. if runtime.GOOS == "windows" {
  42. t.Skip("skipping on windows for now")
  43. }
  44. for _, pair := range modelPairs {
  45. t.Run(pair.name, func(t *testing.T) {
  46. t.Run("simple test", func(t *testing.T) {
  47. agent, env := setupAgent(t, pair)
  48. session, err := env.sessions.Create(t.Context(), "New Session")
  49. require.NoError(t, err)
  50. res, err := agent.Run(t.Context(), SessionAgentCall{
  51. Prompt: "Hello",
  52. SessionID: session.ID,
  53. MaxOutputTokens: 10000,
  54. })
  55. require.NoError(t, err)
  56. assert.NotNil(t, res)
  57. msgs, err := env.messages.List(t.Context(), session.ID)
  58. require.NoError(t, err)
  59. // Should have the agent and user message
  60. assert.Equal(t, len(msgs), 2)
  61. })
  62. t.Run("read a file", func(t *testing.T) {
  63. agent, env := setupAgent(t, pair)
  64. session, err := env.sessions.Create(t.Context(), "New Session")
  65. require.NoError(t, err)
  66. res, err := agent.Run(t.Context(), SessionAgentCall{
  67. Prompt: "Read the go mod",
  68. SessionID: session.ID,
  69. MaxOutputTokens: 10000,
  70. })
  71. require.NoError(t, err)
  72. assert.NotNil(t, res)
  73. msgs, err := env.messages.List(t.Context(), session.ID)
  74. require.NoError(t, err)
  75. foundFile := false
  76. var tcID string
  77. out:
  78. for _, msg := range msgs {
  79. if msg.Role == message.Assistant {
  80. for _, tc := range msg.ToolCalls() {
  81. if tc.Name == tools.ViewToolName {
  82. tcID = tc.ID
  83. }
  84. }
  85. }
  86. if msg.Role == message.Tool {
  87. for _, tr := range msg.ToolResults() {
  88. if tr.ToolCallID == tcID {
  89. if strings.Contains(tr.Content, "module example.com/testproject") {
  90. foundFile = true
  91. break out
  92. }
  93. }
  94. }
  95. }
  96. }
  97. require.True(t, foundFile)
  98. })
  99. t.Run("update a file", func(t *testing.T) {
  100. agent, env := setupAgent(t, pair)
  101. session, err := env.sessions.Create(t.Context(), "New Session")
  102. require.NoError(t, err)
  103. res, err := agent.Run(t.Context(), SessionAgentCall{
  104. Prompt: "update the main.go file by changing the print to say hello from crush",
  105. SessionID: session.ID,
  106. MaxOutputTokens: 10000,
  107. })
  108. require.NoError(t, err)
  109. assert.NotNil(t, res)
  110. msgs, err := env.messages.List(t.Context(), session.ID)
  111. require.NoError(t, err)
  112. foundRead := false
  113. foundWrite := false
  114. var readTCID, writeTCID string
  115. for _, msg := range msgs {
  116. if msg.Role == message.Assistant {
  117. for _, tc := range msg.ToolCalls() {
  118. if tc.Name == tools.ViewToolName {
  119. readTCID = tc.ID
  120. }
  121. if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
  122. writeTCID = tc.ID
  123. }
  124. }
  125. }
  126. if msg.Role == message.Tool {
  127. for _, tr := range msg.ToolResults() {
  128. if tr.ToolCallID == readTCID {
  129. foundRead = true
  130. }
  131. if tr.ToolCallID == writeTCID {
  132. foundWrite = true
  133. }
  134. }
  135. }
  136. }
  137. require.True(t, foundRead, "Expected to find a read operation")
  138. require.True(t, foundWrite, "Expected to find a write operation")
  139. mainGoPath := filepath.Join(env.workingDir, "main.go")
  140. content, err := os.ReadFile(mainGoPath)
  141. require.NoError(t, err)
  142. require.Contains(t, strings.ToLower(string(content)), "hello from crush")
  143. })
  144. t.Run("bash tool", func(t *testing.T) {
  145. agent, env := setupAgent(t, pair)
  146. session, err := env.sessions.Create(t.Context(), "New Session")
  147. require.NoError(t, err)
  148. res, err := agent.Run(t.Context(), SessionAgentCall{
  149. Prompt: "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
  150. SessionID: session.ID,
  151. MaxOutputTokens: 10000,
  152. })
  153. require.NoError(t, err)
  154. assert.NotNil(t, res)
  155. msgs, err := env.messages.List(t.Context(), session.ID)
  156. require.NoError(t, err)
  157. foundBash := false
  158. var bashTCID string
  159. for _, msg := range msgs {
  160. if msg.Role == message.Assistant {
  161. for _, tc := range msg.ToolCalls() {
  162. if tc.Name == tools.BashToolName {
  163. bashTCID = tc.ID
  164. }
  165. }
  166. }
  167. if msg.Role == message.Tool {
  168. for _, tr := range msg.ToolResults() {
  169. if tr.ToolCallID == bashTCID {
  170. foundBash = true
  171. }
  172. }
  173. }
  174. }
  175. require.True(t, foundBash, "Expected to find a bash operation")
  176. testFilePath := filepath.Join(env.workingDir, "test.txt")
  177. content, err := os.ReadFile(testFilePath)
  178. require.NoError(t, err)
  179. require.Contains(t, string(content), "hello bash")
  180. })
  181. t.Run("download tool", func(t *testing.T) {
  182. agent, env := setupAgent(t, pair)
  183. session, err := env.sessions.Create(t.Context(), "New Session")
  184. require.NoError(t, err)
  185. res, err := agent.Run(t.Context(), SessionAgentCall{
  186. Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
  187. SessionID: session.ID,
  188. MaxOutputTokens: 10000,
  189. })
  190. require.NoError(t, err)
  191. assert.NotNil(t, res)
  192. msgs, err := env.messages.List(t.Context(), session.ID)
  193. require.NoError(t, err)
  194. foundDownload := false
  195. var downloadTCID string
  196. for _, msg := range msgs {
  197. if msg.Role == message.Assistant {
  198. for _, tc := range msg.ToolCalls() {
  199. if tc.Name == tools.DownloadToolName {
  200. downloadTCID = tc.ID
  201. }
  202. }
  203. }
  204. if msg.Role == message.Tool {
  205. for _, tr := range msg.ToolResults() {
  206. if tr.ToolCallID == downloadTCID {
  207. foundDownload = true
  208. }
  209. }
  210. }
  211. }
  212. require.True(t, foundDownload, "Expected to find a download operation")
  213. examplePath := filepath.Join(env.workingDir, "example.txt")
  214. _, err = os.Stat(examplePath)
  215. require.NoError(t, err, "Expected example.txt file to exist")
  216. })
  217. t.Run("fetch tool", func(t *testing.T) {
  218. agent, env := setupAgent(t, pair)
  219. session, err := env.sessions.Create(t.Context(), "New Session")
  220. require.NoError(t, err)
  221. res, err := agent.Run(t.Context(), SessionAgentCall{
  222. Prompt: "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
  223. SessionID: session.ID,
  224. MaxOutputTokens: 10000,
  225. })
  226. require.NoError(t, err)
  227. assert.NotNil(t, res)
  228. msgs, err := env.messages.List(t.Context(), session.ID)
  229. require.NoError(t, err)
  230. foundFetch := false
  231. var fetchTCID string
  232. for _, msg := range msgs {
  233. if msg.Role == message.Assistant {
  234. for _, tc := range msg.ToolCalls() {
  235. if tc.Name == tools.FetchToolName {
  236. fetchTCID = tc.ID
  237. }
  238. }
  239. }
  240. if msg.Role == message.Tool {
  241. for _, tr := range msg.ToolResults() {
  242. if tr.ToolCallID == fetchTCID {
  243. foundFetch = true
  244. }
  245. }
  246. }
  247. }
  248. require.True(t, foundFetch, "Expected to find a fetch operation")
  249. })
  250. t.Run("glob tool", func(t *testing.T) {
  251. agent, env := setupAgent(t, pair)
  252. session, err := env.sessions.Create(t.Context(), "New Session")
  253. require.NoError(t, err)
  254. res, err := agent.Run(t.Context(), SessionAgentCall{
  255. Prompt: "use glob to find all .go files in the current directory",
  256. SessionID: session.ID,
  257. MaxOutputTokens: 10000,
  258. })
  259. require.NoError(t, err)
  260. assert.NotNil(t, res)
  261. msgs, err := env.messages.List(t.Context(), session.ID)
  262. require.NoError(t, err)
  263. foundGlob := false
  264. var globTCID string
  265. for _, msg := range msgs {
  266. if msg.Role == message.Assistant {
  267. for _, tc := range msg.ToolCalls() {
  268. if tc.Name == tools.GlobToolName {
  269. globTCID = tc.ID
  270. }
  271. }
  272. }
  273. if msg.Role == message.Tool {
  274. for _, tr := range msg.ToolResults() {
  275. if tr.ToolCallID == globTCID {
  276. foundGlob = true
  277. require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
  278. }
  279. }
  280. }
  281. }
  282. require.True(t, foundGlob, "Expected to find a glob operation")
  283. })
  284. t.Run("grep tool", func(t *testing.T) {
  285. agent, env := setupAgent(t, pair)
  286. session, err := env.sessions.Create(t.Context(), "New Session")
  287. require.NoError(t, err)
  288. res, err := agent.Run(t.Context(), SessionAgentCall{
  289. Prompt: "use grep to search for the word 'package' in go files",
  290. SessionID: session.ID,
  291. MaxOutputTokens: 10000,
  292. })
  293. require.NoError(t, err)
  294. assert.NotNil(t, res)
  295. msgs, err := env.messages.List(t.Context(), session.ID)
  296. require.NoError(t, err)
  297. foundGrep := false
  298. var grepTCID string
  299. for _, msg := range msgs {
  300. if msg.Role == message.Assistant {
  301. for _, tc := range msg.ToolCalls() {
  302. if tc.Name == tools.GrepToolName {
  303. grepTCID = tc.ID
  304. }
  305. }
  306. }
  307. if msg.Role == message.Tool {
  308. for _, tr := range msg.ToolResults() {
  309. if tr.ToolCallID == grepTCID {
  310. foundGrep = true
  311. require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
  312. }
  313. }
  314. }
  315. }
  316. require.True(t, foundGrep, "Expected to find a grep operation")
  317. })
  318. t.Run("ls tool", func(t *testing.T) {
  319. agent, env := setupAgent(t, pair)
  320. session, err := env.sessions.Create(t.Context(), "New Session")
  321. require.NoError(t, err)
  322. res, err := agent.Run(t.Context(), SessionAgentCall{
  323. Prompt: "use ls to list the files in the current directory",
  324. SessionID: session.ID,
  325. MaxOutputTokens: 10000,
  326. })
  327. require.NoError(t, err)
  328. assert.NotNil(t, res)
  329. msgs, err := env.messages.List(t.Context(), session.ID)
  330. require.NoError(t, err)
  331. foundLS := false
  332. var lsTCID string
  333. for _, msg := range msgs {
  334. if msg.Role == message.Assistant {
  335. for _, tc := range msg.ToolCalls() {
  336. if tc.Name == tools.LSToolName {
  337. lsTCID = tc.ID
  338. }
  339. }
  340. }
  341. if msg.Role == message.Tool {
  342. for _, tr := range msg.ToolResults() {
  343. if tr.ToolCallID == lsTCID {
  344. foundLS = true
  345. require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
  346. require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
  347. }
  348. }
  349. }
  350. }
  351. require.True(t, foundLS, "Expected to find an ls operation")
  352. })
  353. t.Run("multiedit tool", func(t *testing.T) {
  354. agent, env := setupAgent(t, pair)
  355. session, err := env.sessions.Create(t.Context(), "New Session")
  356. require.NoError(t, err)
  357. res, err := agent.Run(t.Context(), SessionAgentCall{
  358. Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
  359. SessionID: session.ID,
  360. MaxOutputTokens: 10000,
  361. })
  362. require.NoError(t, err)
  363. assert.NotNil(t, res)
  364. msgs, err := env.messages.List(t.Context(), session.ID)
  365. require.NoError(t, err)
  366. foundMultiEdit := false
  367. var multiEditTCID string
  368. for _, msg := range msgs {
  369. if msg.Role == message.Assistant {
  370. for _, tc := range msg.ToolCalls() {
  371. if tc.Name == tools.MultiEditToolName {
  372. multiEditTCID = tc.ID
  373. }
  374. }
  375. }
  376. if msg.Role == message.Tool {
  377. for _, tr := range msg.ToolResults() {
  378. if tr.ToolCallID == multiEditTCID {
  379. foundMultiEdit = true
  380. }
  381. }
  382. }
  383. }
  384. require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
  385. mainGoPath := filepath.Join(env.workingDir, "main.go")
  386. content, err := os.ReadFile(mainGoPath)
  387. require.NoError(t, err)
  388. require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
  389. })
  390. t.Run("sourcegraph tool", func(t *testing.T) {
  391. agent, env := setupAgent(t, pair)
  392. session, err := env.sessions.Create(t.Context(), "New Session")
  393. require.NoError(t, err)
  394. res, err := agent.Run(t.Context(), SessionAgentCall{
  395. Prompt: "use sourcegraph to search for 'func main' in Go repositories",
  396. SessionID: session.ID,
  397. MaxOutputTokens: 10000,
  398. })
  399. require.NoError(t, err)
  400. assert.NotNil(t, res)
  401. msgs, err := env.messages.List(t.Context(), session.ID)
  402. require.NoError(t, err)
  403. foundSourcegraph := false
  404. var sourcegraphTCID string
  405. for _, msg := range msgs {
  406. if msg.Role == message.Assistant {
  407. for _, tc := range msg.ToolCalls() {
  408. if tc.Name == tools.SourcegraphToolName {
  409. sourcegraphTCID = tc.ID
  410. }
  411. }
  412. }
  413. if msg.Role == message.Tool {
  414. for _, tr := range msg.ToolResults() {
  415. if tr.ToolCallID == sourcegraphTCID {
  416. foundSourcegraph = true
  417. }
  418. }
  419. }
  420. }
  421. require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
  422. })
  423. t.Run("write tool", func(t *testing.T) {
  424. agent, env := setupAgent(t, pair)
  425. session, err := env.sessions.Create(t.Context(), "New Session")
  426. require.NoError(t, err)
  427. res, err := agent.Run(t.Context(), SessionAgentCall{
  428. Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
  429. SessionID: session.ID,
  430. MaxOutputTokens: 10000,
  431. })
  432. require.NoError(t, err)
  433. assert.NotNil(t, res)
  434. msgs, err := env.messages.List(t.Context(), session.ID)
  435. require.NoError(t, err)
  436. foundWrite := false
  437. var writeTCID string
  438. for _, msg := range msgs {
  439. if msg.Role == message.Assistant {
  440. for _, tc := range msg.ToolCalls() {
  441. if tc.Name == tools.WriteToolName {
  442. writeTCID = tc.ID
  443. }
  444. }
  445. }
  446. if msg.Role == message.Tool {
  447. for _, tr := range msg.ToolResults() {
  448. if tr.ToolCallID == writeTCID {
  449. foundWrite = true
  450. }
  451. }
  452. }
  453. }
  454. require.True(t, foundWrite, "Expected to find a write operation")
  455. configPath := filepath.Join(env.workingDir, "config.json")
  456. content, err := os.ReadFile(configPath)
  457. require.NoError(t, err)
  458. require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
  459. require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
  460. })
  461. t.Run("parallel tool calls", func(t *testing.T) {
  462. agent, env := setupAgent(t, pair)
  463. session, err := env.sessions.Create(t.Context(), "New Session")
  464. require.NoError(t, err)
  465. res, err := agent.Run(t.Context(), SessionAgentCall{
  466. Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
  467. SessionID: session.ID,
  468. MaxOutputTokens: 10000,
  469. })
  470. require.NoError(t, err)
  471. assert.NotNil(t, res)
  472. msgs, err := env.messages.List(t.Context(), session.ID)
  473. require.NoError(t, err)
  474. var assistantMsg *message.Message
  475. var toolMsgs []message.Message
  476. for _, msg := range msgs {
  477. if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
  478. assistantMsg = &msg
  479. }
  480. if msg.Role == message.Tool {
  481. toolMsgs = append(toolMsgs, msg)
  482. }
  483. }
  484. require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
  485. require.NotNil(t, toolMsgs, "Expected to find a tool message")
  486. toolCalls := assistantMsg.ToolCalls()
  487. require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
  488. foundGlob := false
  489. foundLS := false
  490. var globTCID, lsTCID string
  491. for _, tc := range toolCalls {
  492. if tc.Name == tools.GlobToolName {
  493. foundGlob = true
  494. globTCID = tc.ID
  495. }
  496. if tc.Name == tools.LSToolName {
  497. foundLS = true
  498. lsTCID = tc.ID
  499. }
  500. }
  501. require.True(t, foundGlob, "Expected to find a glob tool call")
  502. require.True(t, foundLS, "Expected to find an ls tool call")
  503. require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
  504. foundGlobResult := false
  505. foundLSResult := false
  506. for _, msg := range toolMsgs {
  507. for _, tr := range msg.ToolResults() {
  508. if tr.ToolCallID == globTCID {
  509. foundGlobResult = true
  510. require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
  511. require.False(t, tr.IsError, "Expected glob result to not be an error")
  512. }
  513. if tr.ToolCallID == lsTCID {
  514. foundLSResult = true
  515. require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
  516. require.False(t, tr.IsError, "Expected ls result to not be an error")
  517. }
  518. }
  519. }
  520. require.True(t, foundGlobResult, "Expected to find glob tool result")
  521. require.True(t, foundLSResult, "Expected to find ls tool result")
  522. })
  523. })
  524. }
  525. }
  526. func makeTestTodos(n int) []session.Todo {
  527. todos := make([]session.Todo, n)
  528. for i := range n {
  529. todos[i] = session.Todo{
  530. Status: session.TodoStatusPending,
  531. Content: fmt.Sprintf("Task %d: Implement feature with some description that makes it realistic", i),
  532. }
  533. }
  534. return todos
  535. }
  536. func BenchmarkBuildSummaryPrompt(b *testing.B) {
  537. cases := []struct {
  538. name string
  539. numTodos int
  540. }{
  541. {"0todos", 0},
  542. {"5todos", 5},
  543. {"10todos", 10},
  544. {"50todos", 50},
  545. }
  546. for _, tc := range cases {
  547. todos := makeTestTodos(tc.numTodos)
  548. b.Run(tc.name, func(b *testing.B) {
  549. b.ReportAllocs()
  550. for range b.N {
  551. _ = buildSummaryPrompt(todos)
  552. }
  553. })
  554. }
  555. }