fetch.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "unicode/utf8"
  11. md "github.com/JohannesKaufmann/html-to-markdown"
  12. "github.com/PuerkitoBio/goquery"
  13. "github.com/charmbracelet/crush/internal/permission"
  14. )
  15. type FetchParams struct {
  16. URL string `json:"url"`
  17. Format string `json:"format"`
  18. Timeout int `json:"timeout,omitempty"`
  19. }
  20. type FetchPermissionsParams struct {
  21. URL string `json:"url"`
  22. Format string `json:"format"`
  23. Timeout int `json:"timeout,omitempty"`
  24. }
  25. type fetchTool struct {
  26. client *http.Client
  27. permissions permission.Service
  28. workingDir string
  29. }
  30. const (
  31. FetchToolName = "fetch"
  32. fetchToolDescription = `Fetches content from a URL and returns it in the specified format.
  33. WHEN TO USE THIS TOOL:
  34. - Use when you need to download content from a URL
  35. - Helpful for retrieving documentation, API responses, or web content
  36. - Useful for getting external information to assist with tasks
  37. HOW TO USE:
  38. - Provide the URL to fetch content from
  39. - Specify the desired output format (text, markdown, or html)
  40. - Optionally set a timeout for the request
  41. FEATURES:
  42. - Supports three output formats: text, markdown, and html
  43. - Automatically handles HTTP redirects
  44. - Sets reasonable timeouts to prevent hanging
  45. - Validates input parameters before making requests
  46. LIMITATIONS:
  47. - Maximum response size is 5MB
  48. - Only supports HTTP and HTTPS protocols
  49. - Cannot handle authentication or cookies
  50. - Some websites may block automated requests
  51. TIPS:
  52. - Use text format for plain text content or simple API responses
  53. - Use markdown format for content that should be rendered with formatting
  54. - Use html format when you need the raw HTML structure
  55. - Set appropriate timeouts for potentially slow websites`
  56. )
  57. func NewFetchTool(permissions permission.Service, workingDir string) BaseTool {
  58. return &fetchTool{
  59. client: &http.Client{
  60. Timeout: 30 * time.Second,
  61. Transport: &http.Transport{
  62. MaxIdleConns: 100,
  63. MaxIdleConnsPerHost: 10,
  64. IdleConnTimeout: 90 * time.Second,
  65. },
  66. },
  67. permissions: permissions,
  68. workingDir: workingDir,
  69. }
  70. }
  71. func (t *fetchTool) Name() string {
  72. return FetchToolName
  73. }
  74. func (t *fetchTool) Info() ToolInfo {
  75. return ToolInfo{
  76. Name: FetchToolName,
  77. Description: fetchToolDescription,
  78. Parameters: map[string]any{
  79. "url": map[string]any{
  80. "type": "string",
  81. "description": "The URL to fetch content from",
  82. },
  83. "format": map[string]any{
  84. "type": "string",
  85. "description": "The format to return the content in (text, markdown, or html)",
  86. "enum": []string{"text", "markdown", "html"},
  87. },
  88. "timeout": map[string]any{
  89. "type": "number",
  90. "description": "Optional timeout in seconds (max 120)",
  91. },
  92. },
  93. Required: []string{"url", "format"},
  94. }
  95. }
  96. func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
  97. var params FetchParams
  98. if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
  99. return NewTextErrorResponse("Failed to parse fetch parameters: " + err.Error()), nil
  100. }
  101. if params.URL == "" {
  102. return NewTextErrorResponse("URL parameter is required"), nil
  103. }
  104. format := strings.ToLower(params.Format)
  105. if format != "text" && format != "markdown" && format != "html" {
  106. return NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
  107. }
  108. if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
  109. return NewTextErrorResponse("URL must start with http:// or https://"), nil
  110. }
  111. sessionID, messageID := GetContextValues(ctx)
  112. if sessionID == "" || messageID == "" {
  113. return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
  114. }
  115. p := t.permissions.Request(
  116. permission.CreatePermissionRequest{
  117. SessionID: sessionID,
  118. Path: t.workingDir,
  119. ToolCallID: call.ID,
  120. ToolName: FetchToolName,
  121. Action: "fetch",
  122. Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
  123. Params: FetchPermissionsParams(params),
  124. },
  125. )
  126. if !p {
  127. return ToolResponse{}, permission.ErrorPermissionDenied
  128. }
  129. // Handle timeout with context
  130. requestCtx := ctx
  131. if params.Timeout > 0 {
  132. maxTimeout := 120 // 2 minutes
  133. if params.Timeout > maxTimeout {
  134. params.Timeout = maxTimeout
  135. }
  136. var cancel context.CancelFunc
  137. requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
  138. defer cancel()
  139. }
  140. req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
  141. if err != nil {
  142. return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
  143. }
  144. req.Header.Set("User-Agent", "crush/1.0")
  145. resp, err := t.client.Do(req)
  146. if err != nil {
  147. return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
  148. }
  149. defer resp.Body.Close()
  150. if resp.StatusCode != http.StatusOK {
  151. return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
  152. }
  153. maxSize := int64(5 * 1024 * 1024) // 5MB
  154. body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
  155. if err != nil {
  156. return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
  157. }
  158. content := string(body)
  159. isValidUt8 := utf8.ValidString(content)
  160. if !isValidUt8 {
  161. return NewTextErrorResponse("Response content is not valid UTF-8"), nil
  162. }
  163. contentType := resp.Header.Get("Content-Type")
  164. switch format {
  165. case "text":
  166. if strings.Contains(contentType, "text/html") {
  167. text, err := extractTextFromHTML(content)
  168. if err != nil {
  169. return NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
  170. }
  171. content = text
  172. }
  173. case "markdown":
  174. if strings.Contains(contentType, "text/html") {
  175. markdown, err := convertHTMLToMarkdown(content)
  176. if err != nil {
  177. return NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
  178. }
  179. content = markdown
  180. }
  181. content = "```\n" + content + "\n```"
  182. case "html":
  183. // return only the body of the HTML document
  184. if strings.Contains(contentType, "text/html") {
  185. doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
  186. if err != nil {
  187. return NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
  188. }
  189. body, err := doc.Find("body").Html()
  190. if err != nil {
  191. return NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
  192. }
  193. if body == "" {
  194. return NewTextErrorResponse("No body content found in HTML"), nil
  195. }
  196. content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
  197. }
  198. }
  199. // calculate byte size of content
  200. contentSize := int64(len(content))
  201. if contentSize > MaxReadSize {
  202. content = content[:MaxReadSize]
  203. content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
  204. }
  205. return NewTextResponse(content), nil
  206. }
  207. func extractTextFromHTML(html string) (string, error) {
  208. doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
  209. if err != nil {
  210. return "", err
  211. }
  212. text := doc.Find("body").Text()
  213. text = strings.Join(strings.Fields(text), " ")
  214. return text, nil
  215. }
  216. func convertHTMLToMarkdown(html string) (string, error) {
  217. converter := md.NewConverter("", true, nil)
  218. markdown, err := converter.ConvertString(html)
  219. if err != nil {
  220. return "", err
  221. }
  222. return markdown, nil
  223. }