fetch.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package tools
  2. import (
  3. "context"
  4. _ "embed"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "unicode/utf8"
  11. "charm.land/fantasy"
  12. md "github.com/JohannesKaufmann/html-to-markdown"
  13. "github.com/PuerkitoBio/goquery"
  14. "github.com/charmbracelet/crush/internal/permission"
  15. )
  16. const (
  17. FetchToolName = "fetch"
  18. MaxFetchSize = 1 * 1024 * 1024 // 1MB
  19. )
  20. //go:embed fetch.md
  21. var fetchDescription []byte
  22. func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
  23. if client == nil {
  24. transport := http.DefaultTransport.(*http.Transport).Clone()
  25. transport.MaxIdleConns = 100
  26. transport.MaxIdleConnsPerHost = 10
  27. transport.IdleConnTimeout = 90 * time.Second
  28. client = &http.Client{
  29. Timeout: 30 * time.Second,
  30. Transport: transport,
  31. }
  32. }
  33. return fantasy.NewParallelAgentTool(
  34. FetchToolName,
  35. FirstLineDescription(fetchDescription),
  36. func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  37. if params.URL == "" {
  38. return fantasy.NewTextErrorResponse("URL parameter is required"), nil
  39. }
  40. format := strings.ToLower(params.Format)
  41. if format != "text" && format != "markdown" && format != "html" {
  42. return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
  43. }
  44. if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
  45. return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
  46. }
  47. sessionID := GetSessionFromContext(ctx)
  48. if sessionID == "" {
  49. return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
  50. }
  51. p, err := permissions.Request(ctx,
  52. permission.CreatePermissionRequest{
  53. SessionID: sessionID,
  54. Path: workingDir,
  55. ToolCallID: call.ID,
  56. ToolName: FetchToolName,
  57. Action: "fetch",
  58. Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
  59. Params: FetchPermissionsParams(params),
  60. },
  61. )
  62. if err != nil {
  63. return fantasy.ToolResponse{}, err
  64. }
  65. if !p {
  66. return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
  67. }
  68. // maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
  69. const maxFetchTimeoutSeconds = 120
  70. // Handle timeout with context
  71. requestCtx := ctx
  72. if params.Timeout > 0 {
  73. if params.Timeout > maxFetchTimeoutSeconds {
  74. params.Timeout = maxFetchTimeoutSeconds
  75. }
  76. var cancel context.CancelFunc
  77. requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
  78. defer cancel()
  79. }
  80. req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
  81. if err != nil {
  82. return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
  83. }
  84. req.Header.Set("User-Agent", "crush/1.0")
  85. resp, err := client.Do(req)
  86. if err != nil {
  87. return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
  88. }
  89. defer resp.Body.Close()
  90. if resp.StatusCode != http.StatusOK {
  91. return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
  92. }
  93. body, err := io.ReadAll(io.LimitReader(resp.Body, MaxFetchSize))
  94. if err != nil {
  95. return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
  96. }
  97. content := string(body)
  98. validUTF8 := utf8.ValidString(content)
  99. if !validUTF8 {
  100. return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
  101. }
  102. contentType := resp.Header.Get("Content-Type")
  103. switch format {
  104. case "text":
  105. if strings.Contains(contentType, "text/html") {
  106. text, err := extractTextFromHTML(content)
  107. if err != nil {
  108. return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
  109. }
  110. content = text
  111. }
  112. case "markdown":
  113. if strings.Contains(contentType, "text/html") {
  114. markdown, err := convertHTMLToMarkdown(content)
  115. if err != nil {
  116. return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
  117. }
  118. content = markdown
  119. }
  120. content = "```\n" + content + "\n```"
  121. case "html":
  122. // return only the body of the HTML document
  123. if strings.Contains(contentType, "text/html") {
  124. doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
  125. if err != nil {
  126. return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
  127. }
  128. body, err := doc.Find("body").Html()
  129. if err != nil {
  130. return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
  131. }
  132. if body == "" {
  133. return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
  134. }
  135. content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
  136. }
  137. }
  138. // truncate content if it exceeds max read size
  139. if int64(len(content)) > MaxFetchSize {
  140. content = content[:MaxFetchSize]
  141. content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxFetchSize)
  142. }
  143. return fantasy.NewTextResponse(content), nil
  144. })
  145. }
  146. func extractTextFromHTML(html string) (string, error) {
  147. doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
  148. if err != nil {
  149. return "", err
  150. }
  151. text := doc.Find("body").Text()
  152. text = strings.Join(strings.Fields(text), " ")
  153. return text, nil
  154. }
  155. func convertHTMLToMarkdown(html string) (string, error) {
  156. converter := md.NewConverter("", true, nil)
  157. markdown, err := converter.ConvertString(html)
  158. if err != nil {
  159. return "", err
  160. }
  161. return markdown, nil
  162. }