fetch.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package tools
  2. import (
  3. "context"
  4. _ "embed"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "strings"
  9. "time"
  10. "unicode/utf8"
  11. "charm.land/fantasy"
  12. md "github.com/JohannesKaufmann/html-to-markdown"
  13. "github.com/PuerkitoBio/goquery"
  14. "github.com/charmbracelet/crush/internal/permission"
  15. )
  16. const FetchToolName = "fetch"
  17. //go:embed fetch.md
  18. var fetchDescription []byte
  19. func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
  20. if client == nil {
  21. client = &http.Client{
  22. Timeout: 30 * time.Second,
  23. Transport: &http.Transport{
  24. MaxIdleConns: 100,
  25. MaxIdleConnsPerHost: 10,
  26. IdleConnTimeout: 90 * time.Second,
  27. },
  28. }
  29. }
  30. return fantasy.NewAgentTool(
  31. FetchToolName,
  32. string(fetchDescription),
  33. func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  34. if params.URL == "" {
  35. return fantasy.NewTextErrorResponse("URL parameter is required"), nil
  36. }
  37. format := strings.ToLower(params.Format)
  38. if format != "text" && format != "markdown" && format != "html" {
  39. return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
  40. }
  41. if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
  42. return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
  43. }
  44. sessionID := GetSessionFromContext(ctx)
  45. if sessionID == "" {
  46. return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
  47. }
  48. p := permissions.Request(
  49. permission.CreatePermissionRequest{
  50. SessionID: sessionID,
  51. Path: workingDir,
  52. ToolCallID: call.ID,
  53. ToolName: FetchToolName,
  54. Action: "fetch",
  55. Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
  56. Params: FetchPermissionsParams(params),
  57. },
  58. )
  59. if !p {
  60. return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
  61. }
  62. // Handle timeout with context
  63. requestCtx := ctx
  64. if params.Timeout > 0 {
  65. maxTimeout := 120 // 2 minutes
  66. if params.Timeout > maxTimeout {
  67. params.Timeout = maxTimeout
  68. }
  69. var cancel context.CancelFunc
  70. requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
  71. defer cancel()
  72. }
  73. req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
  74. if err != nil {
  75. return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
  76. }
  77. req.Header.Set("User-Agent", "crush/1.0")
  78. resp, err := client.Do(req)
  79. if err != nil {
  80. return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
  81. }
  82. defer resp.Body.Close()
  83. if resp.StatusCode != http.StatusOK {
  84. return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
  85. }
  86. maxSize := int64(5 * 1024 * 1024) // 5MB
  87. body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
  88. if err != nil {
  89. return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
  90. }
  91. content := string(body)
  92. isValidUt8 := utf8.ValidString(content)
  93. if !isValidUt8 {
  94. return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
  95. }
  96. contentType := resp.Header.Get("Content-Type")
  97. switch format {
  98. case "text":
  99. if strings.Contains(contentType, "text/html") {
  100. text, err := extractTextFromHTML(content)
  101. if err != nil {
  102. return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
  103. }
  104. content = text
  105. }
  106. case "markdown":
  107. if strings.Contains(contentType, "text/html") {
  108. markdown, err := convertHTMLToMarkdown(content)
  109. if err != nil {
  110. return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
  111. }
  112. content = markdown
  113. }
  114. content = "```\n" + content + "\n```"
  115. case "html":
  116. // return only the body of the HTML document
  117. if strings.Contains(contentType, "text/html") {
  118. doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
  119. if err != nil {
  120. return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
  121. }
  122. body, err := doc.Find("body").Html()
  123. if err != nil {
  124. return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
  125. }
  126. if body == "" {
  127. return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
  128. }
  129. content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
  130. }
  131. }
  132. // calculate byte size of content
  133. contentSize := int64(len(content))
  134. if contentSize > MaxReadSize {
  135. content = content[:MaxReadSize]
  136. content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
  137. }
  138. return fantasy.NewTextResponse(content), nil
  139. })
  140. }
  141. func extractTextFromHTML(html string) (string, error) {
  142. doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
  143. if err != nil {
  144. return "", err
  145. }
  146. text := doc.Find("body").Text()
  147. text = strings.Join(strings.Fields(text), " ")
  148. return text, nil
  149. }
  150. func convertHTMLToMarkdown(html string) (string, error) {
  151. converter := md.NewConverter("", true, nil)
  152. markdown, err := converter.ConvertString(html)
  153. if err != nil {
  154. return "", err
  155. }
  156. return markdown, nil
  157. }