init.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. // Package mcp provides functionality for managing Model Context Protocol (MCP)
  2. // clients within the Crush application.
  3. package mcp
  4. import (
  5. "cmp"
  6. "context"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "net/http"
  12. "os"
  13. "os/exec"
  14. "strings"
  15. "sync"
  16. "time"
  17. "github.com/charmbracelet/crush/internal/config"
  18. "github.com/charmbracelet/crush/internal/csync"
  19. "github.com/charmbracelet/crush/internal/home"
  20. "github.com/charmbracelet/crush/internal/permission"
  21. "github.com/charmbracelet/crush/internal/pubsub"
  22. "github.com/charmbracelet/crush/internal/version"
  23. "github.com/modelcontextprotocol/go-sdk/mcp"
  24. )
  25. func parseLevel(level mcp.LoggingLevel) slog.Level {
  26. switch level {
  27. case "info":
  28. return slog.LevelInfo
  29. case "notice":
  30. return slog.LevelInfo
  31. case "warning":
  32. return slog.LevelWarn
  33. default:
  34. return slog.LevelDebug
  35. }
  36. }
  37. // ClientSession wraps an mcp.ClientSession with a context cancel function so
  38. // that the context created during session establishment is properly cleaned up
  39. // on close.
  40. type ClientSession struct {
  41. *mcp.ClientSession
  42. cancel context.CancelFunc
  43. }
  44. // Close cancels the session context and then closes the underlying session.
  45. func (s *ClientSession) Close() error {
  46. s.cancel()
  47. return s.ClientSession.Close()
  48. }
  49. var (
  50. sessions = csync.NewMap[string, *ClientSession]()
  51. states = csync.NewMap[string, ClientInfo]()
  52. broker = pubsub.NewBroker[Event]()
  53. initOnce sync.Once
  54. initDone = make(chan struct{})
  55. )
  56. // State represents the current state of an MCP client
  57. type State int
  58. const (
  59. StateDisabled State = iota
  60. StateStarting
  61. StateConnected
  62. StateError
  63. )
  64. func (s State) String() string {
  65. switch s {
  66. case StateDisabled:
  67. return "disabled"
  68. case StateStarting:
  69. return "starting"
  70. case StateConnected:
  71. return "connected"
  72. case StateError:
  73. return "error"
  74. default:
  75. return "unknown"
  76. }
  77. }
  78. // EventType represents the type of MCP event
  79. type EventType uint
  80. const (
  81. EventStateChanged EventType = iota
  82. EventToolsListChanged
  83. EventPromptsListChanged
  84. EventResourcesListChanged
  85. )
  86. // Event represents an event in the MCP system
  87. type Event struct {
  88. Type EventType
  89. Name string
  90. State State
  91. Error error
  92. Counts Counts
  93. }
  94. // Counts number of available tools, prompts, etc.
  95. type Counts struct {
  96. Tools int
  97. Prompts int
  98. Resources int
  99. }
  100. // ClientInfo holds information about an MCP client's state
  101. type ClientInfo struct {
  102. Name string
  103. State State
  104. Error error
  105. Client *ClientSession
  106. Counts Counts
  107. ConnectedAt time.Time
  108. }
  109. // SubscribeEvents returns a channel for MCP events
  110. func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
  111. return broker.Subscribe(ctx)
  112. }
  113. // GetStates returns the current state of all MCP clients
  114. func GetStates() map[string]ClientInfo {
  115. return states.Copy()
  116. }
  117. // GetState returns the state of a specific MCP client
  118. func GetState(name string) (ClientInfo, bool) {
  119. return states.Get(name)
  120. }
  121. // Close closes all MCP clients. This should be called during application shutdown.
  122. func Close(ctx context.Context) error {
  123. var wg sync.WaitGroup
  124. for name, session := range sessions.Seq2() {
  125. wg.Go(func() {
  126. done := make(chan error, 1)
  127. go func() {
  128. done <- session.Close()
  129. }()
  130. select {
  131. case err := <-done:
  132. if err != nil &&
  133. !errors.Is(err, io.EOF) &&
  134. !errors.Is(err, context.Canceled) &&
  135. err.Error() != "signal: killed" {
  136. slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
  137. }
  138. case <-ctx.Done():
  139. }
  140. })
  141. }
  142. wg.Wait()
  143. broker.Shutdown()
  144. return nil
  145. }
  146. // Initialize initializes MCP clients based on the provided configuration.
  147. func Initialize(ctx context.Context, permissions permission.Service, cfg *config.ConfigStore) {
  148. slog.Info("Initializing MCP clients")
  149. var wg sync.WaitGroup
  150. // Initialize states for all configured MCPs
  151. for name, m := range cfg.Config().MCP {
  152. if m.Disabled {
  153. updateState(name, StateDisabled, nil, nil, Counts{})
  154. slog.Debug("Skipping disabled MCP", "name", name)
  155. continue
  156. }
  157. // Set initial starting state
  158. wg.Add(1)
  159. go func(name string, m config.MCPConfig) {
  160. defer func() {
  161. wg.Done()
  162. if r := recover(); r != nil {
  163. var err error
  164. switch v := r.(type) {
  165. case error:
  166. err = v
  167. case string:
  168. err = fmt.Errorf("panic: %s", v)
  169. default:
  170. err = fmt.Errorf("panic: %v", v)
  171. }
  172. updateState(name, StateError, err, nil, Counts{})
  173. slog.Error("Panic in MCP client initialization", "error", err, "name", name)
  174. }
  175. }()
  176. if err := initClient(ctx, cfg, name, m, cfg.Resolver()); err != nil {
  177. slog.Debug("failed to initialize mcp client", "name", name, "error", err)
  178. }
  179. }(name, m)
  180. }
  181. wg.Wait()
  182. initOnce.Do(func() { close(initDone) })
  183. }
  184. // WaitForInit blocks until MCP initialization is complete.
  185. // If Initialize was never called, this returns immediately.
  186. func WaitForInit(ctx context.Context) error {
  187. select {
  188. case <-initDone:
  189. return nil
  190. case <-ctx.Done():
  191. return ctx.Err()
  192. }
  193. }
  194. // InitializeSingle initializes a single MCP client by name.
  195. func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore) error {
  196. m, exists := cfg.Config().MCP[name]
  197. if !exists {
  198. return fmt.Errorf("mcp '%s' not found in configuration", name)
  199. }
  200. if m.Disabled {
  201. updateState(name, StateDisabled, nil, nil, Counts{})
  202. slog.Debug("skipping disabled mcp", "name", name)
  203. return nil
  204. }
  205. return initClient(ctx, cfg, name, m, cfg.Resolver())
  206. }
  207. // initClient initializes a single MCP client with the given configuration.
  208. func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m config.MCPConfig, resolver config.VariableResolver) error {
  209. // Set initial starting state.
  210. updateState(name, StateStarting, nil, nil, Counts{})
  211. // createSession handles its own timeout internally.
  212. session, err := createSession(ctx, name, m, resolver)
  213. if err != nil {
  214. return err
  215. }
  216. tools, err := getTools(ctx, session)
  217. if err != nil {
  218. slog.Error("Error listing tools", "error", err)
  219. updateState(name, StateError, err, nil, Counts{})
  220. session.Close()
  221. return err
  222. }
  223. prompts, err := getPrompts(ctx, session)
  224. if err != nil {
  225. slog.Error("Error listing prompts", "error", err)
  226. updateState(name, StateError, err, nil, Counts{})
  227. session.Close()
  228. return err
  229. }
  230. toolCount := updateTools(cfg, name, tools)
  231. updatePrompts(name, prompts)
  232. sessions.Set(name, session)
  233. updateState(name, StateConnected, nil, session, Counts{
  234. Tools: toolCount,
  235. Prompts: len(prompts),
  236. })
  237. return nil
  238. }
  239. // DisableSingle disables and closes a single MCP client by name.
  240. func DisableSingle(cfg *config.ConfigStore, name string) error {
  241. session, ok := sessions.Get(name)
  242. if ok {
  243. if err := session.Close(); err != nil &&
  244. !errors.Is(err, io.EOF) &&
  245. !errors.Is(err, context.Canceled) &&
  246. err.Error() != "signal: killed" {
  247. slog.Warn("error closing mcp session", "name", name, "error", err)
  248. }
  249. sessions.Del(name)
  250. }
  251. // Clear tools and prompts for this MCP.
  252. updateTools(cfg, name, nil)
  253. updatePrompts(name, nil)
  254. // Update state to disabled.
  255. updateState(name, StateDisabled, nil, nil, Counts{})
  256. slog.Info("Disabled mcp client", "name", name)
  257. return nil
  258. }
  259. func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) {
  260. sess, ok := sessions.Get(name)
  261. if !ok {
  262. return nil, fmt.Errorf("mcp '%s' not available", name)
  263. }
  264. m := cfg.Config().MCP[name]
  265. state, _ := states.Get(name)
  266. timeout := mcpTimeout(m)
  267. pingCtx, cancel := context.WithTimeout(ctx, timeout)
  268. defer cancel()
  269. err := sess.Ping(pingCtx, nil)
  270. if err == nil {
  271. return sess, nil
  272. }
  273. updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
  274. sess, err = createSession(ctx, name, m, cfg.Resolver())
  275. if err != nil {
  276. return nil, err
  277. }
  278. updateState(name, StateConnected, nil, sess, state.Counts)
  279. sessions.Set(name, sess)
  280. return sess, nil
  281. }
  282. // updateState updates the state of an MCP client and publishes an event
  283. func updateState(name string, state State, err error, client *ClientSession, counts Counts) {
  284. info := ClientInfo{
  285. Name: name,
  286. State: state,
  287. Error: err,
  288. Client: client,
  289. Counts: counts,
  290. }
  291. switch state {
  292. case StateConnected:
  293. info.ConnectedAt = time.Now()
  294. case StateError:
  295. sessions.Del(name)
  296. }
  297. states.Set(name, info)
  298. // Publish state change event
  299. broker.Publish(pubsub.UpdatedEvent, Event{
  300. Type: EventStateChanged,
  301. Name: name,
  302. State: state,
  303. Error: err,
  304. Counts: counts,
  305. })
  306. }
  307. func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) {
  308. timeout := mcpTimeout(m)
  309. mcpCtx, cancel := context.WithCancel(ctx)
  310. cancelTimer := time.AfterFunc(timeout, cancel)
  311. transport, err := createTransport(mcpCtx, m, resolver)
  312. if err != nil {
  313. updateState(name, StateError, err, nil, Counts{})
  314. slog.Error("Error creating MCP client", "error", err, "name", name)
  315. cancel()
  316. cancelTimer.Stop()
  317. return nil, err
  318. }
  319. client := mcp.NewClient(
  320. &mcp.Implementation{
  321. Name: "crush",
  322. Version: version.Version,
  323. Title: "Crush",
  324. },
  325. &mcp.ClientOptions{
  326. ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
  327. broker.Publish(pubsub.UpdatedEvent, Event{
  328. Type: EventToolsListChanged,
  329. Name: name,
  330. })
  331. },
  332. PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
  333. broker.Publish(pubsub.UpdatedEvent, Event{
  334. Type: EventPromptsListChanged,
  335. Name: name,
  336. })
  337. },
  338. ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
  339. broker.Publish(pubsub.UpdatedEvent, Event{
  340. Type: EventResourcesListChanged,
  341. Name: name,
  342. })
  343. },
  344. LoggingMessageHandler: func(ctx context.Context, req *mcp.LoggingMessageRequest) {
  345. level := parseLevel(req.Params.Level)
  346. slog.Log(ctx, level, "MCP log", "name", name, "logger", req.Params.Logger, "data", req.Params.Data)
  347. },
  348. },
  349. )
  350. session, err := client.Connect(mcpCtx, transport, nil)
  351. if err != nil {
  352. err = maybeStdioErr(err, transport)
  353. updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
  354. slog.Error("MCP client failed to initialize", "error", err, "name", name)
  355. cancel()
  356. cancelTimer.Stop()
  357. return nil, err
  358. }
  359. cancelTimer.Stop()
  360. slog.Debug("MCP client initialized", "name", name)
  361. return &ClientSession{session, cancel}, nil
  362. }
  363. // maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
  364. // to parse, and the cli will then close it, causing the EOF error.
  365. // so, if we got an EOF err, and the transport is STDIO, we try to exec it
  366. // again with a timeout and collect the output so we can add details to the
  367. // error.
  368. // this happens particularly when starting things with npx, e.g. if node can't
  369. // be found or some other error like that.
  370. func maybeStdioErr(err error, transport mcp.Transport) error {
  371. if !errors.Is(err, io.EOF) {
  372. return err
  373. }
  374. ct, ok := transport.(*mcp.CommandTransport)
  375. if !ok {
  376. return err
  377. }
  378. if err2 := stdioCheck(ct.Command); err2 != nil {
  379. err = errors.Join(err, err2)
  380. }
  381. return err
  382. }
  383. func maybeTimeoutErr(err error, timeout time.Duration) error {
  384. if errors.Is(err, context.Canceled) {
  385. return fmt.Errorf("timed out after %s", timeout)
  386. }
  387. return err
  388. }
  389. func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
  390. switch m.Type {
  391. case config.MCPStdio:
  392. command, err := resolver.ResolveValue(m.Command)
  393. if err != nil {
  394. return nil, fmt.Errorf("invalid mcp command: %w", err)
  395. }
  396. if strings.TrimSpace(command) == "" {
  397. return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
  398. }
  399. cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
  400. cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
  401. return &mcp.CommandTransport{
  402. Command: cmd,
  403. }, nil
  404. case config.MCPHttp:
  405. if strings.TrimSpace(m.URL) == "" {
  406. return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
  407. }
  408. client := &http.Client{
  409. Transport: &headerRoundTripper{
  410. headers: m.ResolvedHeaders(),
  411. },
  412. }
  413. return &mcp.StreamableClientTransport{
  414. Endpoint: m.URL,
  415. HTTPClient: client,
  416. }, nil
  417. case config.MCPSSE:
  418. if strings.TrimSpace(m.URL) == "" {
  419. return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
  420. }
  421. client := &http.Client{
  422. Transport: &headerRoundTripper{
  423. headers: m.ResolvedHeaders(),
  424. },
  425. }
  426. return &mcp.SSEClientTransport{
  427. Endpoint: m.URL,
  428. HTTPClient: client,
  429. }, nil
  430. default:
  431. return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
  432. }
  433. }
  434. type headerRoundTripper struct {
  435. headers map[string]string
  436. }
  437. func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  438. for k, v := range rt.headers {
  439. req.Header.Set(k, v)
  440. }
  441. return http.DefaultTransport.RoundTrip(req)
  442. }
  443. func mcpTimeout(m config.MCPConfig) time.Duration {
  444. return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
  445. }
  446. func stdioCheck(old *exec.Cmd) error {
  447. ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
  448. defer cancel()
  449. cmd := exec.CommandContext(ctx, old.Path, old.Args...)
  450. cmd.Env = old.Env
  451. out, err := cmd.CombinedOutput()
  452. if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
  453. return nil
  454. }
  455. return fmt.Errorf("%w: %s", err, string(out))
  456. }