diagnostics.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "maps"
  7. "sort"
  8. "strings"
  9. "time"
  10. "github.com/kujtimiihoxha/termai/internal/lsp"
  11. "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
  12. )
  13. type diagnosticsTool struct {
  14. lspClients map[string]*lsp.Client
  15. }
  16. const (
  17. DiagnosticsToolName = "diagnostics"
  18. )
  19. type DiagnosticsParams struct {
  20. FilePath string `json:"file_path"`
  21. }
  22. func (b *diagnosticsTool) Info() ToolInfo {
  23. return ToolInfo{
  24. Name: DiagnosticsToolName,
  25. Description: "Get diagnostics for a file and/or project.",
  26. Parameters: map[string]any{
  27. "file_path": map[string]any{
  28. "type": "string",
  29. "description": "The path to the file to get diagnostics for (leave w empty for project diagnostics)",
  30. },
  31. },
  32. Required: []string{},
  33. }
  34. }
  35. func (b *diagnosticsTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
  36. var params DiagnosticsParams
  37. if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
  38. return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
  39. }
  40. lsps := b.lspClients
  41. if len(lsps) == 0 {
  42. return NewTextErrorResponse("no LSP clients available"), nil
  43. }
  44. if params.FilePath != "" {
  45. notifyLspOpenFile(ctx, params.FilePath, lsps)
  46. waitForLspDiagnostics(ctx, params.FilePath, lsps)
  47. }
  48. output := appendDiagnostics(params.FilePath, lsps)
  49. return NewTextResponse(output), nil
  50. }
  51. func notifyLspOpenFile(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
  52. for _, client := range lsps {
  53. // Open the file
  54. err := client.OpenFile(ctx, filePath)
  55. if err != nil {
  56. // If there's an error opening the file, continue to the next client
  57. continue
  58. }
  59. }
  60. }
  61. // waitForLspDiagnostics opens a file in LSP clients and waits for diagnostics to be published
  62. func waitForLspDiagnostics(ctx context.Context, filePath string, lsps map[string]*lsp.Client) {
  63. if len(lsps) == 0 {
  64. return
  65. }
  66. // Create a channel to receive diagnostic notifications
  67. diagChan := make(chan struct{}, 1)
  68. // Register a temporary diagnostic handler for each client
  69. for _, client := range lsps {
  70. // Store the original diagnostics map to detect changes
  71. originalDiags := make(map[protocol.DocumentUri][]protocol.Diagnostic)
  72. maps.Copy(originalDiags, client.GetDiagnostics())
  73. // Create a notification handler that will signal when diagnostics are received
  74. handler := func(params json.RawMessage) {
  75. lsp.HandleDiagnostics(client, params)
  76. var diagParams protocol.PublishDiagnosticsParams
  77. if err := json.Unmarshal(params, &diagParams); err != nil {
  78. return
  79. }
  80. // If this is for our file or we've received any new diagnostics, signal completion
  81. if diagParams.URI.Path() == filePath || hasDiagnosticsChanged(client.GetDiagnostics(), originalDiags) {
  82. select {
  83. case diagChan <- struct{}{}:
  84. // Signal sent
  85. default:
  86. // Channel already has a value, no need to send again
  87. }
  88. }
  89. }
  90. // Register our temporary handler
  91. client.RegisterNotificationHandler("textDocument/publishDiagnostics", handler)
  92. // Notify change if the file is already open
  93. if client.IsFileOpen(filePath) {
  94. err := client.NotifyChange(ctx, filePath)
  95. if err != nil {
  96. continue
  97. }
  98. } else {
  99. // Open the file if it's not already open
  100. err := client.OpenFile(ctx, filePath)
  101. if err != nil {
  102. continue
  103. }
  104. }
  105. }
  106. // Wait for diagnostics with a reasonable timeout
  107. select {
  108. case <-diagChan:
  109. // Diagnostics received
  110. case <-time.After(5 * time.Second):
  111. // Timeout after 5 seconds - this is a fallback in case no diagnostics are published
  112. case <-ctx.Done():
  113. // Context cancelled
  114. }
  115. // Note: We're not unregistering our handler because the Client.RegisterNotificationHandler
  116. // replaces any existing handler, and we'll be replaced by the original handler when
  117. // the LSP client is reinitialized or when a new handler is registered.
  118. }
  119. // hasDiagnosticsChanged checks if there are any new diagnostics compared to the original set
  120. func hasDiagnosticsChanged(current, original map[protocol.DocumentUri][]protocol.Diagnostic) bool {
  121. for uri, diags := range current {
  122. origDiags, exists := original[uri]
  123. if !exists || len(diags) != len(origDiags) {
  124. return true
  125. }
  126. }
  127. return false
  128. }
  129. func appendDiagnostics(filePath string, lsps map[string]*lsp.Client) string {
  130. fileDiagnostics := []string{}
  131. projectDiagnostics := []string{}
  132. // Enhanced format function that includes more diagnostic information
  133. formatDiagnostic := func(pth string, diagnostic protocol.Diagnostic, source string) string {
  134. // Base components
  135. severity := "Info"
  136. switch diagnostic.Severity {
  137. case protocol.SeverityError:
  138. severity = "Error"
  139. case protocol.SeverityWarning:
  140. severity = "Warn"
  141. case protocol.SeverityHint:
  142. severity = "Hint"
  143. }
  144. // Location information
  145. location := fmt.Sprintf("%s:%d:%d", pth, diagnostic.Range.Start.Line+1, diagnostic.Range.Start.Character+1)
  146. // Source information (LSP name)
  147. sourceInfo := ""
  148. if diagnostic.Source != "" {
  149. sourceInfo = diagnostic.Source
  150. } else if source != "" {
  151. sourceInfo = source
  152. }
  153. // Code information
  154. codeInfo := ""
  155. if diagnostic.Code != nil {
  156. codeInfo = fmt.Sprintf("[%v]", diagnostic.Code)
  157. }
  158. // Tags information
  159. tagsInfo := ""
  160. if len(diagnostic.Tags) > 0 {
  161. tags := []string{}
  162. for _, tag := range diagnostic.Tags {
  163. switch tag {
  164. case protocol.Unnecessary:
  165. tags = append(tags, "unnecessary")
  166. case protocol.Deprecated:
  167. tags = append(tags, "deprecated")
  168. }
  169. }
  170. if len(tags) > 0 {
  171. tagsInfo = fmt.Sprintf(" (%s)", strings.Join(tags, ", "))
  172. }
  173. }
  174. // Assemble the full diagnostic message
  175. return fmt.Sprintf("%s: %s [%s]%s%s %s",
  176. severity,
  177. location,
  178. sourceInfo,
  179. codeInfo,
  180. tagsInfo,
  181. diagnostic.Message)
  182. }
  183. for lspName, client := range lsps {
  184. diagnostics := client.GetDiagnostics()
  185. if len(diagnostics) > 0 {
  186. for location, diags := range diagnostics {
  187. isCurrentFile := location.Path() == filePath
  188. // Group diagnostics by severity for better organization
  189. for _, diag := range diags {
  190. formattedDiag := formatDiagnostic(location.Path(), diag, lspName)
  191. if isCurrentFile {
  192. fileDiagnostics = append(fileDiagnostics, formattedDiag)
  193. } else {
  194. projectDiagnostics = append(projectDiagnostics, formattedDiag)
  195. }
  196. }
  197. }
  198. }
  199. }
  200. // Sort diagnostics by severity (errors first) and then by location
  201. sort.Slice(fileDiagnostics, func(i, j int) bool {
  202. iIsError := strings.HasPrefix(fileDiagnostics[i], "Error")
  203. jIsError := strings.HasPrefix(fileDiagnostics[j], "Error")
  204. if iIsError != jIsError {
  205. return iIsError // Errors come first
  206. }
  207. return fileDiagnostics[i] < fileDiagnostics[j] // Then alphabetically
  208. })
  209. sort.Slice(projectDiagnostics, func(i, j int) bool {
  210. iIsError := strings.HasPrefix(projectDiagnostics[i], "Error")
  211. jIsError := strings.HasPrefix(projectDiagnostics[j], "Error")
  212. if iIsError != jIsError {
  213. return iIsError
  214. }
  215. return projectDiagnostics[i] < projectDiagnostics[j]
  216. })
  217. output := ""
  218. if len(fileDiagnostics) > 0 {
  219. output += "\n<file_diagnostics>\n"
  220. if len(fileDiagnostics) > 10 {
  221. output += strings.Join(fileDiagnostics[:10], "\n")
  222. output += fmt.Sprintf("\n... and %d more diagnostics", len(fileDiagnostics)-10)
  223. } else {
  224. output += strings.Join(fileDiagnostics, "\n")
  225. }
  226. output += "\n</file_diagnostics>\n"
  227. }
  228. if len(projectDiagnostics) > 0 {
  229. output += "\n<project_diagnostics>\n"
  230. if len(projectDiagnostics) > 10 {
  231. output += strings.Join(projectDiagnostics[:10], "\n")
  232. output += fmt.Sprintf("\n... and %d more diagnostics", len(projectDiagnostics)-10)
  233. } else {
  234. output += strings.Join(projectDiagnostics, "\n")
  235. }
  236. output += "\n</project_diagnostics>\n"
  237. }
  238. // Add summary counts
  239. if len(fileDiagnostics) > 0 || len(projectDiagnostics) > 0 {
  240. fileErrors := countSeverity(fileDiagnostics, "Error")
  241. fileWarnings := countSeverity(fileDiagnostics, "Warn")
  242. projectErrors := countSeverity(projectDiagnostics, "Error")
  243. projectWarnings := countSeverity(projectDiagnostics, "Warn")
  244. output += "\n<diagnostic_summary>\n"
  245. output += fmt.Sprintf("Current file: %d errors, %d warnings\n", fileErrors, fileWarnings)
  246. output += fmt.Sprintf("Project: %d errors, %d warnings\n", projectErrors, projectWarnings)
  247. output += "</diagnostic_summary>\n"
  248. }
  249. return output
  250. }
  251. // Helper function to count diagnostics by severity
  252. func countSeverity(diagnostics []string, severity string) int {
  253. count := 0
  254. for _, diag := range diagnostics {
  255. if strings.HasPrefix(diag, severity) {
  256. count++
  257. }
  258. }
  259. return count
  260. }
  261. func NewDiagnosticsTool(lspClients map[string]*lsp.Client) BaseTool {
  262. return &diagnosticsTool{
  263. lspClients,
  264. }
  265. }