| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416 |
- package tools
- import (
- "bufio"
- "bytes"
- "context"
- _ "embed"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "os"
- "os/exec"
- "path/filepath"
- "regexp"
- "sort"
- "strings"
- "sync"
- "time"
- "charm.land/fantasy"
- "github.com/charmbracelet/crush/internal/fsext"
- )
- // regexCache provides thread-safe caching of compiled regex patterns
- type regexCache struct {
- cache map[string]*regexp.Regexp
- mu sync.RWMutex
- }
- // newRegexCache creates a new regex cache
- func newRegexCache() *regexCache {
- return ®exCache{
- cache: make(map[string]*regexp.Regexp),
- }
- }
- // get retrieves a compiled regex from cache or compiles and caches it
- func (rc *regexCache) get(pattern string) (*regexp.Regexp, error) {
- // Try to get from cache first (read lock)
- rc.mu.RLock()
- if regex, exists := rc.cache[pattern]; exists {
- rc.mu.RUnlock()
- return regex, nil
- }
- rc.mu.RUnlock()
- // Compile the regex (write lock)
- rc.mu.Lock()
- defer rc.mu.Unlock()
- // Double-check in case another goroutine compiled it while we waited
- if regex, exists := rc.cache[pattern]; exists {
- return regex, nil
- }
- // Compile and cache the regex
- regex, err := regexp.Compile(pattern)
- if err != nil {
- return nil, err
- }
- rc.cache[pattern] = regex
- return regex, nil
- }
- // Global regex cache instances
- var (
- searchRegexCache = newRegexCache()
- globRegexCache = newRegexCache()
- // Pre-compiled regex for glob conversion (used frequently)
- globBraceRegex = regexp.MustCompile(`\{([^}]+)\}`)
- )
- type GrepParams struct {
- Pattern string `json:"pattern" description:"The regex pattern to search for in file contents"`
- Path string `json:"path,omitempty" description:"The directory to search in. Defaults to the current working directory."`
- Include string `json:"include,omitempty" description:"File pattern to include in the search (e.g. \"*.js\", \"*.{ts,tsx}\")"`
- LiteralText bool `json:"literal_text,omitempty" description:"If true, the pattern will be treated as literal text with special regex characters escaped. Default is false."`
- }
- type grepMatch struct {
- path string
- modTime time.Time
- lineNum int
- charNum int
- lineText string
- }
- type GrepResponseMetadata struct {
- NumberOfMatches int `json:"number_of_matches"`
- Truncated bool `json:"truncated"`
- }
- const (
- GrepToolName = "grep"
- maxGrepContentWidth = 500
- )
- //go:embed grep.md
- var grepDescription []byte
- // escapeRegexPattern escapes special regex characters so they're treated as literal characters
- func escapeRegexPattern(pattern string) string {
- specialChars := []string{"\\", ".", "+", "*", "?", "(", ")", "[", "]", "{", "}", "^", "$", "|"}
- escaped := pattern
- for _, char := range specialChars {
- escaped = strings.ReplaceAll(escaped, char, "\\"+char)
- }
- return escaped
- }
- func NewGrepTool(workingDir string) fantasy.AgentTool {
- return fantasy.NewAgentTool(
- GrepToolName,
- string(grepDescription),
- func(ctx context.Context, params GrepParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
- if params.Pattern == "" {
- return fantasy.NewTextErrorResponse("pattern is required"), nil
- }
- // If literal_text is true, escape the pattern
- searchPattern := params.Pattern
- if params.LiteralText {
- searchPattern = escapeRegexPattern(params.Pattern)
- }
- searchPath := params.Path
- if searchPath == "" {
- searchPath = workingDir
- }
- matches, truncated, err := searchFiles(ctx, searchPattern, searchPath, params.Include, 100)
- if err != nil {
- return fantasy.NewTextErrorResponse(fmt.Sprintf("error searching files: %v", err)), nil
- }
- var output strings.Builder
- if len(matches) == 0 {
- output.WriteString("No files found")
- } else {
- fmt.Fprintf(&output, "Found %d matches\n", len(matches))
- currentFile := ""
- for _, match := range matches {
- if currentFile != match.path {
- if currentFile != "" {
- output.WriteString("\n")
- }
- currentFile = match.path
- fmt.Fprintf(&output, "%s:\n", filepath.ToSlash(match.path))
- }
- if match.lineNum > 0 {
- lineText := match.lineText
- if len(lineText) > maxGrepContentWidth {
- lineText = lineText[:maxGrepContentWidth] + "..."
- }
- if match.charNum > 0 {
- fmt.Fprintf(&output, " Line %d, Char %d: %s\n", match.lineNum, match.charNum, lineText)
- } else {
- fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, lineText)
- }
- } else {
- fmt.Fprintf(&output, " %s\n", match.path)
- }
- }
- if truncated {
- output.WriteString("\n(Results are truncated. Consider using a more specific path or pattern.)")
- }
- }
- return fantasy.WithResponseMetadata(
- fantasy.NewTextResponse(output.String()),
- GrepResponseMetadata{
- NumberOfMatches: len(matches),
- Truncated: truncated,
- },
- ), nil
- })
- }
- func searchFiles(ctx context.Context, pattern, rootPath, include string, limit int) ([]grepMatch, bool, error) {
- matches, err := searchWithRipgrep(ctx, pattern, rootPath, include)
- if err != nil {
- matches, err = searchFilesWithRegex(pattern, rootPath, include)
- if err != nil {
- return nil, false, err
- }
- }
- sort.Slice(matches, func(i, j int) bool {
- return matches[i].modTime.After(matches[j].modTime)
- })
- truncated := len(matches) > limit
- if truncated {
- matches = matches[:limit]
- }
- return matches, truncated, nil
- }
- func searchWithRipgrep(ctx context.Context, pattern, path, include string) ([]grepMatch, error) {
- cmd := getRgSearchCmd(ctx, pattern, path, include)
- if cmd == nil {
- return nil, fmt.Errorf("ripgrep not found in $PATH")
- }
- // Only add ignore files if they exist
- for _, ignoreFile := range []string{".gitignore", ".crushignore"} {
- ignorePath := filepath.Join(path, ignoreFile)
- if _, err := os.Stat(ignorePath); err == nil {
- cmd.Args = append(cmd.Args, "--ignore-file", ignorePath)
- }
- }
- output, err := cmd.Output()
- if err != nil {
- if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
- return []grepMatch{}, nil
- }
- return nil, err
- }
- var matches []grepMatch
- for line := range bytes.SplitSeq(bytes.TrimSpace(output), []byte{'\n'}) {
- if len(line) == 0 {
- continue
- }
- var match ripgrepMatch
- if err := json.Unmarshal(line, &match); err != nil {
- continue
- }
- if match.Type != "match" {
- continue
- }
- for _, m := range match.Data.Submatches {
- fi, err := os.Stat(match.Data.Path.Text)
- if err != nil {
- continue // Skip files we can't access
- }
- matches = append(matches, grepMatch{
- path: match.Data.Path.Text,
- modTime: fi.ModTime(),
- lineNum: match.Data.LineNumber,
- charNum: m.Start + 1, // ensure 1-based
- lineText: strings.TrimSpace(match.Data.Lines.Text),
- })
- // only get the first match of each line
- break
- }
- }
- return matches, nil
- }
- type ripgrepMatch struct {
- Type string `json:"type"`
- Data struct {
- Path struct {
- Text string `json:"text"`
- } `json:"path"`
- Lines struct {
- Text string `json:"text"`
- } `json:"lines"`
- LineNumber int `json:"line_number"`
- Submatches []struct {
- Start int `json:"start"`
- } `json:"submatches"`
- } `json:"data"`
- }
- func searchFilesWithRegex(pattern, rootPath, include string) ([]grepMatch, error) {
- matches := []grepMatch{}
- // Use cached regex compilation
- regex, err := searchRegexCache.get(pattern)
- if err != nil {
- return nil, fmt.Errorf("invalid regex pattern: %w", err)
- }
- var includePattern *regexp.Regexp
- if include != "" {
- regexPattern := globToRegex(include)
- includePattern, err = globRegexCache.get(regexPattern)
- if err != nil {
- return nil, fmt.Errorf("invalid include pattern: %w", err)
- }
- }
- // Create walker with gitignore and crushignore support
- walker := fsext.NewFastGlobWalker(rootPath)
- err = filepath.Walk(rootPath, func(path string, info os.FileInfo, err error) error {
- if err != nil {
- return nil // Skip errors
- }
- if info.IsDir() {
- // Check if directory should be skipped
- if walker.ShouldSkip(path) {
- return filepath.SkipDir
- }
- return nil // Continue into directory
- }
- // Use walker's shouldSkip method for files
- if walker.ShouldSkip(path) {
- return nil
- }
- // Skip hidden files (starting with a dot) to match ripgrep's default behavior
- base := filepath.Base(path)
- if base != "." && strings.HasPrefix(base, ".") {
- return nil
- }
- if includePattern != nil && !includePattern.MatchString(path) {
- return nil
- }
- match, lineNum, charNum, lineText, err := fileContainsPattern(path, regex)
- if err != nil {
- return nil // Skip files we can't read
- }
- if match {
- matches = append(matches, grepMatch{
- path: path,
- modTime: info.ModTime(),
- lineNum: lineNum,
- charNum: charNum,
- lineText: lineText,
- })
- if len(matches) >= 200 {
- return filepath.SkipAll
- }
- }
- return nil
- })
- if err != nil {
- return nil, err
- }
- return matches, nil
- }
- func fileContainsPattern(filePath string, pattern *regexp.Regexp) (bool, int, int, string, error) {
- // Only search text files.
- if !isTextFile(filePath) {
- return false, 0, 0, "", nil
- }
- file, err := os.Open(filePath)
- if err != nil {
- return false, 0, 0, "", err
- }
- defer file.Close()
- scanner := bufio.NewScanner(file)
- lineNum := 0
- for scanner.Scan() {
- lineNum++
- line := scanner.Text()
- if loc := pattern.FindStringIndex(line); loc != nil {
- charNum := loc[0] + 1
- return true, lineNum, charNum, line, nil
- }
- }
- return false, 0, 0, "", scanner.Err()
- }
- // isTextFile checks if a file is a text file by examining its MIME type.
- func isTextFile(filePath string) bool {
- file, err := os.Open(filePath)
- if err != nil {
- return false
- }
- defer file.Close()
- // Read first 512 bytes for MIME type detection.
- buffer := make([]byte, 512)
- n, err := file.Read(buffer)
- if err != nil && err != io.EOF {
- return false
- }
- // Detect content type.
- contentType := http.DetectContentType(buffer[:n])
- // Check if it's a text MIME type.
- return strings.HasPrefix(contentType, "text/") ||
- contentType == "application/json" ||
- contentType == "application/xml" ||
- contentType == "application/javascript" ||
- contentType == "application/x-sh"
- }
- func globToRegex(glob string) string {
- regexPattern := strings.ReplaceAll(glob, ".", "\\.")
- regexPattern = strings.ReplaceAll(regexPattern, "*", ".*")
- regexPattern = strings.ReplaceAll(regexPattern, "?", ".")
- // Use pre-compiled regex instead of compiling each time
- regexPattern = globBraceRegex.ReplaceAllStringFunc(regexPattern, func(match string) string {
- inner := match[1 : len(match)-1]
- return "(" + strings.ReplaceAll(inner, ",", "|") + ")"
- })
- return regexPattern
- }
|