client.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. package lsp
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "os"
  9. "os/exec"
  10. "strings"
  11. "sync"
  12. "sync/atomic"
  13. "time"
  14. "github.com/kujtimiihoxha/termai/internal/config"
  15. "github.com/kujtimiihoxha/termai/internal/logging"
  16. "github.com/kujtimiihoxha/termai/internal/lsp/protocol"
  17. )
  18. type Client struct {
  19. Cmd *exec.Cmd
  20. stdin io.WriteCloser
  21. stdout *bufio.Reader
  22. stderr io.ReadCloser
  23. // Request ID counter
  24. nextID atomic.Int32
  25. // Response handlers
  26. handlers map[int32]chan *Message
  27. handlersMu sync.RWMutex
  28. // Server request handlers
  29. serverRequestHandlers map[string]ServerRequestHandler
  30. serverHandlersMu sync.RWMutex
  31. // Notification handlers
  32. notificationHandlers map[string]NotificationHandler
  33. notificationMu sync.RWMutex
  34. // Diagnostic cache
  35. diagnostics map[protocol.DocumentUri][]protocol.Diagnostic
  36. diagnosticsMu sync.RWMutex
  37. // Files are currently opened by the LSP
  38. openFiles map[string]*OpenFileInfo
  39. openFilesMu sync.RWMutex
  40. }
  41. func NewClient(ctx context.Context, command string, args ...string) (*Client, error) {
  42. cmd := exec.CommandContext(ctx, command, args...)
  43. // Copy env
  44. cmd.Env = os.Environ()
  45. stdin, err := cmd.StdinPipe()
  46. if err != nil {
  47. return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
  48. }
  49. stdout, err := cmd.StdoutPipe()
  50. if err != nil {
  51. return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
  52. }
  53. stderr, err := cmd.StderrPipe()
  54. if err != nil {
  55. return nil, fmt.Errorf("failed to create stderr pipe: %w", err)
  56. }
  57. client := &Client{
  58. Cmd: cmd,
  59. stdin: stdin,
  60. stdout: bufio.NewReader(stdout),
  61. stderr: stderr,
  62. handlers: make(map[int32]chan *Message),
  63. notificationHandlers: make(map[string]NotificationHandler),
  64. serverRequestHandlers: make(map[string]ServerRequestHandler),
  65. diagnostics: make(map[protocol.DocumentUri][]protocol.Diagnostic),
  66. openFiles: make(map[string]*OpenFileInfo),
  67. }
  68. // Start the LSP server process
  69. if err := cmd.Start(); err != nil {
  70. return nil, fmt.Errorf("failed to start LSP server: %w", err)
  71. }
  72. // Handle stderr in a separate goroutine
  73. go func() {
  74. scanner := bufio.NewScanner(stderr)
  75. for scanner.Scan() {
  76. fmt.Fprintf(os.Stderr, "LSP Server: %s\n", scanner.Text())
  77. }
  78. if err := scanner.Err(); err != nil {
  79. fmt.Fprintf(os.Stderr, "Error reading stderr: %v\n", err)
  80. }
  81. }()
  82. // Start message handling loop
  83. go client.handleMessages()
  84. return client, nil
  85. }
  86. func (c *Client) RegisterNotificationHandler(method string, handler NotificationHandler) {
  87. c.notificationMu.Lock()
  88. defer c.notificationMu.Unlock()
  89. c.notificationHandlers[method] = handler
  90. }
  91. func (c *Client) RegisterServerRequestHandler(method string, handler ServerRequestHandler) {
  92. c.serverHandlersMu.Lock()
  93. defer c.serverHandlersMu.Unlock()
  94. c.serverRequestHandlers[method] = handler
  95. }
  96. func (c *Client) InitializeLSPClient(ctx context.Context, workspaceDir string) (*protocol.InitializeResult, error) {
  97. initParams := &protocol.InitializeParams{
  98. WorkspaceFoldersInitializeParams: protocol.WorkspaceFoldersInitializeParams{
  99. WorkspaceFolders: []protocol.WorkspaceFolder{
  100. {
  101. URI: protocol.URI("file://" + workspaceDir),
  102. Name: workspaceDir,
  103. },
  104. },
  105. },
  106. XInitializeParams: protocol.XInitializeParams{
  107. ProcessID: int32(os.Getpid()),
  108. ClientInfo: &protocol.ClientInfo{
  109. Name: "mcp-language-server",
  110. Version: "0.1.0",
  111. },
  112. RootPath: workspaceDir,
  113. RootURI: protocol.DocumentUri("file://" + workspaceDir),
  114. Capabilities: protocol.ClientCapabilities{
  115. Workspace: protocol.WorkspaceClientCapabilities{
  116. Configuration: true,
  117. DidChangeConfiguration: protocol.DidChangeConfigurationClientCapabilities{
  118. DynamicRegistration: true,
  119. },
  120. DidChangeWatchedFiles: protocol.DidChangeWatchedFilesClientCapabilities{
  121. DynamicRegistration: true,
  122. RelativePatternSupport: true,
  123. },
  124. },
  125. TextDocument: protocol.TextDocumentClientCapabilities{
  126. Synchronization: &protocol.TextDocumentSyncClientCapabilities{
  127. DynamicRegistration: true,
  128. DidSave: true,
  129. },
  130. Completion: protocol.CompletionClientCapabilities{
  131. CompletionItem: protocol.ClientCompletionItemOptions{},
  132. },
  133. CodeLens: &protocol.CodeLensClientCapabilities{
  134. DynamicRegistration: true,
  135. },
  136. DocumentSymbol: protocol.DocumentSymbolClientCapabilities{},
  137. CodeAction: protocol.CodeActionClientCapabilities{
  138. CodeActionLiteralSupport: protocol.ClientCodeActionLiteralOptions{
  139. CodeActionKind: protocol.ClientCodeActionKindOptions{
  140. ValueSet: []protocol.CodeActionKind{},
  141. },
  142. },
  143. },
  144. PublishDiagnostics: protocol.PublishDiagnosticsClientCapabilities{
  145. VersionSupport: true,
  146. },
  147. SemanticTokens: protocol.SemanticTokensClientCapabilities{
  148. Requests: protocol.ClientSemanticTokensRequestOptions{
  149. Range: &protocol.Or_ClientSemanticTokensRequestOptions_range{},
  150. Full: &protocol.Or_ClientSemanticTokensRequestOptions_full{},
  151. },
  152. TokenTypes: []string{},
  153. TokenModifiers: []string{},
  154. Formats: []protocol.TokenFormat{},
  155. },
  156. },
  157. Window: protocol.WindowClientCapabilities{},
  158. },
  159. InitializationOptions: map[string]any{
  160. "codelenses": map[string]bool{
  161. "generate": true,
  162. "regenerate_cgo": true,
  163. "test": true,
  164. "tidy": true,
  165. "upgrade_dependency": true,
  166. "vendor": true,
  167. "vulncheck": false,
  168. },
  169. },
  170. },
  171. }
  172. var result protocol.InitializeResult
  173. if err := c.Call(ctx, "initialize", initParams, &result); err != nil {
  174. return nil, fmt.Errorf("initialize failed: %w", err)
  175. }
  176. if err := c.Notify(ctx, "initialized", struct{}{}); err != nil {
  177. return nil, fmt.Errorf("initialized notification failed: %w", err)
  178. }
  179. // Register handlers
  180. c.RegisterServerRequestHandler("workspace/applyEdit", HandleApplyEdit)
  181. c.RegisterServerRequestHandler("workspace/configuration", HandleWorkspaceConfiguration)
  182. c.RegisterServerRequestHandler("client/registerCapability", HandleRegisterCapability)
  183. c.RegisterNotificationHandler("window/showMessage", HandleServerMessage)
  184. c.RegisterNotificationHandler("textDocument/publishDiagnostics",
  185. func(params json.RawMessage) { HandleDiagnostics(c, params) })
  186. // Notify the LSP server
  187. err := c.Initialized(ctx, protocol.InitializedParams{})
  188. if err != nil {
  189. return nil, fmt.Errorf("initialization failed: %w", err)
  190. }
  191. // LSP sepecific Initialization
  192. path := strings.ToLower(c.Cmd.Path)
  193. switch {
  194. case strings.Contains(path, "typescript-language-server"):
  195. // err := initializeTypescriptLanguageServer(ctx, c, workspaceDir)
  196. // if err != nil {
  197. // return nil, err
  198. // }
  199. }
  200. return &result, nil
  201. }
  202. func (c *Client) Close() error {
  203. // Try to close all open files first
  204. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  205. defer cancel()
  206. // Attempt to close files but continue shutdown regardless
  207. c.CloseAllFiles(ctx)
  208. // Close stdin to signal the server
  209. if err := c.stdin.Close(); err != nil {
  210. return fmt.Errorf("failed to close stdin: %w", err)
  211. }
  212. // Use a channel to handle the Wait with timeout
  213. done := make(chan error, 1)
  214. go func() {
  215. done <- c.Cmd.Wait()
  216. }()
  217. // Wait for process to exit with timeout
  218. select {
  219. case err := <-done:
  220. return err
  221. case <-time.After(2 * time.Second):
  222. // If we timeout, try to kill the process
  223. if err := c.Cmd.Process.Kill(); err != nil {
  224. return fmt.Errorf("failed to kill process: %w", err)
  225. }
  226. return fmt.Errorf("process killed after timeout")
  227. }
  228. }
  229. type ServerState int
  230. const (
  231. StateStarting ServerState = iota
  232. StateReady
  233. StateError
  234. )
  235. func (c *Client) WaitForServerReady(ctx context.Context) error {
  236. // TODO: wait for specific messages or poll workspace/symbol
  237. time.Sleep(time.Second * 1)
  238. return nil
  239. }
  240. type OpenFileInfo struct {
  241. Version int32
  242. URI protocol.DocumentUri
  243. }
  244. func (c *Client) OpenFile(ctx context.Context, filepath string) error {
  245. uri := fmt.Sprintf("file://%s", filepath)
  246. c.openFilesMu.Lock()
  247. if _, exists := c.openFiles[uri]; exists {
  248. c.openFilesMu.Unlock()
  249. return nil // Already open
  250. }
  251. c.openFilesMu.Unlock()
  252. // Skip files that do not exist or cannot be read
  253. content, err := os.ReadFile(filepath)
  254. if err != nil {
  255. return fmt.Errorf("error reading file: %w", err)
  256. }
  257. params := protocol.DidOpenTextDocumentParams{
  258. TextDocument: protocol.TextDocumentItem{
  259. URI: protocol.DocumentUri(uri),
  260. LanguageID: DetectLanguageID(uri),
  261. Version: 1,
  262. Text: string(content),
  263. },
  264. }
  265. if err := c.Notify(ctx, "textDocument/didOpen", params); err != nil {
  266. return err
  267. }
  268. c.openFilesMu.Lock()
  269. c.openFiles[uri] = &OpenFileInfo{
  270. Version: 1,
  271. URI: protocol.DocumentUri(uri),
  272. }
  273. c.openFilesMu.Unlock()
  274. return nil
  275. }
  276. func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
  277. uri := fmt.Sprintf("file://%s", filepath)
  278. content, err := os.ReadFile(filepath)
  279. if err != nil {
  280. return fmt.Errorf("error reading file: %w", err)
  281. }
  282. c.openFilesMu.Lock()
  283. fileInfo, isOpen := c.openFiles[uri]
  284. if !isOpen {
  285. c.openFilesMu.Unlock()
  286. return fmt.Errorf("cannot notify change for unopened file: %s", filepath)
  287. }
  288. // Increment version
  289. fileInfo.Version++
  290. version := fileInfo.Version
  291. c.openFilesMu.Unlock()
  292. params := protocol.DidChangeTextDocumentParams{
  293. TextDocument: protocol.VersionedTextDocumentIdentifier{
  294. TextDocumentIdentifier: protocol.TextDocumentIdentifier{
  295. URI: protocol.DocumentUri(uri),
  296. },
  297. Version: version,
  298. },
  299. ContentChanges: []protocol.TextDocumentContentChangeEvent{
  300. {
  301. Value: protocol.TextDocumentContentChangeWholeDocument{
  302. Text: string(content),
  303. },
  304. },
  305. },
  306. }
  307. return c.Notify(ctx, "textDocument/didChange", params)
  308. }
  309. func (c *Client) CloseFile(ctx context.Context, filepath string) error {
  310. cnf := config.Get()
  311. uri := fmt.Sprintf("file://%s", filepath)
  312. c.openFilesMu.Lock()
  313. if _, exists := c.openFiles[uri]; !exists {
  314. c.openFilesMu.Unlock()
  315. return nil // Already closed
  316. }
  317. c.openFilesMu.Unlock()
  318. params := protocol.DidCloseTextDocumentParams{
  319. TextDocument: protocol.TextDocumentIdentifier{
  320. URI: protocol.DocumentUri(uri),
  321. },
  322. }
  323. if cnf.Debug {
  324. logging.Debug("Closing file", "file", filepath)
  325. }
  326. if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
  327. return err
  328. }
  329. c.openFilesMu.Lock()
  330. delete(c.openFiles, uri)
  331. c.openFilesMu.Unlock()
  332. return nil
  333. }
  334. func (c *Client) IsFileOpen(filepath string) bool {
  335. uri := fmt.Sprintf("file://%s", filepath)
  336. c.openFilesMu.RLock()
  337. defer c.openFilesMu.RUnlock()
  338. _, exists := c.openFiles[uri]
  339. return exists
  340. }
  341. // CloseAllFiles closes all currently open files
  342. func (c *Client) CloseAllFiles(ctx context.Context) {
  343. cnf := config.Get()
  344. c.openFilesMu.Lock()
  345. filesToClose := make([]string, 0, len(c.openFiles))
  346. // First collect all URIs that need to be closed
  347. for uri := range c.openFiles {
  348. // Convert URI back to file path by trimming "file://" prefix
  349. filePath := strings.TrimPrefix(uri, "file://")
  350. filesToClose = append(filesToClose, filePath)
  351. }
  352. c.openFilesMu.Unlock()
  353. // Then close them all
  354. for _, filePath := range filesToClose {
  355. err := c.CloseFile(ctx, filePath)
  356. if err != nil && cnf.Debug {
  357. logging.Warn("Error closing file", "file", filePath, "error", err)
  358. }
  359. }
  360. if cnf.Debug {
  361. logging.Debug("Closed all files", "files", filesToClose)
  362. }
  363. }
  364. func (c *Client) GetFileDiagnostics(uri protocol.DocumentUri) []protocol.Diagnostic {
  365. c.diagnosticsMu.RLock()
  366. defer c.diagnosticsMu.RUnlock()
  367. return c.diagnostics[uri]
  368. }
  369. func (c *Client) GetDiagnostics() map[protocol.DocumentUri][]protocol.Diagnostic {
  370. return c.diagnostics
  371. }