commands.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. package commands
  2. import (
  3. "context"
  4. "io/fs"
  5. "os"
  6. "path/filepath"
  7. "regexp"
  8. "strings"
  9. "github.com/charmbracelet/crush/internal/agent/tools/mcp"
  10. "github.com/charmbracelet/crush/internal/config"
  11. "github.com/charmbracelet/crush/internal/home"
  12. )
  13. var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
  14. const (
  15. userCommandPrefix = "user:"
  16. projectCommandPrefix = "project:"
  17. )
  18. // Argument represents a command argument with its metadata.
  19. type Argument struct {
  20. ID string
  21. Title string
  22. Description string
  23. Required bool
  24. }
  25. // MCPPrompt represents a custom command loaded from an MCP server.
  26. type MCPPrompt struct {
  27. ID string
  28. Title string
  29. Description string
  30. PromptID string
  31. ClientID string
  32. Arguments []Argument
  33. }
  34. // CustomCommand represents a user-defined custom command loaded from markdown files.
  35. type CustomCommand struct {
  36. ID string
  37. Name string
  38. Content string
  39. Arguments []Argument
  40. }
  41. type commandSource struct {
  42. path string
  43. prefix string
  44. }
  45. // LoadCustomCommands loads custom commands from multiple sources including
  46. // XDG config directory, home directory, and project directory.
  47. func LoadCustomCommands(cfg *config.Config) ([]CustomCommand, error) {
  48. return loadAll(buildCommandSources(cfg))
  49. }
  50. // LoadMCPPrompts loads custom commands from available MCP servers.
  51. func LoadMCPPrompts() ([]MCPPrompt, error) {
  52. var commands []MCPPrompt
  53. for mcpName, prompts := range mcp.Prompts() {
  54. for _, prompt := range prompts {
  55. key := mcpName + ":" + prompt.Name
  56. var args []Argument
  57. for _, arg := range prompt.Arguments {
  58. title := arg.Title
  59. if title == "" {
  60. title = arg.Name
  61. }
  62. args = append(args, Argument{
  63. ID: arg.Name,
  64. Title: title,
  65. Description: arg.Description,
  66. Required: arg.Required,
  67. })
  68. }
  69. commands = append(commands, MCPPrompt{
  70. ID: key,
  71. Title: prompt.Title,
  72. Description: prompt.Description,
  73. PromptID: prompt.Name,
  74. ClientID: mcpName,
  75. Arguments: args,
  76. })
  77. }
  78. }
  79. return commands, nil
  80. }
  81. func buildCommandSources(cfg *config.Config) []commandSource {
  82. var sources []commandSource
  83. // XDG config directory
  84. if dir := getXDGCommandsDir(); dir != "" {
  85. sources = append(sources, commandSource{
  86. path: dir,
  87. prefix: userCommandPrefix,
  88. })
  89. }
  90. // Home directory
  91. if home := home.Dir(); home != "" {
  92. sources = append(sources, commandSource{
  93. path: filepath.Join(home, ".crush", "commands"),
  94. prefix: userCommandPrefix,
  95. })
  96. }
  97. // Project directory
  98. sources = append(sources, commandSource{
  99. path: filepath.Join(cfg.Options.DataDirectory, "commands"),
  100. prefix: projectCommandPrefix,
  101. })
  102. return sources
  103. }
  104. func loadAll(sources []commandSource) ([]CustomCommand, error) {
  105. var commands []CustomCommand
  106. for _, source := range sources {
  107. if cmds, err := loadFromSource(source); err == nil {
  108. commands = append(commands, cmds...)
  109. }
  110. }
  111. return commands, nil
  112. }
  113. func loadFromSource(source commandSource) ([]CustomCommand, error) {
  114. if err := ensureDir(source.path); err != nil {
  115. return nil, err
  116. }
  117. var commands []CustomCommand
  118. err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
  119. if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
  120. return err
  121. }
  122. cmd, err := loadCommand(path, source.path, source.prefix)
  123. if err != nil {
  124. return nil // Skip invalid files
  125. }
  126. commands = append(commands, cmd)
  127. return nil
  128. })
  129. return commands, err
  130. }
  131. func loadCommand(path, baseDir, prefix string) (CustomCommand, error) {
  132. content, err := os.ReadFile(path)
  133. if err != nil {
  134. return CustomCommand{}, err
  135. }
  136. id := buildCommandID(path, baseDir, prefix)
  137. return CustomCommand{
  138. ID: id,
  139. Name: id,
  140. Content: string(content),
  141. Arguments: extractArgNames(string(content)),
  142. }, nil
  143. }
  144. func extractArgNames(content string) []Argument {
  145. matches := namedArgPattern.FindAllStringSubmatch(content, -1)
  146. if len(matches) == 0 {
  147. return nil
  148. }
  149. seen := make(map[string]bool)
  150. var args []Argument
  151. for _, match := range matches {
  152. arg := match[1]
  153. if !seen[arg] {
  154. seen[arg] = true
  155. // for normal custom commands, all args are required
  156. args = append(args, Argument{ID: arg, Title: arg, Required: true})
  157. }
  158. }
  159. return args
  160. }
  161. func buildCommandID(path, baseDir, prefix string) string {
  162. relPath, _ := filepath.Rel(baseDir, path)
  163. parts := strings.Split(relPath, string(filepath.Separator))
  164. // Remove .md extension from last part
  165. if len(parts) > 0 {
  166. lastIdx := len(parts) - 1
  167. parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
  168. }
  169. return prefix + strings.Join(parts, ":")
  170. }
  171. func getXDGCommandsDir() string {
  172. xdgHome := os.Getenv("XDG_CONFIG_HOME")
  173. if xdgHome == "" {
  174. if home := home.Dir(); home != "" {
  175. xdgHome = filepath.Join(home, ".config")
  176. }
  177. }
  178. if xdgHome != "" {
  179. return filepath.Join(xdgHome, "crush", "commands")
  180. }
  181. return ""
  182. }
  183. func ensureDir(path string) error {
  184. if _, err := os.Stat(path); os.IsNotExist(err) {
  185. return os.MkdirAll(path, 0o755)
  186. }
  187. return nil
  188. }
  189. func isMarkdownFile(name string) bool {
  190. return strings.HasSuffix(strings.ToLower(name), ".md")
  191. }
  192. func GetMCPPrompt(cfg *config.ConfigStore, clientID, promptID string, args map[string]string) (string, error) {
  193. // TODO: we should pass the context down
  194. result, err := mcp.GetPromptMessages(context.Background(), cfg, clientID, promptID, args)
  195. if err != nil {
  196. return "", err
  197. }
  198. return strings.Join(result, " "), nil
  199. }