| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- package tools
- import (
- "context"
- _ "embed"
- "fmt"
- "log/slog"
- "sort"
- "strings"
- "sync"
- "time"
- "charm.land/fantasy"
- "github.com/charmbracelet/crush/internal/lsp"
- "github.com/charmbracelet/x/powernap/pkg/lsp/protocol"
- )
- type DiagnosticsParams struct {
- FilePath string `json:"file_path,omitempty" description:"The path to the file to get diagnostics for (leave empty for project diagnostics)"`
- }
- const DiagnosticsToolName = "lsp_diagnostics"
- //go:embed diagnostics.md
- var diagnosticsDescription []byte
- func NewDiagnosticsTool(lspManager *lsp.Manager) fantasy.AgentTool {
- return fantasy.NewAgentTool(
- DiagnosticsToolName,
- string(diagnosticsDescription),
- func(ctx context.Context, params DiagnosticsParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
- if lspManager.Clients().Len() == 0 {
- return fantasy.NewTextErrorResponse("no LSP clients available"), nil
- }
- notifyLSPs(ctx, lspManager, params.FilePath)
- output := getDiagnostics(params.FilePath, lspManager)
- return fantasy.NewTextResponse(output), nil
- })
- }
- // openInLSPs ensures LSP servers are running and aware of the file, but does
- // not notify changes or wait for fresh diagnostics. Use this for read-only
- // operations like view where the file content hasn't changed.
- func openInLSPs(
- ctx context.Context,
- manager *lsp.Manager,
- filepath string,
- ) {
- if filepath == "" || manager == nil {
- return
- }
- manager.Start(ctx, filepath)
- for client := range manager.Clients().Seq() {
- if !client.HandlesFile(filepath) {
- continue
- }
- _ = client.OpenFileOnDemand(ctx, filepath)
- }
- }
- // waitForLSPDiagnostics waits briefly for diagnostics publication after a file
- // has been opened. Intended for read-only situations where viewing up-to-date
- // files matters but latency should remain low (i.e. when using the view tool).
- func waitForLSPDiagnostics(
- ctx context.Context,
- manager *lsp.Manager,
- filepath string,
- timeout time.Duration,
- ) {
- if filepath == "" || manager == nil || timeout <= 0 {
- return
- }
- var wg sync.WaitGroup
- for client := range manager.Clients().Seq() {
- if !client.HandlesFile(filepath) {
- continue
- }
- wg.Go(func() {
- client.WaitForDiagnostics(ctx, timeout)
- })
- }
- wg.Wait()
- }
- // notifyLSPs notifies LSP servers that a file has changed and waits for
- // updated diagnostics. Use this after edit/multiedit operations.
- func notifyLSPs(
- ctx context.Context,
- manager *lsp.Manager,
- filepath string,
- ) {
- if filepath == "" || manager == nil {
- return
- }
- manager.Start(ctx, filepath)
- for client := range manager.Clients().Seq() {
- if !client.HandlesFile(filepath) {
- continue
- }
- _ = client.OpenFileOnDemand(ctx, filepath)
- _ = client.NotifyChange(ctx, filepath)
- client.WaitForDiagnostics(ctx, 5*time.Second)
- }
- }
- func getDiagnostics(filePath string, manager *lsp.Manager) string {
- if manager == nil {
- return ""
- }
- var fileDiagnostics []string
- var projectDiagnostics []string
- for lspName, client := range manager.Clients().Seq2() {
- for location, diags := range client.GetDiagnostics() {
- path, err := location.Path()
- if err != nil {
- slog.Error("Failed to convert diagnostic location URI to path", "uri", location, "error", err)
- continue
- }
- isCurrentFile := path == filePath
- for _, diag := range diags {
- formattedDiag := formatDiagnostic(path, diag, lspName)
- if isCurrentFile {
- fileDiagnostics = append(fileDiagnostics, formattedDiag)
- } else {
- projectDiagnostics = append(projectDiagnostics, formattedDiag)
- }
- }
- }
- }
- sortDiagnostics(fileDiagnostics)
- sortDiagnostics(projectDiagnostics)
- var output strings.Builder
- writeDiagnostics(&output, "file_diagnostics", fileDiagnostics)
- writeDiagnostics(&output, "project_diagnostics", projectDiagnostics)
- if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
- fileErrors := countSeverity(fileDiagnostics, "Error")
- fileWarnings := countSeverity(fileDiagnostics, "Warn")
- projectErrors := countSeverity(projectDiagnostics, "Error")
- projectWarnings := countSeverity(projectDiagnostics, "Warn")
- output.WriteString("\n<diagnostic_summary>\n")
- fmt.Fprintf(&output, "Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
- fmt.Fprintf(&output, "Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
- output.WriteString("</diagnostic_summary>\n")
- }
- out := output.String()
- slog.Debug("Diagnostics", "output", out)
- return out
- }
- func writeDiagnostics(output *strings.Builder, tag string, in []string) {
- if len(in) == 0 {
- return
- }
- output.WriteString("\n<" + tag + ">\n")
- if len(in) > 10 {
- output.WriteString(strings.Join(in[:10], "\n"))
- fmt.Fprintf(output, "\n... and %d more diagnostics", len(in)-10)
- } else {
- output.WriteString(strings.Join(in, "\n"))
- }
- output.WriteString("\n</" + tag + ">\n")
- }
- func sortDiagnostics(in []string) []string {
- sort.Slice(in, func(i, j int) bool {
- iIsError := strings.HasPrefix(in[i], "Error")
- jIsError := strings.HasPrefix(in[j], "Error")
- if iIsError != jIsError {
- return iIsError // Errors come first
- }
- return in[i] < in[j] // Then alphabetically
- })
- return in
- }
- func formatDiagnostic(pth string, diagnostic protocol.Diagnostic, source string) string {
- severity := "Info"
- switch diagnostic.Severity {
- case protocol.SeverityError:
- severity = "Error"
- case protocol.SeverityWarning:
- severity = "Warn"
- case protocol.SeverityHint:
- severity = "Hint"
- }
- location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
- sourceInfo := source
- if diagnostic.Source != "" {
- sourceInfo += " " + diagnostic.Source
- }
- codeInfo := ""
- if diagnostic.Code != nil {
- codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
- }
- tagsInfo := ""
- if len(diagnostic.Tags) > 0 {
- var tags []string
- for _, tag := range diagnostic.Tags {
- switch tag {
- case protocol.Unnecessary:
- tags = append(tags, "unnecessary")
- case protocol.Deprecated:
- tags = append(tags, "deprecated")
- }
- }
- if len(tags) > 0 {
- tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
- }
- }
- return fmt.Sprintf("%s: %s [%s]%s%s %s",
- severity,
- location,
- sourceInfo,
- codeInfo,
- tagsInfo,
- diagnostic.Message)
- }
- func countSeverity(diagnostics []string, severity string) int {
- count := 0
- for _, diag := range diagnostics {
- if strings.HasPrefix(diag, severity) {
- count++
- }
- }
- return count
- }
|