| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621 |
- package agent
- import (
- "os"
- "path/filepath"
- "runtime"
- "strings"
- "testing"
- "charm.land/fantasy"
- "charm.land/x/vcr"
- "github.com/charmbracelet/crush/internal/agent/tools"
- "github.com/charmbracelet/crush/internal/message"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- _ "github.com/joho/godotenv/autoload"
- )
- var modelPairs = []modelPair{
- {"anthropic-sonnet", anthropicBuilder("claude-sonnet-4-5-20250929"), anthropicBuilder("claude-3-5-haiku-20241022")},
- {"openai-gpt-5", openaiBuilder("gpt-5"), openaiBuilder("gpt-4o")},
- {"openrouter-kimi-k2", openRouterBuilder("moonshotai/kimi-k2-0905"), openRouterBuilder("qwen/qwen3-next-80b-a3b-instruct")},
- {"zai-glm4.6", zAIBuilder("glm-4.6"), zAIBuilder("glm-4.5-air")},
- }
- func getModels(t *testing.T, r *vcr.Recorder, pair modelPair) (fantasy.LanguageModel, fantasy.LanguageModel) {
- large, err := pair.largeModel(t, r)
- require.NoError(t, err)
- small, err := pair.smallModel(t, r)
- require.NoError(t, err)
- return large, small
- }
- func setupAgent(t *testing.T, pair modelPair) (SessionAgent, fakeEnv) {
- r := vcr.NewRecorder(t)
- large, small := getModels(t, r, pair)
- env := testEnv(t)
- createSimpleGoProject(t, env.workingDir)
- agent, err := coderAgent(r, env, large, small)
- require.NoError(t, err)
- return agent, env
- }
- func TestCoderAgent(t *testing.T) {
- if runtime.GOOS == "windows" {
- t.Skip("skipping on windows for now")
- }
- for _, pair := range modelPairs {
- t.Run(pair.name, func(t *testing.T) {
- t.Run("simple test", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "Hello",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- // Should have the agent and user message
- assert.Equal(t, len(msgs), 2)
- })
- t.Run("read a file", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "Read the go mod",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundFile := false
- var tcID string
- out:
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.ViewToolName {
- tcID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == tcID {
- if strings.Contains(tr.Content, "module example.com/testproject") {
- foundFile = true
- break out
- }
- }
- }
- }
- }
- require.True(t, foundFile)
- })
- t.Run("update a file", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "update the main.go file by changing the print to say hello from crush",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundRead := false
- foundWrite := false
- var readTCID, writeTCID string
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.ViewToolName {
- readTCID = tc.ID
- }
- if tc.Name == tools.EditToolName || tc.Name == tools.WriteToolName {
- writeTCID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == readTCID {
- foundRead = true
- }
- if tr.ToolCallID == writeTCID {
- foundWrite = true
- }
- }
- }
- }
- require.True(t, foundRead, "Expected to find a read operation")
- require.True(t, foundWrite, "Expected to find a write operation")
- mainGoPath := filepath.Join(env.workingDir, "main.go")
- content, err := os.ReadFile(mainGoPath)
- require.NoError(t, err)
- require.Contains(t, strings.ToLower(string(content)), "hello from crush")
- })
- t.Run("bash tool", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "use bash to create a file named test.txt with content 'hello bash'. do not print its timestamp",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundBash := false
- var bashTCID string
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.BashToolName {
- bashTCID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == bashTCID {
- foundBash = true
- }
- }
- }
- }
- require.True(t, foundBash, "Expected to find a bash operation")
- testFilePath := filepath.Join(env.workingDir, "test.txt")
- content, err := os.ReadFile(testFilePath)
- require.NoError(t, err)
- require.Contains(t, string(content), "hello bash")
- })
- t.Run("download tool", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "download the file from https://example-files.online-convert.com/document/txt/example.txt and save it as example.txt",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundDownload := false
- var downloadTCID string
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.DownloadToolName {
- downloadTCID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == downloadTCID {
- foundDownload = true
- }
- }
- }
- }
- require.True(t, foundDownload, "Expected to find a download operation")
- examplePath := filepath.Join(env.workingDir, "example.txt")
- _, err = os.Stat(examplePath)
- require.NoError(t, err, "Expected example.txt file to exist")
- })
- t.Run("fetch tool", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "fetch the content from https://example-files.online-convert.com/website/html/example.html and tell me if it contains the word 'John Doe'",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundFetch := false
- var fetchTCID string
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.FetchToolName {
- fetchTCID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == fetchTCID {
- foundFetch = true
- }
- }
- }
- }
- require.True(t, foundFetch, "Expected to find a fetch operation")
- })
- t.Run("glob tool", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "use glob to find all .go files in the current directory",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundGlob := false
- var globTCID string
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.GlobToolName {
- globTCID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == globTCID {
- foundGlob = true
- require.Contains(t, tr.Content, "main.go", "Expected glob to find main.go")
- }
- }
- }
- }
- require.True(t, foundGlob, "Expected to find a glob operation")
- })
- t.Run("grep tool", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "use grep to search for the word 'package' in go files",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundGrep := false
- var grepTCID string
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.GrepToolName {
- grepTCID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == grepTCID {
- foundGrep = true
- require.Contains(t, tr.Content, "main.go", "Expected grep to find main.go")
- }
- }
- }
- }
- require.True(t, foundGrep, "Expected to find a grep operation")
- })
- t.Run("ls tool", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "use ls to list the files in the current directory",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundLS := false
- var lsTCID string
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.LSToolName {
- lsTCID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == lsTCID {
- foundLS = true
- require.Contains(t, tr.Content, "main.go", "Expected ls to list main.go")
- require.Contains(t, tr.Content, "go.mod", "Expected ls to list go.mod")
- }
- }
- }
- }
- require.True(t, foundLS, "Expected to find an ls operation")
- })
- t.Run("multiedit tool", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "use multiedit to change 'Hello, World!' to 'Hello, Crush!' and add a comment '// Greeting' above the fmt.Println line in main.go",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundMultiEdit := false
- var multiEditTCID string
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.MultiEditToolName {
- multiEditTCID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == multiEditTCID {
- foundMultiEdit = true
- }
- }
- }
- }
- require.True(t, foundMultiEdit, "Expected to find a multiedit operation")
- mainGoPath := filepath.Join(env.workingDir, "main.go")
- content, err := os.ReadFile(mainGoPath)
- require.NoError(t, err)
- require.Contains(t, string(content), "Hello, Crush!", "Expected file to contain 'Hello, Crush!'")
- })
- t.Run("sourcegraph tool", func(t *testing.T) {
- if runtime.GOOS == "darwin" {
- t.Skip("skipping flacky test on macos for now")
- }
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "use sourcegraph to search for 'func main' in Go repositories",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundSourcegraph := false
- var sourcegraphTCID string
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.SourcegraphToolName {
- sourcegraphTCID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == sourcegraphTCID {
- foundSourcegraph = true
- }
- }
- }
- }
- require.True(t, foundSourcegraph, "Expected to find a sourcegraph operation")
- })
- t.Run("write tool", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "use write to create a new file called config.json with content '{\"name\": \"test\", \"version\": \"1.0.0\"}'",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- foundWrite := false
- var writeTCID string
- for _, msg := range msgs {
- if msg.Role == message.Assistant {
- for _, tc := range msg.ToolCalls() {
- if tc.Name == tools.WriteToolName {
- writeTCID = tc.ID
- }
- }
- }
- if msg.Role == message.Tool {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == writeTCID {
- foundWrite = true
- }
- }
- }
- }
- require.True(t, foundWrite, "Expected to find a write operation")
- configPath := filepath.Join(env.workingDir, "config.json")
- content, err := os.ReadFile(configPath)
- require.NoError(t, err)
- require.Contains(t, string(content), "test", "Expected config.json to contain 'test'")
- require.Contains(t, string(content), "1.0.0", "Expected config.json to contain '1.0.0'")
- })
- t.Run("parallel tool calls", func(t *testing.T) {
- agent, env := setupAgent(t, pair)
- session, err := env.sessions.Create(t.Context(), "New Session")
- require.NoError(t, err)
- res, err := agent.Run(t.Context(), SessionAgentCall{
- Prompt: "use glob to find all .go files and use ls to list the current directory, it is very important that you run both tool calls in parallel",
- SessionID: session.ID,
- MaxOutputTokens: 10000,
- })
- require.NoError(t, err)
- assert.NotNil(t, res)
- msgs, err := env.messages.List(t.Context(), session.ID)
- require.NoError(t, err)
- var assistantMsg *message.Message
- var toolMsgs []message.Message
- for _, msg := range msgs {
- if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
- assistantMsg = &msg
- }
- if msg.Role == message.Tool {
- toolMsgs = append(toolMsgs, msg)
- }
- }
- require.NotNil(t, assistantMsg, "Expected to find an assistant message with tool calls")
- require.NotNil(t, toolMsgs, "Expected to find a tool message")
- toolCalls := assistantMsg.ToolCalls()
- require.GreaterOrEqual(t, len(toolCalls), 2, "Expected at least 2 tool calls in parallel")
- foundGlob := false
- foundLS := false
- var globTCID, lsTCID string
- for _, tc := range toolCalls {
- if tc.Name == tools.GlobToolName {
- foundGlob = true
- globTCID = tc.ID
- }
- if tc.Name == tools.LSToolName {
- foundLS = true
- lsTCID = tc.ID
- }
- }
- require.True(t, foundGlob, "Expected to find a glob tool call")
- require.True(t, foundLS, "Expected to find an ls tool call")
- require.GreaterOrEqual(t, len(toolMsgs), 2, "Expected at least 2 tool results in the same message")
- foundGlobResult := false
- foundLSResult := false
- for _, msg := range toolMsgs {
- for _, tr := range msg.ToolResults() {
- if tr.ToolCallID == globTCID {
- foundGlobResult = true
- require.Contains(t, tr.Content, "main.go", "Expected glob result to contain main.go")
- require.False(t, tr.IsError, "Expected glob result to not be an error")
- }
- if tr.ToolCallID == lsTCID {
- foundLSResult = true
- require.Contains(t, tr.Content, "main.go", "Expected ls result to contain main.go")
- require.False(t, tr.IsError, "Expected ls result to not be an error")
- }
- }
- }
- require.True(t, foundGlobResult, "Expected to find glob tool result")
- require.True(t, foundLSResult, "Expected to find ls tool result")
- })
- })
- }
- }
|