agent_tool.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package agent
  2. import (
  3. "context"
  4. _ "embed"
  5. "errors"
  6. "fmt"
  7. "charm.land/fantasy"
  8. "github.com/charmbracelet/crush/internal/agent/prompt"
  9. "github.com/charmbracelet/crush/internal/agent/tools"
  10. "github.com/charmbracelet/crush/internal/config"
  11. )
  12. //go:embed templates/agent_tool.md
  13. var agentToolDescription []byte
  14. type AgentParams struct {
  15. Prompt string `json:"prompt" description:"The task for the agent to perform"`
  16. }
  17. const (
  18. AgentToolName = "agent"
  19. )
  20. func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) {
  21. agentCfg, ok := c.cfg.Agents[config.AgentTask]
  22. if !ok {
  23. return nil, errors.New("task agent not configured")
  24. }
  25. prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
  26. if err != nil {
  27. return nil, err
  28. }
  29. agent, err := c.buildAgent(ctx, prompt, agentCfg, true)
  30. if err != nil {
  31. return nil, err
  32. }
  33. return fantasy.NewParallelAgentTool(
  34. AgentToolName,
  35. string(agentToolDescription),
  36. func(ctx context.Context, params AgentParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  37. if params.Prompt == "" {
  38. return fantasy.NewTextErrorResponse("prompt is required"), nil
  39. }
  40. sessionID := tools.GetSessionFromContext(ctx)
  41. if sessionID == "" {
  42. return fantasy.ToolResponse{}, errors.New("session id missing from context")
  43. }
  44. agentMessageID := tools.GetMessageFromContext(ctx)
  45. if agentMessageID == "" {
  46. return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
  47. }
  48. agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
  49. session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
  50. if err != nil {
  51. return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
  52. }
  53. model := agent.Model()
  54. maxTokens := model.CatwalkCfg.DefaultMaxTokens
  55. if model.ModelCfg.MaxTokens != 0 {
  56. maxTokens = model.ModelCfg.MaxTokens
  57. }
  58. providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
  59. if !ok {
  60. return fantasy.ToolResponse{}, errors.New("model provider not configured")
  61. }
  62. result, err := agent.Run(ctx, SessionAgentCall{
  63. SessionID: session.ID,
  64. Prompt: params.Prompt,
  65. MaxOutputTokens: maxTokens,
  66. ProviderOptions: getProviderOptions(model, providerCfg),
  67. Temperature: model.ModelCfg.Temperature,
  68. TopP: model.ModelCfg.TopP,
  69. TopK: model.ModelCfg.TopK,
  70. FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
  71. PresencePenalty: model.ModelCfg.PresencePenalty,
  72. })
  73. if err != nil {
  74. return fantasy.NewTextErrorResponse("error generating response"), nil
  75. }
  76. updatedSession, err := c.sessions.Get(ctx, session.ID)
  77. if err != nil {
  78. return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
  79. }
  80. parentSession, err := c.sessions.Get(ctx, sessionID)
  81. if err != nil {
  82. return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
  83. }
  84. parentSession.Cost += updatedSession.Cost
  85. _, err = c.sessions.Save(ctx, parentSession)
  86. if err != nil {
  87. return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
  88. }
  89. return fantasy.NewTextResponse(result.Response.Content.Text()), nil
  90. }), nil
  91. }