tools.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package mcp
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "iter"
  7. "log/slog"
  8. "strings"
  9. "github.com/charmbracelet/crush/internal/csync"
  10. "github.com/modelcontextprotocol/go-sdk/mcp"
  11. )
  12. type Tool = mcp.Tool
  13. var allTools = csync.NewMap[string, []*Tool]()
  14. // Tools returns all available MCP tools.
  15. func Tools() iter.Seq2[string, []*Tool] {
  16. return allTools.Seq2()
  17. }
  18. // RunTool runs an MCP tool with the given input parameters.
  19. func RunTool(ctx context.Context, name, toolName string, input string) (string, error) {
  20. var args map[string]any
  21. if err := json.Unmarshal([]byte(input), &args); err != nil {
  22. return "", fmt.Errorf("error parsing parameters: %s", err)
  23. }
  24. c, err := getOrRenewClient(ctx, name)
  25. if err != nil {
  26. return "", err
  27. }
  28. result, err := c.CallTool(ctx, &mcp.CallToolParams{
  29. Name: toolName,
  30. Arguments: args,
  31. })
  32. if err != nil {
  33. return "", err
  34. }
  35. output := make([]string, 0, len(result.Content))
  36. for _, v := range result.Content {
  37. if vv, ok := v.(*mcp.TextContent); ok {
  38. output = append(output, vv.Text)
  39. } else {
  40. output = append(output, fmt.Sprintf("%v", v))
  41. }
  42. }
  43. return strings.Join(output, "\n"), nil
  44. }
  45. // RefreshTools gets the updated list of tools from the MCP and updates the
  46. // global state.
  47. func RefreshTools(ctx context.Context, name string) {
  48. session, ok := sessions.Get(name)
  49. if !ok {
  50. slog.Warn("refresh tools: no session", "name", name)
  51. return
  52. }
  53. tools, err := getTools(ctx, session)
  54. if err != nil {
  55. updateState(name, StateError, err, nil, Counts{})
  56. return
  57. }
  58. updateTools(name, tools)
  59. prev, _ := states.Get(name)
  60. prev.Counts.Tools = len(tools)
  61. updateState(name, StateConnected, nil, session, prev.Counts)
  62. }
  63. func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
  64. if session.InitializeResult().Capabilities.Tools == nil {
  65. return nil, nil
  66. }
  67. result, err := session.ListTools(ctx, &mcp.ListToolsParams{})
  68. if err != nil {
  69. return nil, err
  70. }
  71. return result.Tools, nil
  72. }
  73. func updateTools(name string, tools []*Tool) {
  74. if len(tools) == 0 {
  75. allTools.Del(name)
  76. return
  77. }
  78. allTools.Set(name, tools)
  79. }