loader.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. package commands
  2. import (
  3. "cmp"
  4. "context"
  5. "fmt"
  6. "io/fs"
  7. "os"
  8. "path/filepath"
  9. "regexp"
  10. "strings"
  11. tea "github.com/charmbracelet/bubbletea/v2"
  12. "github.com/charmbracelet/crush/internal/agent/tools/mcp"
  13. "github.com/charmbracelet/crush/internal/config"
  14. "github.com/charmbracelet/crush/internal/home"
  15. "github.com/charmbracelet/crush/internal/tui/components/chat"
  16. "github.com/charmbracelet/crush/internal/tui/util"
  17. )
  18. const (
  19. userCommandPrefix = "user:"
  20. projectCommandPrefix = "project:"
  21. )
  22. var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`)
  23. type commandLoader struct {
  24. sources []commandSource
  25. }
  26. type commandSource struct {
  27. path string
  28. prefix string
  29. }
  30. func LoadCustomCommands() ([]Command, error) {
  31. cfg := config.Get()
  32. if cfg == nil {
  33. return nil, fmt.Errorf("config not loaded")
  34. }
  35. loader := &commandLoader{
  36. sources: buildCommandSources(cfg),
  37. }
  38. return loader.loadAll()
  39. }
  40. func buildCommandSources(cfg *config.Config) []commandSource {
  41. var sources []commandSource
  42. // XDG config directory
  43. if dir := getXDGCommandsDir(); dir != "" {
  44. sources = append(sources, commandSource{
  45. path: dir,
  46. prefix: userCommandPrefix,
  47. })
  48. }
  49. // Home directory
  50. if home := home.Dir(); home != "" {
  51. sources = append(sources, commandSource{
  52. path: filepath.Join(home, ".crush", "commands"),
  53. prefix: userCommandPrefix,
  54. })
  55. }
  56. // Project directory
  57. sources = append(sources, commandSource{
  58. path: filepath.Join(cfg.Options.DataDirectory, "commands"),
  59. prefix: projectCommandPrefix,
  60. })
  61. return sources
  62. }
  63. func getXDGCommandsDir() string {
  64. xdgHome := os.Getenv("XDG_CONFIG_HOME")
  65. if xdgHome == "" {
  66. if home := home.Dir(); home != "" {
  67. xdgHome = filepath.Join(home, ".config")
  68. }
  69. }
  70. if xdgHome != "" {
  71. return filepath.Join(xdgHome, "crush", "commands")
  72. }
  73. return ""
  74. }
  75. func (l *commandLoader) loadAll() ([]Command, error) {
  76. var commands []Command
  77. for _, source := range l.sources {
  78. if cmds, err := l.loadFromSource(source); err == nil {
  79. commands = append(commands, cmds...)
  80. }
  81. }
  82. return commands, nil
  83. }
  84. func (l *commandLoader) loadFromSource(source commandSource) ([]Command, error) {
  85. if err := ensureDir(source.path); err != nil {
  86. return nil, err
  87. }
  88. var commands []Command
  89. err := filepath.WalkDir(source.path, func(path string, d fs.DirEntry, err error) error {
  90. if err != nil || d.IsDir() || !isMarkdownFile(d.Name()) {
  91. return err
  92. }
  93. cmd, err := l.loadCommand(path, source.path, source.prefix)
  94. if err != nil {
  95. return nil // Skip invalid files
  96. }
  97. commands = append(commands, cmd)
  98. return nil
  99. })
  100. return commands, err
  101. }
  102. func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, error) {
  103. content, err := os.ReadFile(path)
  104. if err != nil {
  105. return Command{}, err
  106. }
  107. id := buildCommandID(path, baseDir, prefix)
  108. desc := fmt.Sprintf("Custom command from %s", filepath.Base(path))
  109. return Command{
  110. ID: id,
  111. Title: id,
  112. Description: desc,
  113. Handler: createCommandHandler(id, desc, string(content)),
  114. }, nil
  115. }
  116. func buildCommandID(path, baseDir, prefix string) string {
  117. relPath, _ := filepath.Rel(baseDir, path)
  118. parts := strings.Split(relPath, string(filepath.Separator))
  119. // Remove .md extension from last part
  120. if len(parts) > 0 {
  121. lastIdx := len(parts) - 1
  122. parts[lastIdx] = strings.TrimSuffix(parts[lastIdx], filepath.Ext(parts[lastIdx]))
  123. }
  124. return prefix + strings.Join(parts, ":")
  125. }
  126. func createCommandHandler(id, desc, content string) func(Command) tea.Cmd {
  127. return func(cmd Command) tea.Cmd {
  128. args := extractArgNames(content)
  129. if len(args) == 0 {
  130. return util.CmdHandler(CommandRunCustomMsg{
  131. Content: content,
  132. })
  133. }
  134. return util.CmdHandler(ShowArgumentsDialogMsg{
  135. CommandID: id,
  136. Description: desc,
  137. ArgNames: args,
  138. OnSubmit: func(args map[string]string) tea.Cmd {
  139. return execUserPrompt(content, args)
  140. },
  141. })
  142. }
  143. }
  144. func execUserPrompt(content string, args map[string]string) tea.Cmd {
  145. return func() tea.Msg {
  146. for name, value := range args {
  147. placeholder := "$" + name
  148. content = strings.ReplaceAll(content, placeholder, value)
  149. }
  150. return CommandRunCustomMsg{
  151. Content: content,
  152. }
  153. }
  154. }
  155. func extractArgNames(content string) []string {
  156. matches := namedArgPattern.FindAllStringSubmatch(content, -1)
  157. if len(matches) == 0 {
  158. return nil
  159. }
  160. seen := make(map[string]bool)
  161. var args []string
  162. for _, match := range matches {
  163. arg := match[1]
  164. if !seen[arg] {
  165. seen[arg] = true
  166. args = append(args, arg)
  167. }
  168. }
  169. return args
  170. }
  171. func ensureDir(path string) error {
  172. if _, err := os.Stat(path); os.IsNotExist(err) {
  173. return os.MkdirAll(path, 0o755)
  174. }
  175. return nil
  176. }
  177. func isMarkdownFile(name string) bool {
  178. return strings.HasSuffix(strings.ToLower(name), ".md")
  179. }
  180. type CommandRunCustomMsg struct {
  181. Content string
  182. }
  183. func loadMCPPrompts() []Command {
  184. var commands []Command
  185. for mcpName, prompts := range mcp.Prompts() {
  186. for _, prompt := range prompts {
  187. key := mcpName + ":" + prompt.Name
  188. commands = append(commands, Command{
  189. ID: key,
  190. Title: cmp.Or(prompt.Title, prompt.Name),
  191. Description: prompt.Description,
  192. Handler: createMCPPromptHandler(mcpName, prompt.Name, prompt),
  193. })
  194. }
  195. }
  196. return commands
  197. }
  198. func createMCPPromptHandler(mcpName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd {
  199. return func(cmd Command) tea.Cmd {
  200. if len(prompt.Arguments) == 0 {
  201. return execMCPPrompt(mcpName, promptName, nil)
  202. }
  203. return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{
  204. Prompt: prompt,
  205. OnSubmit: func(args map[string]string) tea.Cmd {
  206. return execMCPPrompt(mcpName, promptName, args)
  207. },
  208. })
  209. }
  210. }
  211. func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd {
  212. return func() tea.Msg {
  213. ctx := context.Background()
  214. result, err := mcp.GetPromptMessages(ctx, clientName, promptName, args)
  215. if err != nil {
  216. return util.ReportError(err)
  217. }
  218. return chat.SendMsg{
  219. Text: strings.Join(result, " "),
  220. }
  221. }
  222. }
  223. type ShowMCPPromptArgumentsDialogMsg struct {
  224. Prompt *mcp.Prompt
  225. OnSubmit func(arg map[string]string) tea.Cmd
  226. }