| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241 |
- package agent
- import (
- "context"
- "net/http"
- "os"
- "path/filepath"
- "testing"
- "time"
- "charm.land/catwalk/pkg/catwalk"
- "charm.land/fantasy"
- "charm.land/fantasy/providers/anthropic"
- "charm.land/fantasy/providers/openai"
- "charm.land/fantasy/providers/openaicompat"
- "charm.land/fantasy/providers/openrouter"
- "charm.land/x/vcr"
- "github.com/charmbracelet/crush/internal/agent/prompt"
- "github.com/charmbracelet/crush/internal/agent/tools"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/csync"
- "github.com/charmbracelet/crush/internal/db"
- "github.com/charmbracelet/crush/internal/filetracker"
- "github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/lsp"
- "github.com/charmbracelet/crush/internal/message"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/session"
- "github.com/stretchr/testify/require"
- _ "github.com/joho/godotenv/autoload"
- )
- // fakeEnv is an environment for testing.
- type fakeEnv struct {
- workingDir string
- sessions session.Service
- messages message.Service
- permissions permission.Service
- history history.Service
- filetracker *filetracker.Service
- lspClients *csync.Map[string, *lsp.Client]
- }
- type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
- type modelPair struct {
- name string
- largeModel builderFunc
- smallModel builderFunc
- }
- func anthropicBuilder(model string) builderFunc {
- return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
- provider, err := anthropic.New(
- anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
- anthropic.WithHTTPClient(&http.Client{Transport: r}),
- )
- if err != nil {
- return nil, err
- }
- return provider.LanguageModel(t.Context(), model)
- }
- }
- func openaiBuilder(model string) builderFunc {
- return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
- provider, err := openai.New(
- openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
- openai.WithHTTPClient(&http.Client{Transport: r}),
- )
- if err != nil {
- return nil, err
- }
- return provider.LanguageModel(t.Context(), model)
- }
- }
- func openRouterBuilder(model string) builderFunc {
- return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
- provider, err := openrouter.New(
- openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
- openrouter.WithHTTPClient(&http.Client{Transport: r}),
- )
- if err != nil {
- return nil, err
- }
- return provider.LanguageModel(t.Context(), model)
- }
- }
- func zAIBuilder(model string) builderFunc {
- return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
- provider, err := openaicompat.New(
- openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
- openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
- openaicompat.WithHTTPClient(&http.Client{Transport: r}),
- )
- if err != nil {
- return nil, err
- }
- return provider.LanguageModel(t.Context(), model)
- }
- }
- func testEnv(t *testing.T) fakeEnv {
- workingDir := filepath.Join("/tmp/crush-test/", t.Name())
- os.RemoveAll(workingDir)
- err := os.MkdirAll(workingDir, 0o755)
- require.NoError(t, err)
- conn, err := db.Connect(t.Context(), t.TempDir())
- require.NoError(t, err)
- q := db.New(conn)
- sessions := session.NewService(q, conn)
- messages := message.NewService(q)
- permissions := permission.NewPermissionService(workingDir, true, []string{})
- history := history.NewService(q, conn)
- filetrackerService := filetracker.NewService(q)
- lspClients := csync.NewMap[string, *lsp.Client]()
- t.Cleanup(func() {
- conn.Close()
- os.RemoveAll(workingDir)
- })
- return fakeEnv{
- workingDir,
- sessions,
- messages,
- permissions,
- history,
- &filetrackerService,
- lspClients,
- }
- }
- func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
- largeModel := Model{
- Model: large,
- CatwalkCfg: catwalk.Model{
- ContextWindow: 200000,
- DefaultMaxTokens: 10000,
- },
- }
- smallModel := Model{
- Model: small,
- CatwalkCfg: catwalk.Model{
- ContextWindow: 200000,
- DefaultMaxTokens: 10000,
- },
- }
- agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, false, true, env.sessions, env.messages, tools})
- return agent
- }
- func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
- fixedTime := func() time.Time {
- t, _ := time.Parse("1/2/2006", "1/1/2025")
- return t
- }
- prompt, err := coderPrompt(
- prompt.WithTimeFunc(fixedTime),
- prompt.WithPlatform("linux"),
- prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
- )
- if err != nil {
- return nil, err
- }
- cfg, err := config.Init(env.workingDir, "", false)
- if err != nil {
- return nil, err
- }
- // NOTE(@andreynering): Set a fixed config to ensure cassettes match
- // independently of user config on `$HOME/.config/crush/crush.json`.
- cfg.Options.Attribution = &config.Attribution{
- TrailerStyle: "co-authored-by",
- GeneratedWith: true,
- }
- // Clear skills paths to ensure test reproducibility - user's skills
- // would be included in prompt and break VCR cassette matching.
- cfg.Options.SkillsPaths = []string{}
- // Clear LSP config to ensure test reproducibility - user's LSP config
- // would be included in prompt and break VCR cassette matching.
- cfg.LSP = nil
- systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
- if err != nil {
- return nil, err
- }
- // Get the model name for the bash tool
- modelName := large.Model() // fallback to ID if Name not available
- if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
- modelName = model.Name
- }
- allTools := []fantasy.AgentTool{
- tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
- tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
- tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
- tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
- tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
- tools.NewGlobTool(env.workingDir),
- tools.NewGrepTool(env.workingDir, cfg.Tools.Grep),
- tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
- tools.NewSourcegraphTool(r.GetDefaultClient()),
- tools.NewViewTool(nil, env.permissions, *env.filetracker, env.workingDir),
- tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
- }
- return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
- }
- // createSimpleGoProject creates a simple Go project structure in the given directory.
- // It creates a go.mod file and a main.go file with a basic hello world program.
- func createSimpleGoProject(t *testing.T, dir string) {
- goMod := `module example.com/testproject
- go 1.23
- `
- err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
- require.NoError(t, err)
- mainGo := `package main
- import "fmt"
- func main() {
- fmt.Println("Hello, World!")
- }
- `
- err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
- require.NoError(t, err)
- }
|