agentic_fetch_tool.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. package agent
  2. import (
  3. "context"
  4. _ "embed"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "os"
  9. "time"
  10. "charm.land/fantasy"
  11. "github.com/charmbracelet/crush/internal/agent/prompt"
  12. "github.com/charmbracelet/crush/internal/agent/tools"
  13. "github.com/charmbracelet/crush/internal/permission"
  14. )
  15. //go:embed templates/agentic_fetch.md
  16. var agenticFetchToolDescription []byte
  17. // agenticFetchValidationResult holds the validated parameters from the tool call context.
  18. type agenticFetchValidationResult struct {
  19. SessionID string
  20. AgentMessageID string
  21. }
  22. // validateAgenticFetchParams validates the tool call parameters and extracts required context values.
  23. func validateAgenticFetchParams(ctx context.Context, params tools.AgenticFetchParams) (agenticFetchValidationResult, error) {
  24. if params.Prompt == "" {
  25. return agenticFetchValidationResult{}, errors.New("prompt is required")
  26. }
  27. sessionID := tools.GetSessionFromContext(ctx)
  28. if sessionID == "" {
  29. return agenticFetchValidationResult{}, errors.New("session id missing from context")
  30. }
  31. agentMessageID := tools.GetMessageFromContext(ctx)
  32. if agentMessageID == "" {
  33. return agenticFetchValidationResult{}, errors.New("agent message id missing from context")
  34. }
  35. return agenticFetchValidationResult{
  36. SessionID: sessionID,
  37. AgentMessageID: agentMessageID,
  38. }, nil
  39. }
  40. //go:embed templates/agentic_fetch_prompt.md.tpl
  41. var agenticFetchPromptTmpl []byte
  42. func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (fantasy.AgentTool, error) {
  43. if client == nil {
  44. client = &http.Client{
  45. Timeout: 30 * time.Second,
  46. Transport: &http.Transport{
  47. MaxIdleConns: 100,
  48. MaxIdleConnsPerHost: 10,
  49. IdleConnTimeout: 90 * time.Second,
  50. },
  51. }
  52. }
  53. return fantasy.NewParallelAgentTool(
  54. tools.AgenticFetchToolName,
  55. string(agenticFetchToolDescription),
  56. func(ctx context.Context, params tools.AgenticFetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  57. validationResult, err := validateAgenticFetchParams(ctx, params)
  58. if err != nil {
  59. return fantasy.NewTextErrorResponse(err.Error()), nil
  60. }
  61. // Determine description based on mode.
  62. var description string
  63. if params.URL != "" {
  64. description = fmt.Sprintf("Fetch and analyze content from URL: %s", params.URL)
  65. } else {
  66. description = "Search the web and analyze results"
  67. }
  68. p := c.permissions.Request(
  69. permission.CreatePermissionRequest{
  70. SessionID: validationResult.SessionID,
  71. Path: c.cfg.WorkingDir(),
  72. ToolCallID: call.ID,
  73. ToolName: tools.AgenticFetchToolName,
  74. Action: "fetch",
  75. Description: description,
  76. Params: tools.AgenticFetchPermissionsParams(params),
  77. },
  78. )
  79. if !p {
  80. return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
  81. }
  82. tmpDir, err := os.MkdirTemp(c.cfg.Options.DataDirectory, "crush-fetch-*")
  83. if err != nil {
  84. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary directory: %s", err)), nil
  85. }
  86. defer os.RemoveAll(tmpDir)
  87. var fullPrompt string
  88. if params.URL != "" {
  89. // URL mode: fetch the URL content first.
  90. content, err := tools.FetchURLAndConvert(ctx, client, params.URL)
  91. if err != nil {
  92. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to fetch URL: %s", err)), nil
  93. }
  94. hasLargeContent := len(content) > tools.LargeContentThreshold
  95. if hasLargeContent {
  96. tempFile, err := os.CreateTemp(tmpDir, "page-*.md")
  97. if err != nil {
  98. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary file: %s", err)), nil
  99. }
  100. tempFilePath := tempFile.Name()
  101. if _, err := tempFile.WriteString(content); err != nil {
  102. tempFile.Close()
  103. return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to write content to file: %s", err)), nil
  104. }
  105. tempFile.Close()
  106. fullPrompt = fmt.Sprintf("%s\n\nThe web page from %s has been saved to: %s\n\nUse the view and grep tools to analyze this file and extract the requested information.", params.Prompt, params.URL, tempFilePath)
  107. } else {
  108. fullPrompt = fmt.Sprintf("%s\n\nWeb page URL: %s\n\n<webpage_content>\n%s\n</webpage_content>", params.Prompt, params.URL, content)
  109. }
  110. } else {
  111. // Search mode: let the sub-agent search and fetch as needed.
  112. fullPrompt = fmt.Sprintf("%s\n\nUse the web_search tool to find relevant information. Break down the question into smaller, focused searches if needed. After searching, use web_fetch to get detailed content from the most relevant results.", params.Prompt)
  113. }
  114. promptOpts := []prompt.Option{
  115. prompt.WithWorkingDir(tmpDir),
  116. }
  117. promptTemplate, err := prompt.NewPrompt("agentic_fetch", string(agenticFetchPromptTmpl), promptOpts...)
  118. if err != nil {
  119. return fantasy.ToolResponse{}, fmt.Errorf("error creating prompt: %s", err)
  120. }
  121. _, small, err := c.buildAgentModels(ctx)
  122. if err != nil {
  123. return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err)
  124. }
  125. systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg)
  126. if err != nil {
  127. return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err)
  128. }
  129. smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider)
  130. if !ok {
  131. return fantasy.ToolResponse{}, errors.New("small model provider not configured")
  132. }
  133. webFetchTool := tools.NewWebFetchTool(tmpDir, client)
  134. webSearchTool := tools.NewWebSearchTool(client)
  135. fetchTools := []fantasy.AgentTool{
  136. webFetchTool,
  137. webSearchTool,
  138. tools.NewGlobTool(tmpDir),
  139. tools.NewGrepTool(tmpDir),
  140. tools.NewSourcegraphTool(client),
  141. tools.NewViewTool(c.lspClients, c.permissions, tmpDir),
  142. }
  143. agent := NewSessionAgent(SessionAgentOptions{
  144. LargeModel: small, // Use small model for both (fetch doesn't need large)
  145. SmallModel: small,
  146. SystemPromptPrefix: smallProviderCfg.SystemPromptPrefix,
  147. SystemPrompt: systemPrompt,
  148. DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize,
  149. IsYolo: c.permissions.SkipRequests(),
  150. Sessions: c.sessions,
  151. Messages: c.messages,
  152. Tools: fetchTools,
  153. })
  154. agentToolSessionID := c.sessions.CreateAgentToolSessionID(validationResult.AgentMessageID, call.ID)
  155. session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, validationResult.SessionID, "Fetch Analysis")
  156. if err != nil {
  157. return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
  158. }
  159. c.permissions.AutoApproveSession(session.ID)
  160. // Use small model for web content analysis (faster and cheaper)
  161. maxTokens := small.CatwalkCfg.DefaultMaxTokens
  162. if small.ModelCfg.MaxTokens != 0 {
  163. maxTokens = small.ModelCfg.MaxTokens
  164. }
  165. result, err := agent.Run(ctx, SessionAgentCall{
  166. SessionID: session.ID,
  167. Prompt: fullPrompt,
  168. MaxOutputTokens: maxTokens,
  169. ProviderOptions: getProviderOptions(small, smallProviderCfg),
  170. Temperature: small.ModelCfg.Temperature,
  171. TopP: small.ModelCfg.TopP,
  172. TopK: small.ModelCfg.TopK,
  173. FrequencyPenalty: small.ModelCfg.FrequencyPenalty,
  174. PresencePenalty: small.ModelCfg.PresencePenalty,
  175. })
  176. if err != nil {
  177. return fantasy.NewTextErrorResponse("error generating response"), nil
  178. }
  179. updatedSession, err := c.sessions.Get(ctx, session.ID)
  180. if err != nil {
  181. return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
  182. }
  183. parentSession, err := c.sessions.Get(ctx, validationResult.SessionID)
  184. if err != nil {
  185. return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
  186. }
  187. parentSession.Cost += updatedSession.Cost
  188. _, err = c.sessions.Save(ctx, parentSession)
  189. if err != nil {
  190. return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
  191. }
  192. return fantasy.NewTextResponse(result.Response.Content.Text()), nil
  193. }), nil
  194. }