coder.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. package agent
  2. import (
  3. "errors"
  4. "github.com/kujtimiihoxha/termai/internal/app"
  5. "github.com/kujtimiihoxha/termai/internal/config"
  6. "github.com/kujtimiihoxha/termai/internal/llm/models"
  7. "github.com/kujtimiihoxha/termai/internal/llm/tools"
  8. )
  9. type coderAgent struct {
  10. *agent
  11. }
  12. func (c *coderAgent) setAgentTool(sessionID string) {
  13. inx := -1
  14. for i, tool := range c.tools {
  15. if tool.Info().Name == AgentToolName {
  16. inx = i
  17. break
  18. }
  19. }
  20. if inx == -1 {
  21. c.tools = append(c.tools, NewAgentTool(sessionID, c.App))
  22. } else {
  23. c.tools[inx] = NewAgentTool(sessionID, c.App)
  24. }
  25. }
  26. func (c *coderAgent) Generate(sessionID string, content string) error {
  27. c.setAgentTool(sessionID)
  28. return c.generate(sessionID, content)
  29. }
  30. func NewCoderAgent(app *app.App) (Agent, error) {
  31. model, ok := models.SupportedModels[config.Get().Model.Coder]
  32. if !ok {
  33. return nil, errors.New("model not supported")
  34. }
  35. agentProvider, titleGenerator, err := getAgentProviders(app.Context, model)
  36. if err != nil {
  37. return nil, err
  38. }
  39. otherTools := GetMcpTools(app.Context)
  40. if len(app.LSPClients) > 0 {
  41. otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients))
  42. }
  43. return &coderAgent{
  44. agent: &agent{
  45. App: app,
  46. tools: append(
  47. []tools.BaseTool{
  48. tools.NewBashTool(),
  49. tools.NewEditTool(app.LSPClients),
  50. tools.NewFetchTool(),
  51. tools.NewGlobTool(),
  52. tools.NewGrepTool(),
  53. tools.NewLsTool(),
  54. tools.NewSourcegraphTool(),
  55. tools.NewViewTool(app.LSPClients),
  56. tools.NewWriteTool(app.LSPClients),
  57. }, otherTools...,
  58. ),
  59. model: model,
  60. agent: agentProvider,
  61. titleGenerator: titleGenerator,
  62. },
  63. }, nil
  64. }