references.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. package tools
  2. import (
  3. "cmp"
  4. "context"
  5. _ "embed"
  6. "errors"
  7. "fmt"
  8. "log/slog"
  9. "maps"
  10. "path/filepath"
  11. "regexp"
  12. "slices"
  13. "sort"
  14. "strings"
  15. "charm.land/fantasy"
  16. "github.com/charmbracelet/crush/internal/csync"
  17. "github.com/charmbracelet/crush/internal/lsp"
  18. "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
  19. )
  20. type ReferencesParams struct {
  21. Symbol string `json:"symbol" description:"The symbol name to search for (e.g., function name, variable name, type name)"`
  22. Path string `json:"path,omitempty" description:"The directory to search in. Use a directory/file to narrow down the symbol search. Defaults to the current working directory."`
  23. }
  24. type referencesTool struct {
  25. lspClients *csync.Map[string, *lsp.Client]
  26. }
  27. const ReferencesToolName = "lsp_references"
  28. //go:embed references.md
  29. var referencesDescription []byte
  30. func NewReferencesTool(lspClients *csync.Map[string, *lsp.Client]) fantasy.AgentTool {
  31. return fantasy.NewAgentTool(
  32. ReferencesToolName,
  33. string(referencesDescription),
  34. func(ctx context.Context, params ReferencesParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  35. if params.Symbol == "" {
  36. return fantasy.NewTextErrorResponse("symbol is required"), nil
  37. }
  38. if lspClients.Len() == 0 {
  39. return fantasy.NewTextErrorResponse("no LSP clients available"), nil
  40. }
  41. workingDir := cmp.Or(params.Path, ".")
  42. matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100)
  43. if err != nil {
  44. return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil
  45. }
  46. if len(matches) == 0 {
  47. return fantasy.NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil
  48. }
  49. var allLocations []protocol.Location
  50. var allErrs error
  51. for _, match := range matches {
  52. locations, err := find(ctx, lspClients, params.Symbol, match)
  53. if err != nil {
  54. if strings.Contains(err.Error(), "no identifier found") {
  55. // grep probably matched a comment, string value, or something else that's irrelevant
  56. continue
  57. }
  58. slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum)
  59. allErrs = errors.Join(allErrs, err)
  60. continue
  61. }
  62. allLocations = append(allLocations, locations...)
  63. // XXX: should we break here or look for all results?
  64. }
  65. if len(allLocations) > 0 {
  66. output := formatReferences(cleanupLocations(allLocations))
  67. return fantasy.NewTextResponse(output), nil
  68. }
  69. if allErrs != nil {
  70. return fantasy.NewTextErrorResponse(allErrs.Error()), nil
  71. }
  72. return fantasy.NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil
  73. })
  74. }
  75. func (r *referencesTool) Name() string {
  76. return ReferencesToolName
  77. }
  78. func find(ctx context.Context, lspClients *csync.Map[string, *lsp.Client], symbol string, match grepMatch) ([]protocol.Location, error) {
  79. absPath, err := filepath.Abs(match.path)
  80. if err != nil {
  81. return nil, fmt.Errorf("failed to get absolute path: %s", err)
  82. }
  83. var client *lsp.Client
  84. for c := range lspClients.Seq() {
  85. if c.HandlesFile(absPath) {
  86. client = c
  87. break
  88. }
  89. }
  90. if client == nil {
  91. slog.Warn("No LSP clients to handle", "path", match.path)
  92. return nil, nil
  93. }
  94. return client.FindReferences(
  95. ctx,
  96. absPath,
  97. match.lineNum,
  98. match.charNum+getSymbolOffset(symbol),
  99. true,
  100. )
  101. }
  102. // getSymbolOffset returns the character offset to the actual symbol name
  103. // in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method").
  104. func getSymbolOffset(symbol string) int {
  105. // Check for :: separator (Rust, C++, Ruby modules/classes, PHP static).
  106. if idx := strings.LastIndex(symbol, "::"); idx != -1 {
  107. return idx + 2
  108. }
  109. // Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods).
  110. if idx := strings.LastIndex(symbol, "."); idx != -1 {
  111. return idx + 1
  112. }
  113. // Check for \ separator (PHP namespaces).
  114. if idx := strings.LastIndex(symbol, "\\"); idx != -1 {
  115. return idx + 1
  116. }
  117. return 0
  118. }
  119. func cleanupLocations(locations []protocol.Location) []protocol.Location {
  120. slices.SortFunc(locations, func(a, b protocol.Location) int {
  121. if a.URI != b.URI {
  122. return strings.Compare(string(a.URI), string(b.URI))
  123. }
  124. if a.Range.Start.Line != b.Range.Start.Line {
  125. return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line)
  126. }
  127. return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character)
  128. })
  129. return slices.CompactFunc(locations, func(a, b protocol.Location) bool {
  130. return a.URI == b.URI &&
  131. a.Range.Start.Line == b.Range.Start.Line &&
  132. a.Range.Start.Character == b.Range.Start.Character
  133. })
  134. }
  135. func groupByFilename(locations []protocol.Location) map[string][]protocol.Location {
  136. files := make(map[string][]protocol.Location)
  137. for _, loc := range locations {
  138. path, err := loc.URI.Path()
  139. if err != nil {
  140. slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err)
  141. continue
  142. }
  143. files[path] = append(files[path], loc)
  144. }
  145. return files
  146. }
  147. func formatReferences(locations []protocol.Location) string {
  148. fileRefs := groupByFilename(locations)
  149. files := slices.Collect(maps.Keys(fileRefs))
  150. sort.Strings(files)
  151. var output strings.Builder
  152. output.WriteString(fmt.Sprintf("Found %d reference(s) in %d file(s):\n\n", len(locations), len(files)))
  153. for _, file := range files {
  154. refs := fileRefs[file]
  155. output.WriteString(fmt.Sprintf("%s (%d reference(s)):\n", file, len(refs)))
  156. for _, ref := range refs {
  157. line := ref.Range.Start.Line + 1
  158. char := ref.Range.Start.Character + 1
  159. output.WriteString(fmt.Sprintf(" Line %d, Column %d\n", line, char))
  160. }
  161. output.WriteString("\n")
  162. }
  163. return output.String()
  164. }