| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- package tools
- import (
- "context"
- _ "embed"
- "fmt"
- "io"
- "net/http"
- "strings"
- "time"
- "unicode/utf8"
- "charm.land/fantasy"
- md "github.com/JohannesKaufmann/html-to-markdown"
- "github.com/PuerkitoBio/goquery"
- "github.com/charmbracelet/crush/internal/permission"
- )
- const (
- FetchToolName = "fetch"
- MaxFetchSize = 1 * 1024 * 1024 // 1MB
- )
- //go:embed fetch.md
- var fetchDescription []byte
- func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
- if client == nil {
- transport := http.DefaultTransport.(*http.Transport).Clone()
- transport.MaxIdleConns = 100
- transport.MaxIdleConnsPerHost = 10
- transport.IdleConnTimeout = 90 * time.Second
- client = &http.Client{
- Timeout: 30 * time.Second,
- Transport: transport,
- }
- }
- return fantasy.NewParallelAgentTool(
- FetchToolName,
- FirstLineDescription(fetchDescription),
- func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
- if params.URL == "" {
- return fantasy.NewTextErrorResponse("URL parameter is required"), nil
- }
- format := strings.ToLower(params.Format)
- if format != "text" && format != "markdown" && format != "html" {
- return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
- }
- if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
- return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
- }
- sessionID := GetSessionFromContext(ctx)
- if sessionID == "" {
- return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
- }
- p, err := permissions.Request(ctx,
- permission.CreatePermissionRequest{
- SessionID: sessionID,
- Path: workingDir,
- ToolCallID: call.ID,
- ToolName: FetchToolName,
- Action: "fetch",
- Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
- Params: FetchPermissionsParams(params),
- },
- )
- if err != nil {
- return fantasy.ToolResponse{}, err
- }
- if !p {
- return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
- }
- // maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
- const maxFetchTimeoutSeconds = 120
- // Handle timeout with context
- requestCtx := ctx
- if params.Timeout > 0 {
- if params.Timeout > maxFetchTimeoutSeconds {
- params.Timeout = maxFetchTimeoutSeconds
- }
- var cancel context.CancelFunc
- requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
- defer cancel()
- }
- req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
- }
- req.Header.Set("User-Agent", "crush/1.0")
- resp, err := client.Do(req)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
- }
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
- }
- body, err := io.ReadAll(io.LimitReader(resp.Body, MaxFetchSize))
- if err != nil {
- return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
- }
- content := string(body)
- validUTF8 := utf8.ValidString(content)
- if !validUTF8 {
- return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
- }
- contentType := resp.Header.Get("Content-Type")
- switch format {
- case "text":
- if strings.Contains(contentType, "text/html") {
- text, err := extractTextFromHTML(content)
- if err != nil {
- return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
- }
- content = text
- }
- case "markdown":
- if strings.Contains(contentType, "text/html") {
- markdown, err := convertHTMLToMarkdown(content)
- if err != nil {
- return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
- }
- content = markdown
- }
- content = "```\n" + content + "\n```"
- case "html":
- // return only the body of the HTML document
- if strings.Contains(contentType, "text/html") {
- doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
- if err != nil {
- return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
- }
- body, err := doc.Find("body").Html()
- if err != nil {
- return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
- }
- if body == "" {
- return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
- }
- content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
- }
- }
- // truncate content if it exceeds max read size
- if int64(len(content)) > MaxFetchSize {
- content = content[:MaxFetchSize]
- content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxFetchSize)
- }
- return fantasy.NewTextResponse(content), nil
- })
- }
- func extractTextFromHTML(html string) (string, error) {
- doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
- if err != nil {
- return "", err
- }
- text := doc.Find("body").Text()
- text = strings.Join(strings.Fields(text), " ")
- return text, nil
- }
- func convertHTMLToMarkdown(html string) (string, error) {
- converter := md.NewConverter("", true, nil)
- markdown, err := converter.ConvertString(html)
- if err != nil {
- return "", err
- }
- return markdown, nil
- }
|