agentic_fetch_tool.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. package agent
  2. import (
  3. "context"
  4. _ "embed"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "os"
  9. "time"
  10. "charm.land/fantasy"
  11. "github.com/charmbracelet/crush/internal/agent/prompt"
  12. "github.com/charmbracelet/crush/internal/agent/tools"
  13. "github.com/charmbracelet/crush/internal/permission"
  14. )
  15. //go:embed templates/agentic_fetch.md
  16. var agenticFetchToolDescription []byte
  17. // agenticFetchValidationResult holds the validated parameters from the tool call context.
  18. type agenticFetchValidationResult struct {
  19. SessionID string
  20. AgentMessageID string
  21. }
  22. // validateAgenticFetchParams validates the tool call parameters and extracts required context values.
  23. func validateAgenticFetchParams(ctx context.Context, params tools.AgenticFetchParams) (agenticFetchValidationResult, error) {
  24. if params.URL == "" {
  25. return agenticFetchValidationResult{}, errors.New("url is required")
  26. }
  27. if params.Prompt == "" {
  28. return agenticFetchValidationResult{}, errors.New("prompt is required")
  29. }
  30. sessionID := tools.GetSessionFromContext(ctx)
  31. if sessionID == "" {
  32. return agenticFetchValidationResult{}, errors.New("session id missing from context")
  33. }
  34. agentMessageID := tools.GetMessageFromContext(ctx)
  35. if agentMessageID == "" {
  36. return agenticFetchValidationResult{}, errors.New("agent message id missing from context")
  37. }
  38. return agenticFetchValidationResult{
  39. SessionID: sessionID,
  40. AgentMessageID: agentMessageID,
  41. }, nil
  42. }
  43. //go:embed templates/agentic_fetch_prompt.md.tpl
  44. var agenticFetchPromptTmpl []byte
  45. func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (fantasy.AgentTool, error) {
  46. if client == nil {
  47. client = &http.Client{
  48. Timeout: 30 * time.Second,
  49. Transport: &http.Transport{
  50. MaxIdleConns: 100,
  51. MaxIdleConnsPerHost: 10,
  52. IdleConnTimeout: 90 * time.Second,
  53. },
  54. }
  55. }
  56. return fantasy.NewAgentTool(
  57. tools.AgenticFetchToolName,
  58. string(agenticFetchToolDescription),
  59. func(ctx context.Context, params tools.AgenticFetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  60. validationResult, err := validateAgenticFetchParams(ctx, params)
  61. if err != nil {
  62. return fantasy.NewTextErrorResponse(err.Error()), nil
  63. }
  64. p := c.permissions.Request(
  65. permission.CreatePermissionRequest{
  66. SessionID: validationResult.SessionID,
  67. Path: c.cfg.WorkingDir(),
  68. ToolCallID: call.ID,
  69. ToolName: tools.AgenticFetchToolName,
  70. Action: "fetch",
  71. Description: fmt.Sprintf("Fetch and analyze content from URL: %s", params.URL),
  72. Params: tools.AgenticFetchPermissionsParams(params),
  73. },
  74. )
  75. if !p {
  76. return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
  77. }
  78. content, err := tools.FetchURLAndConvert(ctx, client, params.URL)
  79. if err != nil {
  80. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to fetch URL: %s", err)), nil
  81. }
  82. tmpDir, err := os.MkdirTemp(c.cfg.Options.DataDirectory, "crush-fetch-*")
  83. if err != nil {
  84. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary directory: %s", err)), nil
  85. }
  86. defer os.RemoveAll(tmpDir)
  87. hasLargeContent := len(content) > tools.LargeContentThreshold
  88. var fullPrompt string
  89. if hasLargeContent {
  90. tempFile, err := os.CreateTemp(tmpDir, "page-*.md")
  91. if err != nil {
  92. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary file: %s", err)), nil
  93. }
  94. tempFilePath := tempFile.Name()
  95. if _, err := tempFile.WriteString(content); err != nil {
  96. tempFile.Close()
  97. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to write content to file: %s", err)), nil
  98. }
  99. tempFile.Close()
  100. fullPrompt = fmt.Sprintf("%s\n\nThe web page from %s has been saved to: %s\n\nUse the view and grep tools to analyze this file and extract the requested information.", params.Prompt, params.URL, tempFilePath)
  101. } else {
  102. fullPrompt = fmt.Sprintf("%s\n\nWeb page URL: %s\n\n<webpage_content>\n%s\n</webpage_content>", params.Prompt, params.URL, content)
  103. }
  104. promptOpts := []prompt.Option{
  105. prompt.WithWorkingDir(tmpDir),
  106. }
  107. promptTemplate, err := prompt.NewPrompt("agentic_fetch", string(agenticFetchPromptTmpl), promptOpts...)
  108. if err != nil {
  109. return fantasy.ToolResponse{}, fmt.Errorf("error creating prompt: %s", err)
  110. }
  111. _, small, err := c.buildAgentModels(ctx)
  112. if err != nil {
  113. return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err)
  114. }
  115. systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg)
  116. if err != nil {
  117. return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err)
  118. }
  119. smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider)
  120. if !ok {
  121. return fantasy.ToolResponse{}, errors.New("small model provider not configured")
  122. }
  123. webFetchTool := tools.NewWebFetchTool(tmpDir, client)
  124. fetchTools := []fantasy.AgentTool{
  125. webFetchTool,
  126. tools.NewGlobTool(tmpDir),
  127. tools.NewGrepTool(tmpDir),
  128. tools.NewViewTool(c.lspClients, c.permissions, tmpDir),
  129. }
  130. agent := NewSessionAgent(SessionAgentOptions{
  131. LargeModel: small, // Use small model for both (fetch doesn't need large)
  132. SmallModel: small,
  133. SystemPromptPrefix: smallProviderCfg.SystemPromptPrefix,
  134. SystemPrompt: systemPrompt,
  135. DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize,
  136. IsYolo: c.permissions.SkipRequests(),
  137. Sessions: c.sessions,
  138. Messages: c.messages,
  139. Tools: fetchTools,
  140. })
  141. agentToolSessionID := c.sessions.CreateAgentToolSessionID(validationResult.AgentMessageID, call.ID)
  142. session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, validationResult.SessionID, "Fetch Analysis")
  143. if err != nil {
  144. return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
  145. }
  146. c.permissions.AutoApproveSession(session.ID)
  147. // Use small model for web content analysis (faster and cheaper)
  148. maxTokens := small.CatwalkCfg.DefaultMaxTokens
  149. if small.ModelCfg.MaxTokens != 0 {
  150. maxTokens = small.ModelCfg.MaxTokens
  151. }
  152. result, err := agent.Run(ctx, SessionAgentCall{
  153. SessionID: session.ID,
  154. Prompt: fullPrompt,
  155. MaxOutputTokens: maxTokens,
  156. ProviderOptions: getProviderOptions(small, smallProviderCfg),
  157. Temperature: small.ModelCfg.Temperature,
  158. TopP: small.ModelCfg.TopP,
  159. TopK: small.ModelCfg.TopK,
  160. FrequencyPenalty: small.ModelCfg.FrequencyPenalty,
  161. PresencePenalty: small.ModelCfg.PresencePenalty,
  162. })
  163. if err != nil {
  164. return fantasy.NewTextErrorResponse("error generating response"), nil
  165. }
  166. updatedSession, err := c.sessions.Get(ctx, session.ID)
  167. if err != nil {
  168. return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
  169. }
  170. parentSession, err := c.sessions.Get(ctx, validationResult.SessionID)
  171. if err != nil {
  172. return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
  173. }
  174. parentSession.Cost += updatedSession.Cost
  175. _, err = c.sessions.Save(ctx, parentSession)
  176. if err != nil {
  177. return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
  178. }
  179. return fantasy.NewTextResponse(result.Response.Content.Text()), nil
  180. }), nil
  181. }