common_test.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package agent
  2. import (
  3. "context"
  4. "net/http"
  5. "os"
  6. "path/filepath"
  7. "testing"
  8. "time"
  9. "charm.land/fantasy"
  10. "charm.land/fantasy/providers/anthropic"
  11. "charm.land/fantasy/providers/openai"
  12. "charm.land/fantasy/providers/openaicompat"
  13. "charm.land/fantasy/providers/openrouter"
  14. "charm.land/x/vcr"
  15. "github.com/charmbracelet/catwalk/pkg/catwalk"
  16. "github.com/charmbracelet/crush/internal/agent/prompt"
  17. "github.com/charmbracelet/crush/internal/agent/tools"
  18. "github.com/charmbracelet/crush/internal/config"
  19. "github.com/charmbracelet/crush/internal/csync"
  20. "github.com/charmbracelet/crush/internal/db"
  21. "github.com/charmbracelet/crush/internal/history"
  22. "github.com/charmbracelet/crush/internal/lsp"
  23. "github.com/charmbracelet/crush/internal/message"
  24. "github.com/charmbracelet/crush/internal/permission"
  25. "github.com/charmbracelet/crush/internal/session"
  26. "github.com/stretchr/testify/require"
  27. _ "github.com/joho/godotenv/autoload"
  28. )
  29. // fakeEnv is an environment for testing.
  30. type fakeEnv struct {
  31. workingDir string
  32. sessions session.Service
  33. messages message.Service
  34. permissions permission.Service
  35. history history.Service
  36. lspClients *csync.Map[string, *lsp.Client]
  37. }
  38. type builderFunc func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error)
  39. type modelPair struct {
  40. name string
  41. largeModel builderFunc
  42. smallModel builderFunc
  43. }
  44. func anthropicBuilder(model string) builderFunc {
  45. return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
  46. provider, err := anthropic.New(
  47. anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
  48. anthropic.WithHTTPClient(&http.Client{Transport: r}),
  49. )
  50. if err != nil {
  51. return nil, err
  52. }
  53. return provider.LanguageModel(t.Context(), model)
  54. }
  55. }
  56. func openaiBuilder(model string) builderFunc {
  57. return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
  58. provider, err := openai.New(
  59. openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
  60. openai.WithHTTPClient(&http.Client{Transport: r}),
  61. )
  62. if err != nil {
  63. return nil, err
  64. }
  65. return provider.LanguageModel(t.Context(), model)
  66. }
  67. }
  68. func openRouterBuilder(model string) builderFunc {
  69. return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
  70. provider, err := openrouter.New(
  71. openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
  72. openrouter.WithHTTPClient(&http.Client{Transport: r}),
  73. )
  74. if err != nil {
  75. return nil, err
  76. }
  77. return provider.LanguageModel(t.Context(), model)
  78. }
  79. }
  80. func zAIBuilder(model string) builderFunc {
  81. return func(t *testing.T, r *vcr.Recorder) (fantasy.LanguageModel, error) {
  82. provider, err := openaicompat.New(
  83. openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
  84. openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
  85. openaicompat.WithHTTPClient(&http.Client{Transport: r}),
  86. )
  87. if err != nil {
  88. return nil, err
  89. }
  90. return provider.LanguageModel(t.Context(), model)
  91. }
  92. }
  93. func testEnv(t *testing.T) fakeEnv {
  94. workingDir := filepath.Join("/tmp/crush-test/", t.Name())
  95. os.RemoveAll(workingDir)
  96. err := os.MkdirAll(workingDir, 0o755)
  97. require.NoError(t, err)
  98. conn, err := db.Connect(t.Context(), t.TempDir())
  99. require.NoError(t, err)
  100. q := db.New(conn)
  101. sessions := session.NewService(q)
  102. messages := message.NewService(q)
  103. permissions := permission.NewPermissionService(workingDir, true, []string{})
  104. history := history.NewService(q, conn)
  105. lspClients := csync.NewMap[string, *lsp.Client]()
  106. t.Cleanup(func() {
  107. conn.Close()
  108. os.RemoveAll(workingDir)
  109. })
  110. return fakeEnv{
  111. workingDir,
  112. sessions,
  113. messages,
  114. permissions,
  115. history,
  116. lspClients,
  117. }
  118. }
  119. func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
  120. largeModel := Model{
  121. Model: large,
  122. CatwalkCfg: catwalk.Model{
  123. ContextWindow: 200000,
  124. DefaultMaxTokens: 10000,
  125. },
  126. }
  127. smallModel := Model{
  128. Model: small,
  129. CatwalkCfg: catwalk.Model{
  130. ContextWindow: 200000,
  131. DefaultMaxTokens: 10000,
  132. },
  133. }
  134. agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools})
  135. return agent
  136. }
  137. func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
  138. fixedTime := func() time.Time {
  139. t, _ := time.Parse("1/2/2006", "1/1/2025")
  140. return t
  141. }
  142. prompt, err := coderPrompt(
  143. prompt.WithTimeFunc(fixedTime),
  144. prompt.WithPlatform("linux"),
  145. prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
  146. )
  147. if err != nil {
  148. return nil, err
  149. }
  150. cfg, err := config.Init(env.workingDir, "", false)
  151. if err != nil {
  152. return nil, err
  153. }
  154. // NOTE(@andreynering): Set a fixed config to ensure cassettes match
  155. // independently of user config on `$HOME/.config/crush/crush.json`.
  156. cfg.Options.Attribution = &config.Attribution{
  157. TrailerStyle: "co-authored-by",
  158. GeneratedWith: true,
  159. }
  160. systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
  161. if err != nil {
  162. return nil, err
  163. }
  164. // Get the model name for the bash tool
  165. modelName := large.Model() // fallback to ID if Name not available
  166. if model := cfg.GetModel(large.Provider(), large.Model()); model != nil {
  167. modelName = model.Name
  168. }
  169. allTools := []fantasy.AgentTool{
  170. tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
  171. tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
  172. tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
  173. tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
  174. tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
  175. tools.NewGlobTool(env.workingDir),
  176. tools.NewGrepTool(env.workingDir),
  177. tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
  178. tools.NewSourcegraphTool(r.GetDefaultClient()),
  179. tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
  180. tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
  181. }
  182. return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
  183. }
  184. // createSimpleGoProject creates a simple Go project structure in the given directory.
  185. // It creates a go.mod file and a main.go file with a basic hello world program.
  186. func createSimpleGoProject(t *testing.T, dir string) {
  187. goMod := `module example.com/testproject
  188. go 1.23
  189. `
  190. err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
  191. require.NoError(t, err)
  192. mainGo := `package main
  193. import "fmt"
  194. func main() {
  195. fmt.Println("Hello, World!")
  196. }
  197. `
  198. err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
  199. require.NoError(t, err)
  200. }