| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511 |
- // Package mcp provides functionality for managing Model Context Protocol (MCP)
- // clients within the Crush application.
- package mcp
- import (
- "cmp"
- "context"
- "errors"
- "fmt"
- "io"
- "log/slog"
- "net/http"
- "os"
- "os/exec"
- "strings"
- "sync"
- "time"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/csync"
- "github.com/charmbracelet/crush/internal/home"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/pubsub"
- "github.com/charmbracelet/crush/internal/version"
- "github.com/modelcontextprotocol/go-sdk/mcp"
- )
- func parseLevel(level mcp.LoggingLevel) slog.Level {
- switch level {
- case "info":
- return slog.LevelInfo
- case "notice":
- return slog.LevelInfo
- case "warning":
- return slog.LevelWarn
- default:
- return slog.LevelDebug
- }
- }
- // ClientSession wraps an mcp.ClientSession with a context cancel function so
- // that the context created during session establishment is properly cleaned up
- // on close.
- type ClientSession struct {
- *mcp.ClientSession
- cancel context.CancelFunc
- }
- // Close cancels the session context and then closes the underlying session.
- func (s *ClientSession) Close() error {
- s.cancel()
- return s.ClientSession.Close()
- }
- var (
- sessions = csync.NewMap[string, *ClientSession]()
- states = csync.NewMap[string, ClientInfo]()
- broker = pubsub.NewBroker[Event]()
- initOnce sync.Once
- initDone = make(chan struct{})
- )
- // State represents the current state of an MCP client
- type State int
- const (
- StateDisabled State = iota
- StateStarting
- StateConnected
- StateError
- )
- func (s State) String() string {
- switch s {
- case StateDisabled:
- return "disabled"
- case StateStarting:
- return "starting"
- case StateConnected:
- return "connected"
- case StateError:
- return "error"
- default:
- return "unknown"
- }
- }
- // EventType represents the type of MCP event
- type EventType uint
- const (
- EventStateChanged EventType = iota
- EventToolsListChanged
- EventPromptsListChanged
- EventResourcesListChanged
- )
- // Event represents an event in the MCP system
- type Event struct {
- Type EventType
- Name string
- State State
- Error error
- Counts Counts
- }
- // Counts number of available tools, prompts, etc.
- type Counts struct {
- Tools int
- Prompts int
- Resources int
- }
- // ClientInfo holds information about an MCP client's state
- type ClientInfo struct {
- Name string
- State State
- Error error
- Client *ClientSession
- Counts Counts
- ConnectedAt time.Time
- }
- // SubscribeEvents returns a channel for MCP events
- func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
- return broker.Subscribe(ctx)
- }
- // GetStates returns the current state of all MCP clients
- func GetStates() map[string]ClientInfo {
- return states.Copy()
- }
- // GetState returns the state of a specific MCP client
- func GetState(name string) (ClientInfo, bool) {
- return states.Get(name)
- }
- // Close closes all MCP clients. This should be called during application shutdown.
- func Close(ctx context.Context) error {
- var wg sync.WaitGroup
- for name, session := range sessions.Seq2() {
- wg.Go(func() {
- done := make(chan error, 1)
- go func() {
- done <- session.Close()
- }()
- select {
- case err := <-done:
- if err != nil &&
- !errors.Is(err, io.EOF) &&
- !errors.Is(err, context.Canceled) &&
- err.Error() != "signal: killed" {
- slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
- }
- case <-ctx.Done():
- }
- })
- }
- wg.Wait()
- broker.Shutdown()
- return nil
- }
- // Initialize initializes MCP clients based on the provided configuration.
- func Initialize(ctx context.Context, permissions permission.Service, cfg *config.ConfigStore) {
- slog.Info("Initializing MCP clients")
- var wg sync.WaitGroup
- // Initialize states for all configured MCPs
- for name, m := range cfg.Config().MCP {
- if m.Disabled {
- updateState(name, StateDisabled, nil, nil, Counts{})
- slog.Debug("Skipping disabled MCP", "name", name)
- continue
- }
- // Set initial starting state
- wg.Add(1)
- go func(name string, m config.MCPConfig) {
- defer func() {
- wg.Done()
- if r := recover(); r != nil {
- var err error
- switch v := r.(type) {
- case error:
- err = v
- case string:
- err = fmt.Errorf("panic: %s", v)
- default:
- err = fmt.Errorf("panic: %v", v)
- }
- updateState(name, StateError, err, nil, Counts{})
- slog.Error("Panic in MCP client initialization", "error", err, "name", name)
- }
- }()
- if err := initClient(ctx, cfg, name, m, cfg.Resolver()); err != nil {
- slog.Debug("failed to initialize mcp client", "name", name, "error", err)
- }
- }(name, m)
- }
- wg.Wait()
- initOnce.Do(func() { close(initDone) })
- }
- // WaitForInit blocks until MCP initialization is complete.
- // If Initialize was never called, this returns immediately.
- func WaitForInit(ctx context.Context) error {
- select {
- case <-initDone:
- return nil
- case <-ctx.Done():
- return ctx.Err()
- }
- }
- // InitializeSingle initializes a single MCP client by name.
- func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore) error {
- m, exists := cfg.Config().MCP[name]
- if !exists {
- return fmt.Errorf("mcp '%s' not found in configuration", name)
- }
- if m.Disabled {
- updateState(name, StateDisabled, nil, nil, Counts{})
- slog.Debug("skipping disabled mcp", "name", name)
- return nil
- }
- return initClient(ctx, cfg, name, m, cfg.Resolver())
- }
- // initClient initializes a single MCP client with the given configuration.
- func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m config.MCPConfig, resolver config.VariableResolver) error {
- // Set initial starting state.
- updateState(name, StateStarting, nil, nil, Counts{})
- // createSession handles its own timeout internally.
- session, err := createSession(ctx, name, m, resolver)
- if err != nil {
- return err
- }
- tools, err := getTools(ctx, session)
- if err != nil {
- slog.Error("Error listing tools", "error", err)
- updateState(name, StateError, err, nil, Counts{})
- session.Close()
- return err
- }
- prompts, err := getPrompts(ctx, session)
- if err != nil {
- slog.Error("Error listing prompts", "error", err)
- updateState(name, StateError, err, nil, Counts{})
- session.Close()
- return err
- }
- toolCount := updateTools(cfg, name, tools)
- updatePrompts(name, prompts)
- sessions.Set(name, session)
- updateState(name, StateConnected, nil, session, Counts{
- Tools: toolCount,
- Prompts: len(prompts),
- })
- return nil
- }
- // DisableSingle disables and closes a single MCP client by name.
- func DisableSingle(cfg *config.ConfigStore, name string) error {
- session, ok := sessions.Get(name)
- if ok {
- if err := session.Close(); err != nil &&
- !errors.Is(err, io.EOF) &&
- !errors.Is(err, context.Canceled) &&
- err.Error() != "signal: killed" {
- slog.Warn("error closing mcp session", "name", name, "error", err)
- }
- sessions.Del(name)
- }
- // Clear tools and prompts for this MCP.
- updateTools(cfg, name, nil)
- updatePrompts(name, nil)
- // Update state to disabled.
- updateState(name, StateDisabled, nil, nil, Counts{})
- slog.Info("Disabled mcp client", "name", name)
- return nil
- }
- func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) {
- sess, ok := sessions.Get(name)
- if !ok {
- return nil, fmt.Errorf("mcp '%s' not available", name)
- }
- m := cfg.Config().MCP[name]
- state, _ := states.Get(name)
- timeout := mcpTimeout(m)
- pingCtx, cancel := context.WithTimeout(ctx, timeout)
- defer cancel()
- err := sess.Ping(pingCtx, nil)
- if err == nil {
- return sess, nil
- }
- updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
- sess, err = createSession(ctx, name, m, cfg.Resolver())
- if err != nil {
- return nil, err
- }
- updateState(name, StateConnected, nil, sess, state.Counts)
- sessions.Set(name, sess)
- return sess, nil
- }
- // updateState updates the state of an MCP client and publishes an event
- func updateState(name string, state State, err error, client *ClientSession, counts Counts) {
- info := ClientInfo{
- Name: name,
- State: state,
- Error: err,
- Client: client,
- Counts: counts,
- }
- switch state {
- case StateConnected:
- info.ConnectedAt = time.Now()
- case StateError:
- sessions.Del(name)
- }
- states.Set(name, info)
- // Publish state change event
- broker.Publish(pubsub.UpdatedEvent, Event{
- Type: EventStateChanged,
- Name: name,
- State: state,
- Error: err,
- Counts: counts,
- })
- }
- func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) {
- timeout := mcpTimeout(m)
- mcpCtx, cancel := context.WithCancel(ctx)
- cancelTimer := time.AfterFunc(timeout, cancel)
- transport, err := createTransport(mcpCtx, m, resolver)
- if err != nil {
- updateState(name, StateError, err, nil, Counts{})
- slog.Error("Error creating MCP client", "error", err, "name", name)
- cancel()
- cancelTimer.Stop()
- return nil, err
- }
- client := mcp.NewClient(
- &mcp.Implementation{
- Name: "crush",
- Version: version.Version,
- Title: "Crush",
- },
- &mcp.ClientOptions{
- ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
- broker.Publish(pubsub.UpdatedEvent, Event{
- Type: EventToolsListChanged,
- Name: name,
- })
- },
- PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
- broker.Publish(pubsub.UpdatedEvent, Event{
- Type: EventPromptsListChanged,
- Name: name,
- })
- },
- ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
- broker.Publish(pubsub.UpdatedEvent, Event{
- Type: EventResourcesListChanged,
- Name: name,
- })
- },
- LoggingMessageHandler: func(ctx context.Context, req *mcp.LoggingMessageRequest) {
- level := parseLevel(req.Params.Level)
- slog.Log(ctx, level, "MCP log", "name", name, "logger", req.Params.Logger, "data", req.Params.Data)
- },
- },
- )
- session, err := client.Connect(mcpCtx, transport, nil)
- if err != nil {
- err = maybeStdioErr(err, transport)
- updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
- slog.Error("MCP client failed to initialize", "error", err, "name", name)
- cancel()
- cancelTimer.Stop()
- return nil, err
- }
- cancelTimer.Stop()
- slog.Debug("MCP client initialized", "name", name)
- return &ClientSession{session, cancel}, nil
- }
- // maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
- // to parse, and the cli will then close it, causing the EOF error.
- // so, if we got an EOF err, and the transport is STDIO, we try to exec it
- // again with a timeout and collect the output so we can add details to the
- // error.
- // this happens particularly when starting things with npx, e.g. if node can't
- // be found or some other error like that.
- func maybeStdioErr(err error, transport mcp.Transport) error {
- if !errors.Is(err, io.EOF) {
- return err
- }
- ct, ok := transport.(*mcp.CommandTransport)
- if !ok {
- return err
- }
- if err2 := stdioCheck(ct.Command); err2 != nil {
- err = errors.Join(err, err2)
- }
- return err
- }
- func maybeTimeoutErr(err error, timeout time.Duration) error {
- if errors.Is(err, context.Canceled) {
- return fmt.Errorf("timed out after %s", timeout)
- }
- return err
- }
- func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
- switch m.Type {
- case config.MCPStdio:
- command, err := resolver.ResolveValue(m.Command)
- if err != nil {
- return nil, fmt.Errorf("invalid mcp command: %w", err)
- }
- if strings.TrimSpace(command) == "" {
- return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
- }
- cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
- cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
- return &mcp.CommandTransport{
- Command: cmd,
- }, nil
- case config.MCPHttp:
- if strings.TrimSpace(m.URL) == "" {
- return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
- }
- client := &http.Client{
- Transport: &headerRoundTripper{
- headers: m.ResolvedHeaders(),
- },
- }
- return &mcp.StreamableClientTransport{
- Endpoint: m.URL,
- HTTPClient: client,
- }, nil
- case config.MCPSSE:
- if strings.TrimSpace(m.URL) == "" {
- return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
- }
- client := &http.Client{
- Transport: &headerRoundTripper{
- headers: m.ResolvedHeaders(),
- },
- }
- return &mcp.SSEClientTransport{
- Endpoint: m.URL,
- HTTPClient: client,
- }, nil
- default:
- return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
- }
- }
- type headerRoundTripper struct {
- headers map[string]string
- }
- func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
- for k, v := range rt.headers {
- req.Header.Set(k, v)
- }
- return http.DefaultTransport.RoundTrip(req)
- }
- func mcpTimeout(m config.MCPConfig) time.Duration {
- return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
- }
- func stdioCheck(old *exec.Cmd) error {
- ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
- defer cancel()
- cmd := exec.CommandContext(ctx, old.Path, old.Args...)
- cmd.Env = old.Env
- out, err := cmd.CombinedOutput()
- if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
- return nil
- }
- return fmt.Errorf("%w: %s", err, string(out))
- }
|