init.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. // Package mcp provides functionality for managing Model Context Protocol (MCP)
  2. // clients within the Crush application.
  3. package mcp
  4. import (
  5. "cmp"
  6. "context"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "log/slog"
  11. "maps"
  12. "net/http"
  13. "os"
  14. "os/exec"
  15. "strings"
  16. "sync"
  17. "time"
  18. "github.com/charmbracelet/crush/internal/config"
  19. "github.com/charmbracelet/crush/internal/csync"
  20. "github.com/charmbracelet/crush/internal/home"
  21. "github.com/charmbracelet/crush/internal/permission"
  22. "github.com/charmbracelet/crush/internal/pubsub"
  23. "github.com/charmbracelet/crush/internal/version"
  24. "github.com/modelcontextprotocol/go-sdk/mcp"
  25. )
  26. var (
  27. sessions = csync.NewMap[string, *mcp.ClientSession]()
  28. states = csync.NewMap[string, ClientInfo]()
  29. broker = pubsub.NewBroker[Event]()
  30. )
  31. // State represents the current state of an MCP client
  32. type State int
  33. const (
  34. StateDisabled State = iota
  35. StateStarting
  36. StateConnected
  37. StateError
  38. )
  39. func (s State) String() string {
  40. switch s {
  41. case StateDisabled:
  42. return "disabled"
  43. case StateStarting:
  44. return "starting"
  45. case StateConnected:
  46. return "connected"
  47. case StateError:
  48. return "error"
  49. default:
  50. return "unknown"
  51. }
  52. }
  53. // EventType represents the type of MCP event
  54. type EventType uint
  55. const (
  56. EventStateChanged EventType = iota
  57. EventToolsListChanged
  58. EventPromptsListChanged
  59. )
  60. // Event represents an event in the MCP system
  61. type Event struct {
  62. Type EventType
  63. Name string
  64. State State
  65. Error error
  66. Counts Counts
  67. }
  68. // Counts number of available tools, prompts, etc.
  69. type Counts struct {
  70. Tools int
  71. Prompts int
  72. }
  73. // ClientInfo holds information about an MCP client's state
  74. type ClientInfo struct {
  75. Name string
  76. State State
  77. Error error
  78. Client *mcp.ClientSession
  79. Counts Counts
  80. ConnectedAt time.Time
  81. }
  82. // SubscribeEvents returns a channel for MCP events
  83. func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
  84. return broker.Subscribe(ctx)
  85. }
  86. // GetStates returns the current state of all MCP clients
  87. func GetStates() map[string]ClientInfo {
  88. return maps.Collect(states.Seq2())
  89. }
  90. // GetState returns the state of a specific MCP client
  91. func GetState(name string) (ClientInfo, bool) {
  92. return states.Get(name)
  93. }
  94. // Close closes all MCP clients. This should be called during application shutdown.
  95. func Close() error {
  96. var errs []error
  97. for name, c := range sessions.Seq2() {
  98. if err := c.Close(); err != nil &&
  99. !errors.Is(err, io.EOF) &&
  100. !errors.Is(err, context.Canceled) &&
  101. err.Error() != "signal: killed" {
  102. errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
  103. }
  104. }
  105. broker.Shutdown()
  106. return errors.Join(errs...)
  107. }
  108. // Initialize initializes MCP clients based on the provided configuration.
  109. func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
  110. var wg sync.WaitGroup
  111. // Initialize states for all configured MCPs
  112. for name, m := range cfg.MCP {
  113. if m.Disabled {
  114. updateState(name, StateDisabled, nil, nil, Counts{})
  115. slog.Debug("skipping disabled mcp", "name", name)
  116. continue
  117. }
  118. // Set initial starting state
  119. updateState(name, StateStarting, nil, nil, Counts{})
  120. wg.Add(1)
  121. go func(name string, m config.MCPConfig) {
  122. defer func() {
  123. wg.Done()
  124. if r := recover(); r != nil {
  125. var err error
  126. switch v := r.(type) {
  127. case error:
  128. err = v
  129. case string:
  130. err = fmt.Errorf("panic: %s", v)
  131. default:
  132. err = fmt.Errorf("panic: %v", v)
  133. }
  134. updateState(name, StateError, err, nil, Counts{})
  135. slog.Error("panic in mcp client initialization", "error", err, "name", name)
  136. }
  137. }()
  138. ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
  139. defer cancel()
  140. session, err := createSession(ctx, name, m, cfg.Resolver())
  141. if err != nil {
  142. return
  143. }
  144. tools, err := getTools(ctx, session)
  145. if err != nil {
  146. slog.Error("error listing tools", "error", err)
  147. updateState(name, StateError, err, nil, Counts{})
  148. session.Close()
  149. return
  150. }
  151. prompts, err := getPrompts(ctx, session)
  152. if err != nil {
  153. slog.Error("error listing prompts", "error", err)
  154. updateState(name, StateError, err, nil, Counts{})
  155. session.Close()
  156. return
  157. }
  158. updateTools(name, tools)
  159. updatePrompts(name, prompts)
  160. sessions.Set(name, session)
  161. updateState(name, StateConnected, nil, session, Counts{
  162. Tools: len(tools),
  163. Prompts: len(prompts),
  164. })
  165. }(name, m)
  166. }
  167. wg.Wait()
  168. }
  169. func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
  170. sess, ok := sessions.Get(name)
  171. if !ok {
  172. return nil, fmt.Errorf("mcp '%s' not available", name)
  173. }
  174. cfg := config.Get()
  175. m := cfg.MCP[name]
  176. state, _ := states.Get(name)
  177. timeout := mcpTimeout(m)
  178. pingCtx, cancel := context.WithTimeout(ctx, timeout)
  179. defer cancel()
  180. err := sess.Ping(pingCtx, nil)
  181. if err == nil {
  182. return sess, nil
  183. }
  184. updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
  185. sess, err = createSession(ctx, name, m, cfg.Resolver())
  186. if err != nil {
  187. return nil, err
  188. }
  189. updateState(name, StateConnected, nil, sess, state.Counts)
  190. sessions.Set(name, sess)
  191. return sess, nil
  192. }
  193. // updateState updates the state of an MCP client and publishes an event
  194. func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
  195. info := ClientInfo{
  196. Name: name,
  197. State: state,
  198. Error: err,
  199. Client: client,
  200. Counts: counts,
  201. }
  202. switch state {
  203. case StateConnected:
  204. info.ConnectedAt = time.Now()
  205. case StateError:
  206. updateTools(name, nil)
  207. sessions.Del(name)
  208. }
  209. states.Set(name, info)
  210. // Publish state change event
  211. broker.Publish(pubsub.UpdatedEvent, Event{
  212. Type: EventStateChanged,
  213. Name: name,
  214. State: state,
  215. Error: err,
  216. Counts: counts,
  217. })
  218. }
  219. func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
  220. timeout := mcpTimeout(m)
  221. mcpCtx, cancel := context.WithCancel(ctx)
  222. cancelTimer := time.AfterFunc(timeout, cancel)
  223. transport, err := createTransport(mcpCtx, m, resolver)
  224. if err != nil {
  225. updateState(name, StateError, err, nil, Counts{})
  226. slog.Error("error creating mcp client", "error", err, "name", name)
  227. cancel()
  228. cancelTimer.Stop()
  229. return nil, err
  230. }
  231. client := mcp.NewClient(
  232. &mcp.Implementation{
  233. Name: "crush",
  234. Version: version.Version,
  235. Title: "Crush",
  236. },
  237. &mcp.ClientOptions{
  238. ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
  239. broker.Publish(pubsub.UpdatedEvent, Event{
  240. Type: EventToolsListChanged,
  241. Name: name,
  242. })
  243. },
  244. PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
  245. broker.Publish(pubsub.UpdatedEvent, Event{
  246. Type: EventPromptsListChanged,
  247. Name: name,
  248. })
  249. },
  250. LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
  251. slog.Info("mcp log", "name", name, "data", req.Params.Data)
  252. },
  253. KeepAlive: time.Minute * 10,
  254. },
  255. )
  256. session, err := client.Connect(mcpCtx, transport, nil)
  257. if err != nil {
  258. err = maybeStdioErr(err, transport)
  259. updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
  260. slog.Error("error starting mcp client", "error", err, "name", name)
  261. cancel()
  262. cancelTimer.Stop()
  263. return nil, err
  264. }
  265. cancelTimer.Stop()
  266. slog.Info("Initialized mcp client", "name", name)
  267. return session, nil
  268. }
  269. // maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
  270. // to parse, and the cli will then close it, causing the EOF error.
  271. // so, if we got an EOF err, and the transport is STDIO, we try to exec it
  272. // again with a timeout and collect the output so we can add details to the
  273. // error.
  274. // this happens particularly when starting things with npx, e.g. if node can't
  275. // be found or some other error like that.
  276. func maybeStdioErr(err error, transport mcp.Transport) error {
  277. if !errors.Is(err, io.EOF) {
  278. return err
  279. }
  280. ct, ok := transport.(*mcp.CommandTransport)
  281. if !ok {
  282. return err
  283. }
  284. if err2 := stdioCheck(ct.Command); err2 != nil {
  285. err = errors.Join(err, err2)
  286. }
  287. return err
  288. }
  289. func maybeTimeoutErr(err error, timeout time.Duration) error {
  290. if errors.Is(err, context.Canceled) {
  291. return fmt.Errorf("timed out after %s", timeout)
  292. }
  293. return err
  294. }
  295. func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
  296. switch m.Type {
  297. case config.MCPStdio:
  298. command, err := resolver.ResolveValue(m.Command)
  299. if err != nil {
  300. return nil, fmt.Errorf("invalid mcp command: %w", err)
  301. }
  302. if strings.TrimSpace(command) == "" {
  303. return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
  304. }
  305. cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
  306. cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
  307. return &mcp.CommandTransport{
  308. Command: cmd,
  309. }, nil
  310. case config.MCPHttp:
  311. if strings.TrimSpace(m.URL) == "" {
  312. return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
  313. }
  314. client := &http.Client{
  315. Transport: &headerRoundTripper{
  316. headers: m.ResolvedHeaders(),
  317. },
  318. }
  319. return &mcp.StreamableClientTransport{
  320. Endpoint: m.URL,
  321. HTTPClient: client,
  322. }, nil
  323. case config.MCPSSE:
  324. if strings.TrimSpace(m.URL) == "" {
  325. return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
  326. }
  327. client := &http.Client{
  328. Transport: &headerRoundTripper{
  329. headers: m.ResolvedHeaders(),
  330. },
  331. }
  332. return &mcp.SSEClientTransport{
  333. Endpoint: m.URL,
  334. HTTPClient: client,
  335. }, nil
  336. default:
  337. return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
  338. }
  339. }
  340. type headerRoundTripper struct {
  341. headers map[string]string
  342. }
  343. func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  344. for k, v := range rt.headers {
  345. req.Header.Set(k, v)
  346. }
  347. return http.DefaultTransport.RoundTrip(req)
  348. }
  349. func mcpTimeout(m config.MCPConfig) time.Duration {
  350. return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
  351. }
  352. func stdioCheck(old *exec.Cmd) error {
  353. ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
  354. defer cancel()
  355. cmd := exec.CommandContext(ctx, old.Path, old.Args...)
  356. cmd.Env = old.Env
  357. out, err := cmd.CombinedOutput()
  358. if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
  359. return nil
  360. }
  361. return fmt.Errorf("%w: %s", err, string(out))
  362. }