| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- package agent
- import (
- "context"
- _ "embed"
- "errors"
- "fmt"
- "charm.land/fantasy"
- "github.com/charmbracelet/crush/internal/agent/prompt"
- "github.com/charmbracelet/crush/internal/agent/tools"
- "github.com/charmbracelet/crush/internal/config"
- )
- //go:embed templates/agent_tool.md
- var agentToolDescription []byte
- type AgentParams struct {
- Prompt string `json:"prompt" description:"The task for the agent to perform"`
- }
- const (
- AgentToolName = "agent"
- )
- func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) {
- agentCfg, ok := c.cfg.Agents[config.AgentTask]
- if !ok {
- return nil, errors.New("task agent not configured")
- }
- prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
- if err != nil {
- return nil, err
- }
- agent, err := c.buildAgent(ctx, prompt, agentCfg)
- if err != nil {
- return nil, err
- }
- return fantasy.NewAgentTool(
- AgentToolName,
- string(agentToolDescription),
- func(ctx context.Context, params AgentParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
- if params.Prompt == "" {
- return fantasy.NewTextErrorResponse("prompt is required"), nil
- }
- sessionID := tools.GetSessionFromContext(ctx)
- if sessionID == "" {
- return fantasy.ToolResponse{}, errors.New("session id missing from context")
- }
- agentMessageID := tools.GetMessageFromContext(ctx)
- if agentMessageID == "" {
- return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
- }
- agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
- session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
- }
- model := agent.Model()
- maxTokens := model.CatwalkCfg.DefaultMaxTokens
- if model.ModelCfg.MaxTokens != 0 {
- maxTokens = model.ModelCfg.MaxTokens
- }
- providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
- if !ok {
- return fantasy.ToolResponse{}, errors.New("model provider not configured")
- }
- result, err := agent.Run(ctx, SessionAgentCall{
- SessionID: session.ID,
- Prompt: params.Prompt,
- MaxOutputTokens: maxTokens,
- ProviderOptions: getProviderOptions(model, providerCfg),
- Temperature: model.ModelCfg.Temperature,
- TopP: model.ModelCfg.TopP,
- TopK: model.ModelCfg.TopK,
- FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
- PresencePenalty: model.ModelCfg.PresencePenalty,
- })
- if err != nil {
- return fantasy.NewTextErrorResponse("error generating response"), nil
- }
- updatedSession, err := c.sessions.Get(ctx, session.ID)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
- }
- parentSession, err := c.sessions.Get(ctx, sessionID)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
- }
- parentSession.Cost += updatedSession.Cost
- _, err = c.sessions.Save(ctx, parentSession)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
- }
- return fantasy.NewTextResponse(result.Response.Content.Text()), nil
- }), nil
- }
|