references.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package tools
  2. import (
  3. "cmp"
  4. "context"
  5. _ "embed"
  6. "errors"
  7. "fmt"
  8. "log/slog"
  9. "maps"
  10. "path/filepath"
  11. "regexp"
  12. "slices"
  13. "sort"
  14. "strings"
  15. "charm.land/fantasy"
  16. "github.com/charmbracelet/crush/internal/lsp"
  17. "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
  18. )
  19. type ReferencesParams struct {
  20. Symbol string `json:"symbol" description:"The symbol name to search for (e.g., function name, variable name, type name)"`
  21. Path string `json:"path,omitempty" description:"The directory to search in. Use a directory/file to narrow down the symbol search. Defaults to the current working directory."`
  22. }
  23. type referencesTool struct {
  24. lspManager *lsp.Manager
  25. }
  26. const ReferencesToolName = "lsp_references"
  27. //go:embed references.md
  28. var referencesDescription []byte
  29. func NewReferencesTool(lspManager *lsp.Manager) fantasy.AgentTool {
  30. return fantasy.NewAgentTool(
  31. ReferencesToolName,
  32. string(referencesDescription),
  33. func(ctx context.Context, params ReferencesParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  34. if params.Symbol == "" {
  35. return fantasy.NewTextErrorResponse("symbol is required"), nil
  36. }
  37. if lspManager.Clients().Len() == 0 {
  38. return fantasy.NewTextErrorResponse("no LSP clients available"), nil
  39. }
  40. workingDir := cmp.Or(params.Path, ".")
  41. matches, _, err := searchFiles(ctx, regexp.QuoteMeta(params.Symbol), workingDir, "", 100)
  42. if err != nil {
  43. return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to search for symbol: %s", err)), nil
  44. }
  45. if len(matches) == 0 {
  46. return fantasy.NewTextResponse(fmt.Sprintf("Symbol '%s' not found", params.Symbol)), nil
  47. }
  48. var allLocations []protocol.Location
  49. var allErrs error
  50. for _, match := range matches {
  51. locations, err := find(ctx, lspManager, params.Symbol, match)
  52. if err != nil {
  53. if strings.Contains(err.Error(), "no identifier found") {
  54. // grep probably matched a comment, string value, or something else that's irrelevant
  55. continue
  56. }
  57. slog.Error("Failed to find references", "error", err, "symbol", params.Symbol, "path", match.path, "line", match.lineNum, "char", match.charNum)
  58. allErrs = errors.Join(allErrs, err)
  59. continue
  60. }
  61. allLocations = append(allLocations, locations...)
  62. // Once we have results, we're done - LSP returns all references
  63. // for the symbol, not just from this file.
  64. if len(locations) > 0 {
  65. break
  66. }
  67. }
  68. if len(allLocations) > 0 {
  69. output := formatReferences(cleanupLocations(allLocations))
  70. return fantasy.NewTextResponse(output), nil
  71. }
  72. if allErrs != nil {
  73. return fantasy.NewTextErrorResponse(allErrs.Error()), nil
  74. }
  75. return fantasy.NewTextResponse(fmt.Sprintf("No references found for symbol '%s'", params.Symbol)), nil
  76. })
  77. }
  78. func (r *referencesTool) Name() string {
  79. return ReferencesToolName
  80. }
  81. func find(ctx context.Context, lspManager *lsp.Manager, symbol string, match grepMatch) ([]protocol.Location, error) {
  82. absPath, err := filepath.Abs(match.path)
  83. if err != nil {
  84. return nil, fmt.Errorf("failed to get absolute path: %s", err)
  85. }
  86. var client *lsp.Client
  87. for c := range lspManager.Clients().Seq() {
  88. if c.HandlesFile(absPath) {
  89. client = c
  90. break
  91. }
  92. }
  93. if client == nil {
  94. slog.Warn("No LSP clients to handle", "path", match.path)
  95. return nil, nil
  96. }
  97. return client.FindReferences(
  98. ctx,
  99. absPath,
  100. match.lineNum,
  101. match.charNum+getSymbolOffset(symbol),
  102. true,
  103. )
  104. }
  105. // getSymbolOffset returns the character offset to the actual symbol name
  106. // in a qualified symbol (e.g., "Bar" in "foo.Bar" or "method" in "Class::method").
  107. func getSymbolOffset(symbol string) int {
  108. // Check for :: separator (Rust, C++, Ruby modules/classes, PHP static).
  109. if idx := strings.LastIndex(symbol, "::"); idx != -1 {
  110. return idx + 2
  111. }
  112. // Check for . separator (Go, Python, JavaScript, Java, C#, Ruby methods).
  113. if idx := strings.LastIndex(symbol, "."); idx != -1 {
  114. return idx + 1
  115. }
  116. // Check for \ separator (PHP namespaces).
  117. if idx := strings.LastIndex(symbol, "\\"); idx != -1 {
  118. return idx + 1
  119. }
  120. return 0
  121. }
  122. func cleanupLocations(locations []protocol.Location) []protocol.Location {
  123. slices.SortFunc(locations, func(a, b protocol.Location) int {
  124. if a.URI != b.URI {
  125. return strings.Compare(string(a.URI), string(b.URI))
  126. }
  127. if a.Range.Start.Line != b.Range.Start.Line {
  128. return cmp.Compare(a.Range.Start.Line, b.Range.Start.Line)
  129. }
  130. return cmp.Compare(a.Range.Start.Character, b.Range.Start.Character)
  131. })
  132. return slices.CompactFunc(locations, func(a, b protocol.Location) bool {
  133. return a.URI == b.URI &&
  134. a.Range.Start.Line == b.Range.Start.Line &&
  135. a.Range.Start.Character == b.Range.Start.Character
  136. })
  137. }
  138. func groupByFilename(locations []protocol.Location) map[string][]protocol.Location {
  139. files := make(map[string][]protocol.Location)
  140. for _, loc := range locations {
  141. path, err := loc.URI.Path()
  142. if err != nil {
  143. slog.Error("Failed to convert location URI to path", "uri", loc.URI, "error", err)
  144. continue
  145. }
  146. files[path] = append(files[path], loc)
  147. }
  148. return files
  149. }
  150. func formatReferences(locations []protocol.Location) string {
  151. fileRefs := groupByFilename(locations)
  152. files := slices.Collect(maps.Keys(fileRefs))
  153. sort.Strings(files)
  154. var output strings.Builder
  155. fmt.Fprintf(&output, "Found %d reference(s) in %d file(s):\n\n", len(locations), len(files))
  156. for _, file := range files {
  157. refs := fileRefs[file]
  158. fmt.Fprintf(&output, "%s (%d reference(s)):\n", file, len(refs))
  159. for _, ref := range refs {
  160. line := ref.Range.Start.Line + 1
  161. char := ref.Range.Start.Character + 1
  162. fmt.Fprintf(&output, " Line %d, Column %d\n", line, char)
  163. }
  164. output.WriteString("\n")
  165. }
  166. return output.String()
  167. }