common_test.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package agent
  2. import (
  3. "context"
  4. "net/http"
  5. "os"
  6. "path/filepath"
  7. "testing"
  8. "time"
  9. "charm.land/fantasy"
  10. "charm.land/fantasy/providers/anthropic"
  11. "charm.land/fantasy/providers/openai"
  12. "charm.land/fantasy/providers/openaicompat"
  13. "charm.land/fantasy/providers/openrouter"
  14. "github.com/charmbracelet/catwalk/pkg/catwalk"
  15. "github.com/charmbracelet/crush/internal/agent/prompt"
  16. "github.com/charmbracelet/crush/internal/agent/tools"
  17. "github.com/charmbracelet/crush/internal/config"
  18. "github.com/charmbracelet/crush/internal/csync"
  19. "github.com/charmbracelet/crush/internal/db"
  20. "github.com/charmbracelet/crush/internal/history"
  21. "github.com/charmbracelet/crush/internal/lsp"
  22. "github.com/charmbracelet/crush/internal/message"
  23. "github.com/charmbracelet/crush/internal/permission"
  24. "github.com/charmbracelet/crush/internal/session"
  25. "github.com/stretchr/testify/require"
  26. "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
  27. _ "github.com/joho/godotenv/autoload"
  28. )
  29. // fakeEnv is an environment for testing.
  30. type fakeEnv struct {
  31. workingDir string
  32. sessions session.Service
  33. messages message.Service
  34. permissions permission.Service
  35. history history.Service
  36. lspClients *csync.Map[string, *lsp.Client]
  37. }
  38. type builderFunc func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error)
  39. type modelPair struct {
  40. name string
  41. largeModel builderFunc
  42. smallModel builderFunc
  43. }
  44. func anthropicBuilder(model string) builderFunc {
  45. return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
  46. provider, err := anthropic.New(
  47. anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
  48. anthropic.WithHTTPClient(&http.Client{Transport: r}),
  49. )
  50. if err != nil {
  51. return nil, err
  52. }
  53. return provider.LanguageModel(t.Context(), model)
  54. }
  55. }
  56. func openaiBuilder(model string) builderFunc {
  57. return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
  58. provider, err := openai.New(
  59. openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
  60. openai.WithHTTPClient(&http.Client{Transport: r}),
  61. )
  62. if err != nil {
  63. return nil, err
  64. }
  65. return provider.LanguageModel(t.Context(), model)
  66. }
  67. }
  68. func openRouterBuilder(model string) builderFunc {
  69. return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
  70. provider, err := openrouter.New(
  71. openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
  72. openrouter.WithHTTPClient(&http.Client{Transport: r}),
  73. )
  74. if err != nil {
  75. return nil, err
  76. }
  77. return provider.LanguageModel(t.Context(), model)
  78. }
  79. }
  80. func zAIBuilder(model string) builderFunc {
  81. return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
  82. provider, err := openaicompat.New(
  83. openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
  84. openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
  85. openaicompat.WithHTTPClient(&http.Client{Transport: r}),
  86. )
  87. if err != nil {
  88. return nil, err
  89. }
  90. return provider.LanguageModel(t.Context(), model)
  91. }
  92. }
  93. func testEnv(t *testing.T) fakeEnv {
  94. workingDir := filepath.Join("/tmp/crush-test/", t.Name())
  95. os.RemoveAll(workingDir)
  96. err := os.MkdirAll(workingDir, 0o755)
  97. require.NoError(t, err)
  98. conn, err := db.Connect(t.Context(), t.TempDir())
  99. require.NoError(t, err)
  100. q := db.New(conn)
  101. sessions := session.NewService(q)
  102. messages := message.NewService(q)
  103. permissions := permission.NewPermissionService(workingDir, true, []string{})
  104. history := history.NewService(q, conn)
  105. lspClients := csync.NewMap[string, *lsp.Client]()
  106. t.Cleanup(func() {
  107. conn.Close()
  108. os.RemoveAll(workingDir)
  109. })
  110. return fakeEnv{
  111. workingDir,
  112. sessions,
  113. messages,
  114. permissions,
  115. history,
  116. lspClients,
  117. }
  118. }
  119. func testSessionAgent(env fakeEnv, large, small fantasy.LanguageModel, systemPrompt string, tools ...fantasy.AgentTool) SessionAgent {
  120. largeModel := Model{
  121. Model: large,
  122. CatwalkCfg: catwalk.Model{
  123. ContextWindow: 200000,
  124. DefaultMaxTokens: 10000,
  125. },
  126. }
  127. smallModel := Model{
  128. Model: small,
  129. CatwalkCfg: catwalk.Model{
  130. ContextWindow: 200000,
  131. DefaultMaxTokens: 10000,
  132. },
  133. }
  134. agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, "", systemPrompt, false, true, env.sessions, env.messages, tools})
  135. return agent
  136. }
  137. func coderAgent(r *recorder.Recorder, env fakeEnv, large, small fantasy.LanguageModel) (SessionAgent, error) {
  138. fixedTime := func() time.Time {
  139. t, _ := time.Parse("1/2/2006", "1/1/2025")
  140. return t
  141. }
  142. prompt, err := coderPrompt(
  143. prompt.WithTimeFunc(fixedTime),
  144. prompt.WithPlatform("linux"),
  145. prompt.WithWorkingDir(filepath.ToSlash(env.workingDir)),
  146. )
  147. if err != nil {
  148. return nil, err
  149. }
  150. cfg, err := config.Init(env.workingDir, "", false)
  151. if err != nil {
  152. return nil, err
  153. }
  154. systemPrompt, err := prompt.Build(context.TODO(), large.Provider(), large.Model(), *cfg)
  155. if err != nil {
  156. return nil, err
  157. }
  158. allTools := []fantasy.AgentTool{
  159. tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution),
  160. tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
  161. tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
  162. tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
  163. tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
  164. tools.NewGlobTool(env.workingDir),
  165. tools.NewGrepTool(env.workingDir),
  166. tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
  167. tools.NewSourcegraphTool(r.GetDefaultClient()),
  168. tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
  169. tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
  170. }
  171. return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
  172. }
  173. // createSimpleGoProject creates a simple Go project structure in the given directory.
  174. // It creates a go.mod file and a main.go file with a basic hello world program.
  175. func createSimpleGoProject(t *testing.T, dir string) {
  176. goMod := `module example.com/testproject
  177. go 1.23
  178. `
  179. err := os.WriteFile(dir+"/go.mod", []byte(goMod), 0o644)
  180. require.NoError(t, err)
  181. mainGo := `package main
  182. import "fmt"
  183. func main() {
  184. fmt.Println("Hello, World!")
  185. }
  186. `
  187. err = os.WriteFile(dir+"/main.go", []byte(mainGo), 0o644)
  188. require.NoError(t, err)
  189. }