common_test.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. package agent
  2. import (
  3. "context"
  4. "net/http"
  5. "os"
  6. "path/filepath"
  7. "testing"
  8. "time"
  9. "charm.land/catwalk/pkg/catwalk"
  10. "charm.land/fantasy"
  11. "charm.land/fantasy/providers/anthropic"
  12. "charm.land/fantasy/providers/openai"
  13. "charm.land/fantasy/providers/openaicompat"
  14. "charm.land/fantasy/providers/openrouter"
  15. "charm.land/x/vcr"
  16. "github.com/charmbracelet/crush/internal/agent/prompt"
  17. "github.com/charmbracelet/crush/internal/agent/tools"
  18. "github.com/charmbracelet/crush/internal/config"
  19. "github.com/charmbracelet/crush/internal/csync"
  20. "github.com/charmbracelet/crush/internal/db"
  21. "github.com/charmbracelet/crush/internal/filetracker"
  22. "github.com/charmbracelet/crush/internal/history"
  23. "github.com/charmbracelet/crush/internal/lsp"
  24. "github.com/charmbracelet/crush/internal/message"
  25. "github.com/charmbracelet/crush/internal/permission"
  26. "github.com/charmbracelet/crush/internal/session"
  27. "github.com/stretchr/testify/require"
  28. _ "github.com/joho/godotenv/autoload"
  29. )
  30. // fakeEnv is an environment for testing.
  31. type fakeEnv struct {
  32. workingDir string
  33. sessions session.Service
  34. messages message.Service
  35. permissions permission.Service
  36. history history.Service
  37. filetracker *filetracker.Service
  38. lspClients *csync.Map[string, *lsp.Client]
  39. }
  40. type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
  41. type modelPair struct {
  42. name string
  43. largeModel builderFunc
  44. smallModel builderFunc
  45. }
  46. func anthropicBuilder(model string) builderFunc {
  47. return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
  48. provider, err := anthropic.New(
  49. anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
  50. anthropic.WithHTTPClient(&http.Client{Transport: r}),
  51. )
  52. if err != nil {
  53. return nil, err
  54. }
  55. return provider.LanguageModel(t.Context(), model)
  56. }
  57. }
  58. func openaiBuilder(model string) builderFunc {
  59. return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
  60. provider, err := openai.New(
  61. openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
  62. openai.WithHTTPClient(&http.Client{Transport: r}),
  63. )
  64. if err != nil {
  65. return nil, err
  66. }
  67. return provider.LanguageModel(t.Context(), model)
  68. }
  69. }
  70. func openRouterBuilder(model string) builderFunc {
  71. return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
  72. provider, err := openrouter.New(
  73. openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
  74. openrouter.WithHTTPClient(&http.Client{Transport: r}),
  75. )
  76. if err != nil {
  77. return nil, err
  78. }
  79. return provider.LanguageModel(t.Context(), model)
  80. }
  81. }
  82. func zAIBuilder(model string) builderFunc {
  83. return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
  84. provider, err := openaicompat.New(
  85. openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
  86. openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
  87. openaicompat.WithHTTPClient(&http.Client{Transport: r}),
  88. )
  89. if err != nil {
  90. return nil, err
  91. }
  92. return provider.LanguageModel(t.Context(), model)
  93. }
  94. }
  95. func testEnv(t *testing.T) fakeEnv {
  96. workingDir := filepath.Join("/tmp/crush-test/", t.Name())
  97. os.RemoveAll(workingDir)
  98. err := os.MkdirAll(workingDir, 0o755)
  99. require.NoError(t, err)
  100. conn, err := db.Connect(t.Context(), t.TempDir())
  101. require.NoError(t, err)
  102. q := db.New(conn)
  103. sessions := session.NewService(q, conn)
  104. messages := message.NewService(q)
  105. permissions := permission.NewPermissionService(workingDir, true, []string{})
  106. history := history.NewService(q, conn)
  107. filetrackerService := filetracker.NewService(q)
  108. lspClients := csync.NewMap[string, *lsp.Client]()
  109. t.Cleanup(func() {
  110. conn.Close()
  111. os.RemoveAll(workingDir)
  112. })
  113. return fakeEnv{
  114. workingDir,
  115. sessions,
  116. messages,
  117. permissions,
  118. history,
  119. &filetrackerService,
  120. lspClients,
  121. }
  122. }
  123. func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
  124. largeModel := Model{
  125. Model: large,
  126. CatwalkCfg: catwalk.Model{
  127. ContextWindow: 200000,
  128. DefaultMaxTokens: 10000,
  129. },
  130. }
  131. smallModel := Model{
  132. Model: small,
  133. CatwalkCfg: catwalk.Model{
  134. ContextWindow: 200000,
  135. DefaultMaxTokens: 10000,
  136. },
  137. }
  138. agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, false, true, env.sessions, env.messages, tools})
  139. return agent
  140. }
  141. func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
  142. fixedTime := func() time.Time {
  143. t, _ := time.Parse("1/2/2006", "1/1/2025")
  144. return t
  145. }
  146. prompt, err := coderPrompt(
  147. prompt.WithTimeFunc(fixedTime),
  148. prompt.WithPlatform("linux"),
  149. prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
  150. )
  151. if err != nil {
  152. return nil, err
  153. }
  154. cfg, err := config.Init(env.workingDir, "", false)
  155. if err != nil {
  156. return nil, err
  157. }
  158. // NOTE(@andreynering): Set a fixed config to ensure cassettes match
  159. // independently of user config on `$HOME/.config/crush/crush.json`.
  160. cfg.Options.Attribution = &config.Attribution{
  161. TrailerStyle: "co-authored-by",
  162. GeneratedWith: true,
  163. }
  164. // Clear skills paths to ensure test reproducibility - user's skills
  165. // would be included in prompt and break VCR cassette matching.
  166. cfg.Options.SkillsPaths = []string{}
  167. // Clear LSP config to ensure test reproducibility - user's LSP config
  168. // would be included in prompt and break VCR cassette matching.
  169. cfg.LSP = nil
  170. systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
  171. if err != nil {
  172. return nil, err
  173. }
  174. // Get the model name for the bash tool
  175. modelName := large.Model() // fallback to ID if Name not available
  176. if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
  177. modelName = model.Name
  178. }
  179. allTools := []fantasy.AgentTool{
  180. tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
  181. tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
  182. tools.NewEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
  183. tools.NewMultiEditTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
  184. tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
  185. tools.NewGlobTool(env.workingDir),
  186. tools.NewGrepTool(env.workingDir, cfg.Tools.Grep),
  187. tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
  188. tools.NewSourcegraphTool(r.GetDefaultClient()),
  189. tools.NewViewTool(nil, env.permissions, *env.filetracker, env.workingDir),
  190. tools.NewWriteTool(nil, env.permissions, env.history, *env.filetracker, env.workingDir),
  191. }
  192. return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
  193. }
  194. // createSimpleGoProject creates a simple Go project structure in the given directory.
  195. // It creates a go.mod file and a main.go file with a basic hello world program.
  196. func createSimpleGoProject(t *testing.T, dir string) {
  197. goMod := `module example.com/testproject
  198. go 1.23
  199. `
  200. err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
  201. require.NoError(t, err)
  202. mainGo := `package main
  203. import "fmt"
  204. func main() {
  205. fmt.Println("Hello, World!")
  206. }
  207. `
  208. err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
  209. require.NoError(t, err)
  210. }