agent_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. package agent
  2. import (
  3. "os"
  4. "path/filepath"
  5. "runtime"
  6. "strings"
  7. "testing"
  8. "charm.land/fantasy"
  9. "charm.land/x/vcr"
  10. "github.com/charmbracelet/crush/internal/agent/tools"
  11. "github.com/charmbracelet/crush/internal/message"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. _ "github.com/joho/godotenv/autoload"
  15. )
  16. var modelPairs = []modelPair{
  17. {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
  18. {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
  19. {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
  20. {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
  21. }
  22. func getModels(t *testing.T, r *vcr.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
  23. large, err := pair.largeModel(t, r)
  24. require.NoError(t, err)
  25. small, err := pair.smallModel(t, r)
  26. require.NoError(t, err)
  27. return large, small
  28. }
  29. func setupAgent(t *testing.T, pair modelPair) (SessionAgent, fakeEnv) {
  30. r := vcr.NewRecorder(t)
  31. large, small := getModels(t, r, pair)
  32. env := testEnv(t)
  33. createSimpleGoProject(t, env.workingDir)
  34. agent, err := coderAgent(r, env, large, small)
  35. require.NoError(t, err)
  36. return agent, env
  37. }
  38. func TestCoderAgent(t *testing.T) {
  39. if runtime.GOOS == "windows" {
  40. t.Skip("skipping on windows for now")
  41. }
  42. for _, pair := range modelPairs {
  43. t.Run(pair.name, func(t *testing.T) {
  44. t.Run("simple test", func(t *testing.T) {
  45. agent, env := setupAgent(t, pair)
  46. session, err := env.sessions.Create(t.Context(), "New Session")
  47. require.NoError(t, err)
  48. res, err := agent.Run(t.Context(), SessionAgentCall{
  49. Prompt: "Hello",
  50. SessionID: session.ID,
  51. MaxOutputTokens: 10000,
  52. })
  53. require.NoError(t, err)
  54. assert.NotNil(t, res)
  55. msgs, err := env.messages.List(t.Context(), session.ID)
  56. require.NoError(t, err)
  57. // Should have the agent and user message
  58. assert.Equal(t, len(msgs), 2)
  59. })
  60. t.Run("read a file", func(t *testing.T) {
  61. agent, env := setupAgent(t, pair)
  62. session, err := env.sessions.Create(t.Context(), "New Session")
  63. require.NoError(t, err)
  64. res, err := agent.Run(t.Context(), SessionAgentCall{
  65. Prompt: "Read the go mod",
  66. SessionID: session.ID,
  67. MaxOutputTokens: 10000,
  68. })
  69. require.NoError(t, err)
  70. assert.NotNil(t, res)
  71. msgs, err := env.messages.List(t.Context(), session.ID)
  72. require.NoError(t, err)
  73. foundFile := false
  74. var tcID string
  75. out:
  76. for _, msg := range msgs {
  77. if msg.Role == message.Assistant {
  78. for _, tc := range msg.ToolCalls() {
  79. if tc.Name == tools.ViewToolName {
  80. tcID = tc.ID
  81. }
  82. }
  83. }
  84. if msg.Role == message.Tool {
  85. for _, tr := range msg.ToolResults() {
  86. if tr.ToolCallID == tcID {
  87. if strings.Contains(tr.Content, "module example.com/testproject") {
  88. foundFile = true
  89. break out
  90. }
  91. }
  92. }
  93. }
  94. }
  95. require.True(t, foundFile)
  96. })
  97. t.Run("update a file", func(t *testing.T) {
  98. agent, env := setupAgent(t, pair)
  99. session, err := env.sessions.Create(t.Context(), "New Session")
  100. require.NoError(t, err)
  101. res, err := agent.Run(t.Context(), SessionAgentCall{
  102. Prompt: "update the main.go file by changing the print to say hello from crush",
  103. SessionID: session.ID,
  104. MaxOutputTokens: 10000,
  105. })
  106. require.NoError(t, err)
  107. assert.NotNil(t, res)
  108. msgs, err := env.messages.List(t.Context(), session.ID)
  109. require.NoError(t, err)
  110. foundRead := false
  111. foundWrite := false
  112. var readTCID, writeTCID string
  113. for _, msg := range msgs {
  114. if msg.Role == message.Assistant {
  115. for _, tc := range msg.ToolCalls() {
  116. if tc.Name == tools.ViewToolName {
  117. readTCID = tc.ID
  118. }
  119. if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
  120. writeTCID = tc.ID
  121. }
  122. }
  123. }
  124. if msg.Role == message.Tool {
  125. for _, tr := range msg.ToolResults() {
  126. if tr.ToolCallID == readTCID {
  127. foundRead = true
  128. }
  129. if tr.ToolCallID == writeTCID {
  130. foundWrite = true
  131. }
  132. }
  133. }
  134. }
  135. require.True(t, foundRead, "Expected to find a read operation")
  136. require.True(t, foundWrite, "Expected to find a write operation")
  137. mainGoPath := filepath.Join(env.workingDir, "main.go")
  138. content, err := os.ReadFile(mainGoPath)
  139. require.NoError(t, err)
  140. require.Contains(t, strings.ToLower(string(content)), "hello from crush")
  141. })
  142. t.Run("bash tool", func(t *testing.T) {
  143. agent, env := setupAgent(t, pair)
  144. session, err := env.sessions.Create(t.Context(), "New Session")
  145. require.NoError(t, err)
  146. res, err := agent.Run(t.Context(), SessionAgentCall{
  147. Prompt: "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
  148. SessionID: session.ID,
  149. MaxOutputTokens: 10000,
  150. })
  151. require.NoError(t, err)
  152. assert.NotNil(t, res)
  153. msgs, err := env.messages.List(t.Context(), session.ID)
  154. require.NoError(t, err)
  155. foundBash := false
  156. var bashTCID string
  157. for _, msg := range msgs {
  158. if msg.Role == message.Assistant {
  159. for _, tc := range msg.ToolCalls() {
  160. if tc.Name == tools.BashToolName {
  161. bashTCID = tc.ID
  162. }
  163. }
  164. }
  165. if msg.Role == message.Tool {
  166. for _, tr := range msg.ToolResults() {
  167. if tr.ToolCallID == bashTCID {
  168. foundBash = true
  169. }
  170. }
  171. }
  172. }
  173. require.True(t, foundBash, "Expected to find a bash operation")
  174. testFilePath := filepath.Join(env.workingDir, "test.txt")
  175. content, err := os.ReadFile(testFilePath)
  176. require.NoError(t, err)
  177. require.Contains(t, string(content), "hello bash")
  178. })
  179. t.Run("download tool", func(t *testing.T) {
  180. agent, env := setupAgent(t, pair)
  181. session, err := env.sessions.Create(t.Context(), "New Session")
  182. require.NoError(t, err)
  183. res, err := agent.Run(t.Context(), SessionAgentCall{
  184. Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
  185. SessionID: session.ID,
  186. MaxOutputTokens: 10000,
  187. })
  188. require.NoError(t, err)
  189. assert.NotNil(t, res)
  190. msgs, err := env.messages.List(t.Context(), session.ID)
  191. require.NoError(t, err)
  192. foundDownload := false
  193. var downloadTCID string
  194. for _, msg := range msgs {
  195. if msg.Role == message.Assistant {
  196. for _, tc := range msg.ToolCalls() {
  197. if tc.Name == tools.DownloadToolName {
  198. downloadTCID = tc.ID
  199. }
  200. }
  201. }
  202. if msg.Role == message.Tool {
  203. for _, tr := range msg.ToolResults() {
  204. if tr.ToolCallID == downloadTCID {
  205. foundDownload = true
  206. }
  207. }
  208. }
  209. }
  210. require.True(t, foundDownload, "Expected to find a download operation")
  211. examplePath := filepath.Join(env.workingDir, "example.txt")
  212. _, err = os.Stat(examplePath)
  213. require.NoError(t, err, "Expected example.txt file to exist")
  214. })
  215. t.Run("fetch tool", func(t *testing.T) {
  216. agent, env := setupAgent(t, pair)
  217. session, err := env.sessions.Create(t.Context(), "New Session")
  218. require.NoError(t, err)
  219. res, err := agent.Run(t.Context(), SessionAgentCall{
  220. Prompt: "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
  221. SessionID: session.ID,
  222. MaxOutputTokens: 10000,
  223. })
  224. require.NoError(t, err)
  225. assert.NotNil(t, res)
  226. msgs, err := env.messages.List(t.Context(), session.ID)
  227. require.NoError(t, err)
  228. foundFetch := false
  229. var fetchTCID string
  230. for _, msg := range msgs {
  231. if msg.Role == message.Assistant {
  232. for _, tc := range msg.ToolCalls() {
  233. if tc.Name == tools.FetchToolName {
  234. fetchTCID = tc.ID
  235. }
  236. }
  237. }
  238. if msg.Role == message.Tool {
  239. for _, tr := range msg.ToolResults() {
  240. if tr.ToolCallID == fetchTCID {
  241. foundFetch = true
  242. }
  243. }
  244. }
  245. }
  246. require.True(t, foundFetch, "Expected to find a fetch operation")
  247. })
  248. t.Run("glob tool", func(t *testing.T) {
  249. agent, env := setupAgent(t, pair)
  250. session, err := env.sessions.Create(t.Context(), "New Session")
  251. require.NoError(t, err)
  252. res, err := agent.Run(t.Context(), SessionAgentCall{
  253. Prompt: "use glob to find all .go files in the current directory",
  254. SessionID: session.ID,
  255. MaxOutputTokens: 10000,
  256. })
  257. require.NoError(t, err)
  258. assert.NotNil(t, res)
  259. msgs, err := env.messages.List(t.Context(), session.ID)
  260. require.NoError(t, err)
  261. foundGlob := false
  262. var globTCID string
  263. for _, msg := range msgs {
  264. if msg.Role == message.Assistant {
  265. for _, tc := range msg.ToolCalls() {
  266. if tc.Name == tools.GlobToolName {
  267. globTCID = tc.ID
  268. }
  269. }
  270. }
  271. if msg.Role == message.Tool {
  272. for _, tr := range msg.ToolResults() {
  273. if tr.ToolCallID == globTCID {
  274. foundGlob = true
  275. require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
  276. }
  277. }
  278. }
  279. }
  280. require.True(t, foundGlob, "Expected to find a glob operation")
  281. })
  282. t.Run("grep tool", func(t *testing.T) {
  283. agent, env := setupAgent(t, pair)
  284. session, err := env.sessions.Create(t.Context(), "New Session")
  285. require.NoError(t, err)
  286. res, err := agent.Run(t.Context(), SessionAgentCall{
  287. Prompt: "use grep to search for the word 'package' in go files",
  288. SessionID: session.ID,
  289. MaxOutputTokens: 10000,
  290. })
  291. require.NoError(t, err)
  292. assert.NotNil(t, res)
  293. msgs, err := env.messages.List(t.Context(), session.ID)
  294. require.NoError(t, err)
  295. foundGrep := false
  296. var grepTCID string
  297. for _, msg := range msgs {
  298. if msg.Role == message.Assistant {
  299. for _, tc := range msg.ToolCalls() {
  300. if tc.Name == tools.GrepToolName {
  301. grepTCID = tc.ID
  302. }
  303. }
  304. }
  305. if msg.Role == message.Tool {
  306. for _, tr := range msg.ToolResults() {
  307. if tr.ToolCallID == grepTCID {
  308. foundGrep = true
  309. require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
  310. }
  311. }
  312. }
  313. }
  314. require.True(t, foundGrep, "Expected to find a grep operation")
  315. })
  316. t.Run("ls tool", func(t *testing.T) {
  317. agent, env := setupAgent(t, pair)
  318. session, err := env.sessions.Create(t.Context(), "New Session")
  319. require.NoError(t, err)
  320. res, err := agent.Run(t.Context(), SessionAgentCall{
  321. Prompt: "use ls to list the files in the current directory",
  322. SessionID: session.ID,
  323. MaxOutputTokens: 10000,
  324. })
  325. require.NoError(t, err)
  326. assert.NotNil(t, res)
  327. msgs, err := env.messages.List(t.Context(), session.ID)
  328. require.NoError(t, err)
  329. foundLS := false
  330. var lsTCID string
  331. for _, msg := range msgs {
  332. if msg.Role == message.Assistant {
  333. for _, tc := range msg.ToolCalls() {
  334. if tc.Name == tools.LSToolName {
  335. lsTCID = tc.ID
  336. }
  337. }
  338. }
  339. if msg.Role == message.Tool {
  340. for _, tr := range msg.ToolResults() {
  341. if tr.ToolCallID == lsTCID {
  342. foundLS = true
  343. require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
  344. require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
  345. }
  346. }
  347. }
  348. }
  349. require.True(t, foundLS, "Expected to find an ls operation")
  350. })
  351. t.Run("multiedit tool", func(t *testing.T) {
  352. agent, env := setupAgent(t, pair)
  353. session, err := env.sessions.Create(t.Context(), "New Session")
  354. require.NoError(t, err)
  355. res, err := agent.Run(t.Context(), SessionAgentCall{
  356. Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
  357. SessionID: session.ID,
  358. MaxOutputTokens: 10000,
  359. })
  360. require.NoError(t, err)
  361. assert.NotNil(t, res)
  362. msgs, err := env.messages.List(t.Context(), session.ID)
  363. require.NoError(t, err)
  364. foundMultiEdit := false
  365. var multiEditTCID string
  366. for _, msg := range msgs {
  367. if msg.Role == message.Assistant {
  368. for _, tc := range msg.ToolCalls() {
  369. if tc.Name == tools.MultiEditToolName {
  370. multiEditTCID = tc.ID
  371. }
  372. }
  373. }
  374. if msg.Role == message.Tool {
  375. for _, tr := range msg.ToolResults() {
  376. if tr.ToolCallID == multiEditTCID {
  377. foundMultiEdit = true
  378. }
  379. }
  380. }
  381. }
  382. require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
  383. mainGoPath := filepath.Join(env.workingDir, "main.go")
  384. content, err := os.ReadFile(mainGoPath)
  385. require.NoError(t, err)
  386. require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
  387. })
  388. t.Run("sourcegraph tool", func(t *testing.T) {
  389. if runtime.GOOS == "darwin" {
  390. t.Skip("skipping flacky test on macos for now")
  391. }
  392. agent, env := setupAgent(t, pair)
  393. session, err := env.sessions.Create(t.Context(), "New Session")
  394. require.NoError(t, err)
  395. res, err := agent.Run(t.Context(), SessionAgentCall{
  396. Prompt: "use sourcegraph to search for 'func main' in Go repositories",
  397. SessionID: session.ID,
  398. MaxOutputTokens: 10000,
  399. })
  400. require.NoError(t, err)
  401. assert.NotNil(t, res)
  402. msgs, err := env.messages.List(t.Context(), session.ID)
  403. require.NoError(t, err)
  404. foundSourcegraph := false
  405. var sourcegraphTCID string
  406. for _, msg := range msgs {
  407. if msg.Role == message.Assistant {
  408. for _, tc := range msg.ToolCalls() {
  409. if tc.Name == tools.SourcegraphToolName {
  410. sourcegraphTCID = tc.ID
  411. }
  412. }
  413. }
  414. if msg.Role == message.Tool {
  415. for _, tr := range msg.ToolResults() {
  416. if tr.ToolCallID == sourcegraphTCID {
  417. foundSourcegraph = true
  418. }
  419. }
  420. }
  421. }
  422. require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
  423. })
  424. t.Run("write tool", func(t *testing.T) {
  425. agent, env := setupAgent(t, pair)
  426. session, err := env.sessions.Create(t.Context(), "New Session")
  427. require.NoError(t, err)
  428. res, err := agent.Run(t.Context(), SessionAgentCall{
  429. Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
  430. SessionID: session.ID,
  431. MaxOutputTokens: 10000,
  432. })
  433. require.NoError(t, err)
  434. assert.NotNil(t, res)
  435. msgs, err := env.messages.List(t.Context(), session.ID)
  436. require.NoError(t, err)
  437. foundWrite := false
  438. var writeTCID string
  439. for _, msg := range msgs {
  440. if msg.Role == message.Assistant {
  441. for _, tc := range msg.ToolCalls() {
  442. if tc.Name == tools.WriteToolName {
  443. writeTCID = tc.ID
  444. }
  445. }
  446. }
  447. if msg.Role == message.Tool {
  448. for _, tr := range msg.ToolResults() {
  449. if tr.ToolCallID == writeTCID {
  450. foundWrite = true
  451. }
  452. }
  453. }
  454. }
  455. require.True(t, foundWrite, "Expected to find a write operation")
  456. configPath := filepath.Join(env.workingDir, "config.json")
  457. content, err := os.ReadFile(configPath)
  458. require.NoError(t, err)
  459. require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
  460. require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
  461. })
  462. t.Run("parallel tool calls", func(t *testing.T) {
  463. agent, env := setupAgent(t, pair)
  464. session, err := env.sessions.Create(t.Context(), "New Session")
  465. require.NoError(t, err)
  466. res, err := agent.Run(t.Context(), SessionAgentCall{
  467. Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
  468. SessionID: session.ID,
  469. MaxOutputTokens: 10000,
  470. })
  471. require.NoError(t, err)
  472. assert.NotNil(t, res)
  473. msgs, err := env.messages.List(t.Context(), session.ID)
  474. require.NoError(t, err)
  475. var assistantMsg *message.Message
  476. var toolMsgs []message.Message
  477. for _, msg := range msgs {
  478. if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
  479. assistantMsg = &msg
  480. }
  481. if msg.Role == message.Tool {
  482. toolMsgs = append(toolMsgs, msg)
  483. }
  484. }
  485. require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
  486. require.NotNil(t, toolMsgs, "Expected to find a tool message")
  487. toolCalls := assistantMsg.ToolCalls()
  488. require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
  489. foundGlob := false
  490. foundLS := false
  491. var globTCID, lsTCID string
  492. for _, tc := range toolCalls {
  493. if tc.Name == tools.GlobToolName {
  494. foundGlob = true
  495. globTCID = tc.ID
  496. }
  497. if tc.Name == tools.LSToolName {
  498. foundLS = true
  499. lsTCID = tc.ID
  500. }
  501. }
  502. require.True(t, foundGlob, "Expected to find a glob tool call")
  503. require.True(t, foundLS, "Expected to find an ls tool call")
  504. require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
  505. foundGlobResult := false
  506. foundLSResult := false
  507. for _, msg := range toolMsgs {
  508. for _, tr := range msg.ToolResults() {
  509. if tr.ToolCallID == globTCID {
  510. foundGlobResult = true
  511. require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
  512. require.False(t, tr.IsError, "Expected glob result to not be an error")
  513. }
  514. if tr.ToolCallID == lsTCID {
  515. foundLSResult = true
  516. require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
  517. require.False(t, tr.IsError, "Expected ls result to not be an error")
  518. }
  519. }
  520. }
  521. require.True(t, foundGlobResult, "Expected to find glob tool result")
  522. require.True(t, foundLSResult, "Expected to find ls tool result")
  523. })
  524. })
  525. }
  526. }