diagnostics.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package tools
  2. import (
  3. "context"
  4. _ "embed"
  5. "fmt"
  6. "log/slog"
  7. "sort"
  8. "strings"
  9. "time"
  10. "charm.land/fantasy"
  11. "github.com/charmbracelet/crush/internal/csync"
  12. "github.com/charmbracelet/crush/internal/lsp"
  13. "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
  14. )
  15. type DiagnosticsParams struct {
  16. FilePath string `json:"file_path,omitempty" description:"The path to the file to get diagnostics for (leave w empty for project diagnostics)"`
  17. }
  18. const DiagnosticsToolName = "lsp_diagnostics"
  19. //go:embed diagnostics.md
  20. var diagnosticsDescription []byte
  21. func NewDiagnosticsTool(lspClients *csync.Map[string, *lsp.Client]) fantasy.AgentTool {
  22. return fantasy.NewAgentTool(
  23. DiagnosticsToolName,
  24. string(diagnosticsDescription),
  25. func(ctx context.Context, params DiagnosticsParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  26. if lspClients.Len() == 0 {
  27. return fantasy.NewTextErrorResponse("no LSP clients available"), nil
  28. }
  29. notifyLSPs(ctx, lspClients, params.FilePath)
  30. output := getDiagnostics(params.FilePath, lspClients)
  31. return fantasy.NewTextResponse(output), nil
  32. })
  33. }
  34. func notifyLSPs(ctx context.Context, lsps *csync.Map[string, *lsp.Client], filepath string) {
  35. if filepath == "" {
  36. return
  37. }
  38. for client := range lsps.Seq() {
  39. if !client.HandlesFile(filepath) {
  40. continue
  41. }
  42. _ = client.OpenFileOnDemand(ctx, filepath)
  43. _ = client.NotifyChange(ctx, filepath)
  44. client.WaitForDiagnostics(ctx, 5*time.Second)
  45. }
  46. }
  47. func getDiagnostics(filePath string, lsps *csync.Map[string, *lsp.Client]) string {
  48. fileDiagnostics := []string{}
  49. projectDiagnostics := []string{}
  50. for lspName, client := range lsps.Seq2() {
  51. for location, diags := range client.GetDiagnostics() {
  52. path, err := location.Path()
  53. if err != nil {
  54. slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
  55. continue
  56. }
  57. isCurrentFile := path == filePath
  58. for _, diag := range diags {
  59. formattedDiag := formatDiagnostic(path, diag, lspName)
  60. if isCurrentFile {
  61. fileDiagnostics = append(fileDiagnostics, formattedDiag)
  62. } else {
  63. projectDiagnostics = append(projectDiagnostics, formattedDiag)
  64. }
  65. }
  66. }
  67. }
  68. sortDiagnostics(fileDiagnostics)
  69. sortDiagnostics(projectDiagnostics)
  70. var output strings.Builder
  71. writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
  72. writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
  73. if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
  74. fileErrors := countSeverity(fileDiagnostics, "Error")
  75. fileWarnings := countSeverity(fileDiagnostics, "Warn")
  76. projectErrors := countSeverity(projectDiagnostics, "Error")
  77. projectWarnings := countSeverity(projectDiagnostics, "Warn")
  78. output.WriteString("\n<diagnostic_summary>\n")
  79. fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
  80. fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
  81. output.WriteString("</diagnostic_summary>\n")
  82. }
  83. out := output.String()
  84. slog.Debug("Diagnostics", "output", out)
  85. return out
  86. }
  87. func writeDiagnostics(output *strings.Builder, tag string, in []string) {
  88. if len(in) == 0 {
  89. return
  90. }
  91. output.WriteString("\n<" + tag + ">\n")
  92. if len(in) > 10 {
  93. output.WriteString(strings.Join(in[:10], "\n"))
  94. fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
  95. } else {
  96. output.WriteString(strings.Join(in, "\n"))
  97. }
  98. output.WriteString("\n</" + tag + ">\n")
  99. }
  100. func sortDiagnostics(in []string) []string {
  101. sort.Slice(in, func(i, j int) bool {
  102. iIsError := strings.HasPrefix(in[i], "Error")
  103. jIsError := strings.HasPrefix(in[j], "Error")
  104. if iIsError != jIsError {
  105. return iIsError // Errors come first
  106. }
  107. return in[i] < in[j] // Then alphabetically
  108. })
  109. return in
  110. }
  111. func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
  112. severity := "Info"
  113. switch diagnostic.Severity {
  114. case protocol.SeverityError:
  115. severity = "Error"
  116. case protocol.SeverityWarning:
  117. severity = "Warn"
  118. case protocol.SeverityHint:
  119. severity = "Hint"
  120. }
  121. location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
  122. sourceInfo := ""
  123. if diagnostic.Source != "" {
  124. sourceInfo = diagnostic.Source
  125. } else if source != "" {
  126. sourceInfo = source
  127. }
  128. codeInfo := ""
  129. if diagnostic.Code != nil {
  130. codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
  131. }
  132. tagsInfo := ""
  133. if len(diagnostic.Tags) > 0 {
  134. tags := []string{}
  135. for _, tag := range diagnostic.Tags {
  136. switch tag {
  137. case protocol.Unnecessary:
  138. tags = append(tags, "unnecessary")
  139. case protocol.Deprecated:
  140. tags = append(tags, "deprecated")
  141. }
  142. }
  143. if len(tags) > 0 {
  144. tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
  145. }
  146. }
  147. return fmt.Sprintf("%s: %s [%s]%s%s %s",
  148. severity,
  149. location,
  150. sourceInfo,
  151. codeInfo,
  152. tagsInfo,
  153. diagnostic.Message)
  154. }
  155. func countSeverity(diagnostics []string, severity string) int {
  156. count := 0
  157. for _, diag := range diagnostics {
  158. if strings.HasPrefix(diag, severity) {
  159. count++
  160. }
  161. }
  162. return count
  163. }