|
|
@@ -1,94 +1,32 @@
|
|
|
package tools
|
|
|
|
|
|
import (
|
|
|
- "cmp"
|
|
|
"context"
|
|
|
- "encoding/json"
|
|
|
- "errors"
|
|
|
"fmt"
|
|
|
- "io"
|
|
|
- "log/slog"
|
|
|
- "maps"
|
|
|
- "net/http"
|
|
|
- "os"
|
|
|
- "os/exec"
|
|
|
- "slices"
|
|
|
- "strings"
|
|
|
- "sync"
|
|
|
- "time"
|
|
|
|
|
|
"charm.land/fantasy"
|
|
|
- "github.com/charmbracelet/crush/internal/config"
|
|
|
- "github.com/charmbracelet/crush/internal/csync"
|
|
|
- "github.com/charmbracelet/crush/internal/home"
|
|
|
+ "github.com/charmbracelet/crush/internal/agent/tools/mcp"
|
|
|
"github.com/charmbracelet/crush/internal/permission"
|
|
|
- "github.com/charmbracelet/crush/internal/pubsub"
|
|
|
- "github.com/charmbracelet/crush/internal/version"
|
|
|
- "github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
)
|
|
|
|
|
|
-// MCPState represents the current state of an MCP client
|
|
|
-type MCPState int
|
|
|
-
|
|
|
-const (
|
|
|
- MCPStateDisabled MCPState = iota
|
|
|
- MCPStateStarting
|
|
|
- MCPStateConnected
|
|
|
- MCPStateError
|
|
|
-)
|
|
|
-
|
|
|
-func (s MCPState) String() string {
|
|
|
- switch s {
|
|
|
- case MCPStateDisabled:
|
|
|
- return "disabled"
|
|
|
- case MCPStateStarting:
|
|
|
- return "starting"
|
|
|
- case MCPStateConnected:
|
|
|
- return "connected"
|
|
|
- case MCPStateError:
|
|
|
- return "error"
|
|
|
- default:
|
|
|
- return "unknown"
|
|
|
+// GetMCPTools gets all the currently available MCP tools.
|
|
|
+func GetMCPTools(permissions permission.Service, wd string) []*Tool {
|
|
|
+ var result []*Tool
|
|
|
+ for mcpName, tools := range mcp.Tools() {
|
|
|
+ for _, tool := range tools {
|
|
|
+ result = append(result, &Tool{
|
|
|
+ mcpName: mcpName,
|
|
|
+ tool: tool,
|
|
|
+ permissions: permissions,
|
|
|
+ workingDir: wd,
|
|
|
+ })
|
|
|
+ }
|
|
|
}
|
|
|
+ return result
|
|
|
}
|
|
|
|
|
|
-// MCPEventType represents the type of MCP event
|
|
|
-type MCPEventType string
|
|
|
-
|
|
|
-const (
|
|
|
- MCPEventStateChanged MCPEventType = "state_changed"
|
|
|
- MCPEventToolsListChanged MCPEventType = "tools_list_changed"
|
|
|
-)
|
|
|
-
|
|
|
-// MCPEvent represents an event in the MCP system
|
|
|
-type MCPEvent struct {
|
|
|
- Type MCPEventType
|
|
|
- Name string
|
|
|
- State MCPState
|
|
|
- Error error
|
|
|
- ToolCount int
|
|
|
-}
|
|
|
-
|
|
|
-// MCPClientInfo holds information about an MCP client's state
|
|
|
-type MCPClientInfo struct {
|
|
|
- Name string
|
|
|
- State MCPState
|
|
|
- Error error
|
|
|
- Client *mcp.ClientSession
|
|
|
- ToolCount int
|
|
|
- ConnectedAt time.Time
|
|
|
-}
|
|
|
-
|
|
|
-var (
|
|
|
- mcpToolsOnce sync.Once
|
|
|
- mcpTools = csync.NewMap[string, *McpTool]()
|
|
|
- mcpClient2Tools = csync.NewMap[string, []*McpTool]()
|
|
|
- mcpClients = csync.NewMap[string, *mcp.ClientSession]()
|
|
|
- mcpStates = csync.NewMap[string, MCPClientInfo]()
|
|
|
- mcpBroker = pubsub.NewBroker[MCPEvent]()
|
|
|
-)
|
|
|
-
|
|
|
-type McpTool struct {
|
|
|
+// Tool is a tool from a MCP.
|
|
|
+type Tool struct {
|
|
|
mcpName string
|
|
|
tool *mcp.Tool
|
|
|
permissions permission.Service
|
|
|
@@ -96,31 +34,31 @@ type McpTool struct {
|
|
|
providerOptions fantasy.ProviderOptions
|
|
|
}
|
|
|
|
|
|
-func (m *McpTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
|
|
+func (m *Tool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
|
|
m.providerOptions = opts
|
|
|
}
|
|
|
|
|
|
-func (m *McpTool) ProviderOptions() fantasy.ProviderOptions {
|
|
|
+func (m *Tool) ProviderOptions() fantasy.ProviderOptions {
|
|
|
return m.providerOptions
|
|
|
}
|
|
|
|
|
|
-func (m *McpTool) Name() string {
|
|
|
+func (m *Tool) Name() string {
|
|
|
return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
|
|
|
}
|
|
|
|
|
|
-func (m *McpTool) MCP() string {
|
|
|
+func (m *Tool) MCP() string {
|
|
|
return m.mcpName
|
|
|
}
|
|
|
|
|
|
-func (m *McpTool) MCPToolName() string {
|
|
|
+func (m *Tool) MCPToolName() string {
|
|
|
return m.tool.Name
|
|
|
}
|
|
|
|
|
|
-func (b *McpTool) Info() fantasy.ToolInfo {
|
|
|
+func (m *Tool) Info() fantasy.ToolInfo {
|
|
|
parameters := make(map[string]any)
|
|
|
required := make([]string, 0)
|
|
|
|
|
|
- if input, ok := b.tool.InputSchema.(map[string]any); ok {
|
|
|
+ if input, ok := m.tool.InputSchema.(map[string]any); ok {
|
|
|
if props, ok := input["properties"].(map[string]any); ok {
|
|
|
parameters = props
|
|
|
}
|
|
|
@@ -138,72 +76,14 @@ func (b *McpTool) Info() fantasy.ToolInfo {
|
|
|
}
|
|
|
|
|
|
return fantasy.ToolInfo{
|
|
|
- Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
|
|
|
- Description: b.tool.Description,
|
|
|
+ Name: m.Name(),
|
|
|
+ Description: m.tool.Description,
|
|
|
Parameters: parameters,
|
|
|
Required: required,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) {
|
|
|
- var args map[string]any
|
|
|
- if err := json.Unmarshal([]byte(input), &args); err != nil {
|
|
|
- return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
|
|
|
- }
|
|
|
-
|
|
|
- c, err := getOrRenewClient(ctx, name)
|
|
|
- if err != nil {
|
|
|
- return fantasy.NewTextErrorResponse(err.Error()), nil
|
|
|
- }
|
|
|
- result, err := c.CallTool(ctx, &mcp.CallToolParams{
|
|
|
- Name: toolName,
|
|
|
- Arguments: args,
|
|
|
- })
|
|
|
- if err != nil {
|
|
|
- return fantasy.NewTextErrorResponse(err.Error()), nil
|
|
|
- }
|
|
|
-
|
|
|
- output := make([]string, 0, len(result.Content))
|
|
|
- for _, v := range result.Content {
|
|
|
- if vv, ok := v.(*mcp.TextContent); ok {
|
|
|
- output = append(output, vv.Text)
|
|
|
- } else {
|
|
|
- output = append(output, fmt.Sprintf("%v", v))
|
|
|
- }
|
|
|
- }
|
|
|
- return fantasy.NewTextResponse(strings.Join(output, "\n")), nil
|
|
|
-}
|
|
|
-
|
|
|
-func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
|
|
|
- sess, ok := mcpClients.Get(name)
|
|
|
- if !ok {
|
|
|
- return nil, fmt.Errorf("mcp '%s' not available", name)
|
|
|
- }
|
|
|
-
|
|
|
- cfg := config.Get()
|
|
|
- m := cfg.MCP[name]
|
|
|
- state, _ := mcpStates.Get(name)
|
|
|
-
|
|
|
- timeout := mcpTimeout(m)
|
|
|
- pingCtx, cancel := context.WithTimeout(ctx, timeout)
|
|
|
- defer cancel()
|
|
|
- err := sess.Ping(pingCtx, nil)
|
|
|
- if err == nil {
|
|
|
- return sess, nil
|
|
|
- }
|
|
|
- updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
|
|
|
-
|
|
|
- sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
- updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
|
|
|
- mcpClients.Set(name, sess)
|
|
|
- return sess, nil
|
|
|
-}
|
|
|
-
|
|
|
-func (m *McpTool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
|
|
+func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
|
|
sessionID := GetSessionFromContext(ctx)
|
|
|
if sessionID == "" {
|
|
|
return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
|
|
|
@@ -224,302 +104,9 @@ func (m *McpTool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.Too
|
|
|
return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
|
|
|
}
|
|
|
|
|
|
- return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
|
|
|
-}
|
|
|
-
|
|
|
-func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*McpTool, error) {
|
|
|
- result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
|
|
|
+ content, err := mcp.RunTool(ctx, m.mcpName, m.tool.Name, params.Input)
|
|
|
if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- mcpTools := make([]*McpTool, 0, len(result.Tools))
|
|
|
- for _, tool := range result.Tools {
|
|
|
- mcpTools = append(mcpTools, &McpTool{
|
|
|
- mcpName: name,
|
|
|
- tool: tool,
|
|
|
- permissions: permissions,
|
|
|
- workingDir: workingDir,
|
|
|
- })
|
|
|
- }
|
|
|
- return mcpTools, nil
|
|
|
-}
|
|
|
-
|
|
|
-// SubscribeMCPEvents returns a channel for MCP events
|
|
|
-func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
|
|
|
- return mcpBroker.Subscribe(ctx)
|
|
|
-}
|
|
|
-
|
|
|
-// GetMCPStates returns the current state of all MCP clients
|
|
|
-func GetMCPStates() map[string]MCPClientInfo {
|
|
|
- return maps.Collect(mcpStates.Seq2())
|
|
|
-}
|
|
|
-
|
|
|
-// GetMCPState returns the state of a specific MCP client
|
|
|
-func GetMCPState(name string) (MCPClientInfo, bool) {
|
|
|
- return mcpStates.Get(name)
|
|
|
-}
|
|
|
-
|
|
|
-// updateMCPState updates the state of an MCP client and publishes an event
|
|
|
-func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
|
|
|
- info := MCPClientInfo{
|
|
|
- Name: name,
|
|
|
- State: state,
|
|
|
- Error: err,
|
|
|
- Client: client,
|
|
|
- ToolCount: toolCount,
|
|
|
- }
|
|
|
- switch state {
|
|
|
- case MCPStateConnected:
|
|
|
- info.ConnectedAt = time.Now()
|
|
|
- case MCPStateError:
|
|
|
- updateMcpTools(name, nil)
|
|
|
- mcpClients.Del(name)
|
|
|
- }
|
|
|
- mcpStates.Set(name, info)
|
|
|
-
|
|
|
- // Publish state change event
|
|
|
- mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
|
|
|
- Type: MCPEventStateChanged,
|
|
|
- Name: name,
|
|
|
- State: state,
|
|
|
- Error: err,
|
|
|
- ToolCount: toolCount,
|
|
|
- })
|
|
|
-}
|
|
|
-
|
|
|
-// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
|
|
|
-func CloseMCPClients() error {
|
|
|
- var errs []error
|
|
|
- for name, c := range mcpClients.Seq2() {
|
|
|
- if err := c.Close(); err != nil &&
|
|
|
- !errors.Is(err, io.EOF) &&
|
|
|
- !errors.Is(err, context.Canceled) &&
|
|
|
- err.Error() != "signal: killed" {
|
|
|
- errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
|
|
|
- }
|
|
|
- }
|
|
|
- mcpBroker.Shutdown()
|
|
|
- return errors.Join(errs...)
|
|
|
-}
|
|
|
-
|
|
|
-func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool {
|
|
|
- mcpToolsOnce.Do(func() {
|
|
|
- var wg sync.WaitGroup
|
|
|
- // Initialize states for all configured MCPs
|
|
|
- for name, m := range cfg.MCP {
|
|
|
- if m.Disabled {
|
|
|
- updateMCPState(name, MCPStateDisabled, nil, nil, 0)
|
|
|
- slog.Debug("skipping disabled mcp", "name", name)
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- // Set initial starting state
|
|
|
- updateMCPState(name, MCPStateStarting, nil, nil, 0)
|
|
|
-
|
|
|
- wg.Add(1)
|
|
|
- go func(name string, m config.MCPConfig) {
|
|
|
- defer func() {
|
|
|
- wg.Done()
|
|
|
- if r := recover(); r != nil {
|
|
|
- var err error
|
|
|
- switch v := r.(type) {
|
|
|
- case error:
|
|
|
- err = v
|
|
|
- case string:
|
|
|
- err = fmt.Errorf("panic: %s", v)
|
|
|
- default:
|
|
|
- err = fmt.Errorf("panic: %v", v)
|
|
|
- }
|
|
|
- updateMCPState(name, MCPStateError, err, nil, 0)
|
|
|
- slog.Error("panic in mcp client initialization", "error", err, "name", name)
|
|
|
- }
|
|
|
- }()
|
|
|
-
|
|
|
- ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
|
|
|
- defer cancel()
|
|
|
-
|
|
|
- c, err := createMCPSession(ctx, name, m, cfg.Resolver())
|
|
|
- if err != nil {
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- mcpClients.Set(name, c)
|
|
|
-
|
|
|
- tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
|
|
|
- if err != nil {
|
|
|
- slog.Error("error listing tools", "error", err)
|
|
|
- updateMCPState(name, MCPStateError, err, nil, 0)
|
|
|
- c.Close()
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- updateMcpTools(name, tools)
|
|
|
- mcpClients.Set(name, c)
|
|
|
- updateMCPState(name, MCPStateConnected, nil, c, len(tools))
|
|
|
- }(name, m)
|
|
|
- }
|
|
|
- wg.Wait()
|
|
|
- })
|
|
|
- return slices.Collect(mcpTools.Seq())
|
|
|
-}
|
|
|
-
|
|
|
-// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
|
|
|
-func updateMcpTools(mcpName string, tools []*McpTool) {
|
|
|
- if len(tools) == 0 {
|
|
|
- mcpClient2Tools.Del(mcpName)
|
|
|
- } else {
|
|
|
- mcpClient2Tools.Set(mcpName, tools)
|
|
|
- }
|
|
|
- for _, tools := range mcpClient2Tools.Seq2() {
|
|
|
- for _, t := range tools {
|
|
|
- mcpTools.Set(t.Name(), t)
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
|
|
|
- timeout := mcpTimeout(m)
|
|
|
- mcpCtx, cancel := context.WithCancel(ctx)
|
|
|
- cancelTimer := time.AfterFunc(timeout, cancel)
|
|
|
-
|
|
|
- transport, err := createMCPTransport(mcpCtx, m, resolver)
|
|
|
- if err != nil {
|
|
|
- updateMCPState(name, MCPStateError, err, nil, 0)
|
|
|
- slog.Error("error creating mcp client", "error", err, "name", name)
|
|
|
- cancel()
|
|
|
- cancelTimer.Stop()
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
- client := mcp.NewClient(
|
|
|
- &mcp.Implementation{
|
|
|
- Name: "crush",
|
|
|
- Version: version.Version,
|
|
|
- Title: "Crush",
|
|
|
- },
|
|
|
- &mcp.ClientOptions{
|
|
|
- ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
|
|
|
- mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
|
|
|
- Type: MCPEventToolsListChanged,
|
|
|
- Name: name,
|
|
|
- })
|
|
|
- },
|
|
|
- KeepAlive: time.Minute * 10,
|
|
|
- },
|
|
|
- )
|
|
|
-
|
|
|
- session, err := client.Connect(mcpCtx, transport, nil)
|
|
|
- if err != nil {
|
|
|
- err = maybeStdioErr(err, transport)
|
|
|
- updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
|
|
|
- slog.Error("error starting mcp client", "error", err, "name", name)
|
|
|
- cancel()
|
|
|
- cancelTimer.Stop()
|
|
|
- return nil, err
|
|
|
- }
|
|
|
-
|
|
|
- cancelTimer.Stop()
|
|
|
- slog.Info("Initialized mcp client", "name", name)
|
|
|
- return session, nil
|
|
|
-}
|
|
|
-
|
|
|
-// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
|
|
|
-// to parse, and the cli will then close it, causing the EOF error.
|
|
|
-// so, if we got an EOF err, and the transport is STDIO, we try to exec it
|
|
|
-// again with a timeout and collect the output so we can add details to the
|
|
|
-// error.
|
|
|
-// this happens particularly when starting things with npx, e.g. if node can't
|
|
|
-// be found or some other error like that.
|
|
|
-func maybeStdioErr(err error, transport mcp.Transport) error {
|
|
|
- if !errors.Is(err, io.EOF) {
|
|
|
- return err
|
|
|
- }
|
|
|
- ct, ok := transport.(*mcp.CommandTransport)
|
|
|
- if !ok {
|
|
|
- return err
|
|
|
- }
|
|
|
- if err2 := stdioMCPCheck(ct.Command); err2 != nil {
|
|
|
- err = errors.Join(err, err2)
|
|
|
- }
|
|
|
- return err
|
|
|
-}
|
|
|
-
|
|
|
-func maybeTimeoutErr(err error, timeout time.Duration) error {
|
|
|
- if errors.Is(err, context.Canceled) {
|
|
|
- return fmt.Errorf("timed out after %s", timeout)
|
|
|
- }
|
|
|
- return err
|
|
|
-}
|
|
|
-
|
|
|
-func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
|
|
|
- switch m.Type {
|
|
|
- case config.MCPStdio:
|
|
|
- command, err := resolver.ResolveValue(m.Command)
|
|
|
- if err != nil {
|
|
|
- return nil, fmt.Errorf("invalid mcp command: %w", err)
|
|
|
- }
|
|
|
- if strings.TrimSpace(command) == "" {
|
|
|
- return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
|
|
|
- }
|
|
|
- cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
|
|
|
- cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
|
|
|
- return &mcp.CommandTransport{
|
|
|
- Command: cmd,
|
|
|
- }, nil
|
|
|
- case config.MCPHttp:
|
|
|
- if strings.TrimSpace(m.URL) == "" {
|
|
|
- return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
|
|
|
- }
|
|
|
- client := &http.Client{
|
|
|
- Transport: &headerRoundTripper{
|
|
|
- headers: m.ResolvedHeaders(),
|
|
|
- },
|
|
|
- }
|
|
|
- return &mcp.StreamableClientTransport{
|
|
|
- Endpoint: m.URL,
|
|
|
- HTTPClient: client,
|
|
|
- }, nil
|
|
|
- case config.MCPSSE:
|
|
|
- if strings.TrimSpace(m.URL) == "" {
|
|
|
- return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
|
|
|
- }
|
|
|
- client := &http.Client{
|
|
|
- Transport: &headerRoundTripper{
|
|
|
- headers: m.ResolvedHeaders(),
|
|
|
- },
|
|
|
- }
|
|
|
- return &mcp.SSEClientTransport{
|
|
|
- Endpoint: m.URL,
|
|
|
- HTTPClient: client,
|
|
|
- }, nil
|
|
|
- default:
|
|
|
- return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-type headerRoundTripper struct {
|
|
|
- headers map[string]string
|
|
|
-}
|
|
|
-
|
|
|
-func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
- for k, v := range rt.headers {
|
|
|
- req.Header.Set(k, v)
|
|
|
- }
|
|
|
- return http.DefaultTransport.RoundTrip(req)
|
|
|
-}
|
|
|
-
|
|
|
-func mcpTimeout(m config.MCPConfig) time.Duration {
|
|
|
- return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
|
|
|
-}
|
|
|
-
|
|
|
-func stdioMCPCheck(old *exec.Cmd) error {
|
|
|
- ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
|
- defer cancel()
|
|
|
- cmd := exec.CommandContext(ctx, old.Path, old.Args...)
|
|
|
- cmd.Env = old.Env
|
|
|
- out, err := cmd.CombinedOutput()
|
|
|
- if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
|
|
- return nil
|
|
|
+ return fantasy.NewTextErrorResponse(err.Error()), nil
|
|
|
}
|
|
|
- return fmt.Errorf("%w: %s", err, string(out))
|
|
|
+ return fantasy.NewTextResponse(content), nil
|
|
|
}
|