agentic_fetch_tool.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. package agent
  2. import (
  3. "context"
  4. _ "embed"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "os"
  9. "time"
  10. "charm.land/fantasy"
  11. "github.com/charmbracelet/crush/internal/agent/prompt"
  12. "github.com/charmbracelet/crush/internal/agent/tools"
  13. "github.com/charmbracelet/crush/internal/permission"
  14. )
  15. //go:embed templates/agentic_fetch.md
  16. var agenticFetchToolDescription []byte
  17. // agenticFetchValidationResult holds the validated parameters from the tool call context.
  18. type agenticFetchValidationResult struct {
  19. SessionID string
  20. AgentMessageID string
  21. }
  22. // validateAgenticFetchParams validates the tool call parameters and extracts required context values.
  23. func validateAgenticFetchParams(ctx context.Context, params tools.AgenticFetchParams) (agenticFetchValidationResult, error) {
  24. if params.Prompt == "" {
  25. return agenticFetchValidationResult{}, errors.New("prompt is required")
  26. }
  27. sessionID := tools.GetSessionFromContext(ctx)
  28. if sessionID == "" {
  29. return agenticFetchValidationResult{}, errors.New("session id missing from context")
  30. }
  31. agentMessageID := tools.GetMessageFromContext(ctx)
  32. if agentMessageID == "" {
  33. return agenticFetchValidationResult{}, errors.New("agent message id missing from context")
  34. }
  35. return agenticFetchValidationResult{
  36. SessionID: sessionID,
  37. AgentMessageID: agentMessageID,
  38. }, nil
  39. }
  40. //go:embed templates/agentic_fetch_prompt.md.tpl
  41. var agenticFetchPromptTmpl []byte
  42. func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (fantasy.AgentTool, error) {
  43. if client == nil {
  44. client = &http.Client{
  45. Timeout: 30 * time.Second,
  46. Transport: &http.Transport{
  47. MaxIdleConns: 100,
  48. MaxIdleConnsPerHost: 10,
  49. IdleConnTimeout: 90 * time.Second,
  50. },
  51. }
  52. }
  53. return fantasy.NewParallelAgentTool(
  54. tools.AgenticFetchToolName,
  55. string(agenticFetchToolDescription),
  56. func(ctx context.Context, params tools.AgenticFetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  57. validationResult, err := validateAgenticFetchParams(ctx, params)
  58. if err != nil {
  59. return fantasy.NewTextErrorResponse(err.Error()), nil
  60. }
  61. // Determine description based on mode.
  62. var description string
  63. if params.URL != "" {
  64. description = fmt.Sprintf("Fetch and analyze content from URL: %s", params.URL)
  65. } else {
  66. description = "Search the web and analyze results"
  67. }
  68. p, err := c.permissions.Request(ctx,
  69. permission.CreatePermissionRequest{
  70. SessionID: validationResult.SessionID,
  71. Path: c.cfg.WorkingDir(),
  72. ToolCallID: call.ID,
  73. ToolName: tools.AgenticFetchToolName,
  74. Action: "fetch",
  75. Description: description,
  76. Params: tools.AgenticFetchPermissionsParams(params),
  77. },
  78. )
  79. if err != nil {
  80. return fantasy.ToolResponse{}, err
  81. }
  82. if !p {
  83. return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
  84. }
  85. tmpDir, err := os.MkdirTemp(c.cfg.Options.DataDirectory, "crush-fetch-*")
  86. if err != nil {
  87. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary directory: %s", err)), nil
  88. }
  89. defer os.RemoveAll(tmpDir)
  90. var fullPrompt string
  91. if params.URL != "" {
  92. // URL mode: fetch the URL content first.
  93. content, err := tools.FetchURLAndConvert(ctx, client, params.URL)
  94. if err != nil {
  95. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to fetch URL: %s", err)), nil
  96. }
  97. hasLargeContent := len(content) > tools.LargeContentThreshold
  98. if hasLargeContent {
  99. tempFile, err := os.CreateTemp(tmpDir, "page-*.md")
  100. if err != nil {
  101. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary file: %s", err)), nil
  102. }
  103. tempFilePath := tempFile.Name()
  104. if _, err := tempFile.WriteString(content); err != nil {
  105. tempFile.Close()
  106. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to write content to file: %s", err)), nil
  107. }
  108. tempFile.Close()
  109. fullPrompt = fmt.Sprintf("%s\n\nThe web page from %s has been saved to: %s\n\nUse the view and grep tools to analyze this file and extract the requested information.", params.Prompt, params.URL, tempFilePath)
  110. } else {
  111. fullPrompt = fmt.Sprintf("%s\n\nWeb page URL: %s\n\n<webpage_content>\n%s\n</webpage_content>", params.Prompt, params.URL, content)
  112. }
  113. } else {
  114. // Search mode: let the sub-agent search and fetch as needed.
  115. fullPrompt = fmt.Sprintf("%s\n\nUse the web_search tool to find relevant information. Break down the question into smaller, focused searches if needed. After searching, use web_fetch to get detailed content from the most relevant results.", params.Prompt)
  116. }
  117. promptOpts := []prompt.Option{
  118. prompt.WithWorkingDir(tmpDir),
  119. }
  120. promptTemplate, err := prompt.NewPrompt("agentic_fetch", string(agenticFetchPromptTmpl), promptOpts...)
  121. if err != nil {
  122. return fantasy.ToolResponse{}, fmt.Errorf("error creating prompt: %s", err)
  123. }
  124. _, small, err := c.buildAgentModels(ctx, true)
  125. if err != nil {
  126. return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err)
  127. }
  128. systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg)
  129. if err != nil {
  130. return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err)
  131. }
  132. smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider)
  133. if !ok {
  134. return fantasy.ToolResponse{}, errors.New("small model provider not configured")
  135. }
  136. webFetchTool := tools.NewWebFetchTool(tmpDir, client)
  137. webSearchTool := tools.NewWebSearchTool(client)
  138. fetchTools := []fantasy.AgentTool{
  139. webFetchTool,
  140. webSearchTool,
  141. tools.NewGlobTool(tmpDir),
  142. tools.NewGrepTool(tmpDir),
  143. tools.NewSourcegraphTool(client),
  144. tools.NewViewTool(c.lspClients, c.permissions, c.filetracker, tmpDir),
  145. }
  146. agent := NewSessionAgent(SessionAgentOptions{
  147. LargeModel: small, // Use small model for both (fetch doesn't need large)
  148. SmallModel: small,
  149. SystemPromptPrefix: smallProviderCfg.SystemPromptPrefix,
  150. SystemPrompt: systemPrompt,
  151. DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize,
  152. IsYolo: c.permissions.SkipRequests(),
  153. Sessions: c.sessions,
  154. Messages: c.messages,
  155. Tools: fetchTools,
  156. })
  157. agentToolSessionID := c.sessions.CreateAgentToolSessionID(validationResult.AgentMessageID, call.ID)
  158. session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, validationResult.SessionID, "Fetch Analysis")
  159. if err != nil {
  160. return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
  161. }
  162. c.permissions.AutoApproveSession(session.ID)
  163. // Use small model for web content analysis (faster and cheaper)
  164. maxTokens := small.CatwalkCfg.DefaultMaxTokens
  165. if small.ModelCfg.MaxTokens != 0 {
  166. maxTokens = small.ModelCfg.MaxTokens
  167. }
  168. result, err := agent.Run(ctx, SessionAgentCall{
  169. SessionID: session.ID,
  170. Prompt: fullPrompt,
  171. MaxOutputTokens: maxTokens,
  172. ProviderOptions: getProviderOptions(small, smallProviderCfg),
  173. Temperature: small.ModelCfg.Temperature,
  174. TopP: small.ModelCfg.TopP,
  175. TopK: small.ModelCfg.TopK,
  176. FrequencyPenalty: small.ModelCfg.FrequencyPenalty,
  177. PresencePenalty: small.ModelCfg.PresencePenalty,
  178. })
  179. if err != nil {
  180. return fantasy.NewTextErrorResponse("error generating response"), nil
  181. }
  182. updatedSession, err := c.sessions.Get(ctx, session.ID)
  183. if err != nil {
  184. return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
  185. }
  186. parentSession, err := c.sessions.Get(ctx, validationResult.SessionID)
  187. if err != nil {
  188. return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
  189. }
  190. parentSession.Cost += updatedSession.Cost
  191. _, err = c.sessions.Save(ctx, parentSession)
  192. if err != nil {
  193. return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
  194. }
  195. return fantasy.NewTextResponse(result.Response.Content.Text()), nil
  196. }), nil
  197. }