shell.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. // Package shell provides cross-platform shell execution capabilities.
  2. //
  3. // This package provides Shell instances for executing commands with their own
  4. // working directory and environment. Each shell execution is independent.
  5. //
  6. // WINDOWS COMPATIBILITY:
  7. // This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
  8. // Windows. Commands should use forward slashes (/) as path separators to work
  9. // correctly on all platforms.
  10. package shell
  11. import (
  12. "bytes"
  13. "context"
  14. "errors"
  15. "fmt"
  16. "io"
  17. "os"
  18. "slices"
  19. "strings"
  20. "sync"
  21. "github.com/charmbracelet/x/exp/slice"
  22. "mvdan.cc/sh/moreinterp/coreutils"
  23. "mvdan.cc/sh/v3/expand"
  24. "mvdan.cc/sh/v3/interp"
  25. "mvdan.cc/sh/v3/syntax"
  26. )
  27. // ShellType represents the type of shell to use
  28. type ShellType int
  29. const (
  30. ShellTypePOSIX ShellType = iota
  31. ShellTypeCmd
  32. ShellTypePowerShell
  33. )
  34. // Logger interface for optional logging
  35. type Logger interface {
  36. InfoPersist(msg string, keysAndValues ...any)
  37. }
  38. // noopLogger is a logger that does nothing
  39. type noopLogger struct{}
  40. func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
  41. // BlockFunc is a function that determines if a command should be blocked
  42. type BlockFunc func(args []string) bool
  43. // Shell provides cross-platform shell execution with optional state persistence
  44. type Shell struct {
  45. env []string
  46. cwd string
  47. mu sync.Mutex
  48. logger Logger
  49. blockFuncs []BlockFunc
  50. }
  51. // Options for creating a new shell
  52. type Options struct {
  53. WorkingDir string
  54. Env []string
  55. Logger Logger
  56. BlockFuncs []BlockFunc
  57. }
  58. // NewShell creates a new shell instance with the given options
  59. func NewShell(opts *Options) *Shell {
  60. if opts == nil {
  61. opts = &Options{}
  62. }
  63. cwd := opts.WorkingDir
  64. if cwd == "" {
  65. cwd, _ = os.Getwd()
  66. }
  67. env := opts.Env
  68. if env == nil {
  69. env = os.Environ()
  70. }
  71. // Allow tools to detect execution by Crush.
  72. env = append(
  73. env,
  74. "CRUSH=1",
  75. "AGENT=crush",
  76. "AI_AGENT=crush",
  77. )
  78. logger := opts.Logger
  79. if logger == nil {
  80. logger = noopLogger{}
  81. }
  82. return &Shell{
  83. cwd: cwd,
  84. env: env,
  85. logger: logger,
  86. blockFuncs: opts.BlockFuncs,
  87. }
  88. }
  89. // Exec executes a command in the shell
  90. func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
  91. s.mu.Lock()
  92. defer s.mu.Unlock()
  93. return s.exec(ctx, command)
  94. }
  95. // ExecStream executes a command in the shell with streaming output to provided writers
  96. func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
  97. s.mu.Lock()
  98. defer s.mu.Unlock()
  99. return s.execStream(ctx, command, stdout, stderr)
  100. }
  101. // GetWorkingDir returns the current working directory
  102. func (s *Shell) GetWorkingDir() string {
  103. s.mu.Lock()
  104. defer s.mu.Unlock()
  105. return s.cwd
  106. }
  107. // SetWorkingDir sets the working directory
  108. func (s *Shell) SetWorkingDir(dir string) error {
  109. s.mu.Lock()
  110. defer s.mu.Unlock()
  111. // Verify the directory exists
  112. if _, err := os.Stat(dir); err != nil {
  113. return fmt.Errorf("directory does not exist: %w", err)
  114. }
  115. s.cwd = dir
  116. return nil
  117. }
  118. // GetEnv returns a copy of the environment variables
  119. func (s *Shell) GetEnv() []string {
  120. s.mu.Lock()
  121. defer s.mu.Unlock()
  122. env := make([]string, len(s.env))
  123. copy(env, s.env)
  124. return env
  125. }
  126. // SetEnv sets an environment variable
  127. func (s *Shell) SetEnv(key, value string) {
  128. s.mu.Lock()
  129. defer s.mu.Unlock()
  130. // Update or add the environment variable
  131. keyPrefix := key + "="
  132. for i, env := range s.env {
  133. if strings.HasPrefix(env, keyPrefix) {
  134. s.env[i] = keyPrefix + value
  135. return
  136. }
  137. }
  138. s.env = append(s.env, keyPrefix+value)
  139. }
  140. // SetBlockFuncs sets the command block functions for the shell
  141. func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
  142. s.mu.Lock()
  143. defer s.mu.Unlock()
  144. s.blockFuncs = blockFuncs
  145. }
  146. // CommandsBlocker creates a BlockFunc that blocks exact command matches
  147. func CommandsBlocker(cmds []string) BlockFunc {
  148. bannedSet := make(map[string]struct{})
  149. for _, cmd := range cmds {
  150. bannedSet[cmd] = struct{}{}
  151. }
  152. return func(args []string) bool {
  153. if len(args) == 0 {
  154. return false
  155. }
  156. _, ok := bannedSet[args[0]]
  157. return ok
  158. }
  159. }
  160. // ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
  161. func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
  162. return func(parts []string) bool {
  163. if len(parts) == 0 || parts[0] != cmd {
  164. return false
  165. }
  166. argParts, flagParts := splitArgsFlags(parts[1:])
  167. if len(argParts) < len(args) || len(flagParts) < len(flags) {
  168. return false
  169. }
  170. argsMatch := slices.Equal(argParts[:len(args)], args)
  171. flagsMatch := slice.IsSubset(flags, flagParts)
  172. return argsMatch && flagsMatch
  173. }
  174. }
  175. func splitArgsFlags(parts []string) (args []string, flags []string) {
  176. args = make([]string, 0, len(parts))
  177. flags = make([]string, 0, len(parts))
  178. for _, part := range parts {
  179. if strings.HasPrefix(part, "-") {
  180. // Extract flag name before '=' if present
  181. flag := part
  182. if before, _, ok := strings.Cut(part, "="); ok {
  183. flag = before
  184. }
  185. flags = append(flags, flag)
  186. } else {
  187. args = append(args, part)
  188. }
  189. }
  190. return args, flags
  191. }
  192. func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
  193. return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
  194. return func(ctx context.Context, args []string) error {
  195. if len(args) == 0 {
  196. return next(ctx, args)
  197. }
  198. for _, blockFunc := range s.blockFuncs {
  199. if blockFunc(args) {
  200. return fmt.Errorf("command is not allowed for security reasons: %q", args[0])
  201. }
  202. }
  203. return next(ctx, args)
  204. }
  205. }
  206. }
  207. // newInterp creates a new interpreter with the current shell state
  208. func (s *Shell) newInterp(stdout, stderr io.Writer) (*interp.Runner, error) {
  209. return interp.New(
  210. interp.StdIO(nil, stdout, stderr),
  211. interp.Interactive(false),
  212. interp.Env(expand.ListEnviron(s.env...)),
  213. interp.Dir(s.cwd),
  214. interp.ExecHandlers(s.execHandlers()...),
  215. )
  216. }
  217. // updateShellFromRunner updates the shell from the interpreter after execution.
  218. func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
  219. s.cwd = runner.Dir
  220. s.env = s.env[:0]
  221. for name, vr := range runner.Vars {
  222. if vr.Exported {
  223. s.env = append(s.env, name+"="+vr.Str)
  224. }
  225. }
  226. }
  227. // execCommon is the shared implementation for executing commands
  228. func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr io.Writer) (err error) {
  229. var runner *interp.Runner
  230. defer func() {
  231. if r := recover(); r != nil {
  232. err = fmt.Errorf("command execution panic: %v", r)
  233. }
  234. if runner != nil {
  235. s.updateShellFromRunner(runner)
  236. }
  237. s.logger.InfoPersist("command finished", "command", command, "err", err)
  238. }()
  239. line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
  240. if err != nil {
  241. return fmt.Errorf("could not parse command: %w", err)
  242. }
  243. runner, err = s.newInterp(stdout, stderr)
  244. if err != nil {
  245. return fmt.Errorf("could not run command: %w", err)
  246. }
  247. err = runner.Run(ctx, line)
  248. return err
  249. }
  250. // exec executes commands using a cross-platform shell interpreter.
  251. func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
  252. var stdout, stderr bytes.Buffer
  253. err := s.execCommon(ctx, command, &stdout, &stderr)
  254. return stdout.String(), stderr.String(), err
  255. }
  256. // execStream executes commands using POSIX shell emulation with streaming output
  257. func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
  258. return s.execCommon(ctx, command, stdout, stderr)
  259. }
  260. func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
  261. handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
  262. s.blockHandler(),
  263. }
  264. if useGoCoreUtils {
  265. handlers = append(handlers, coreutils.ExecHandler)
  266. }
  267. return handlers
  268. }
  269. // IsInterrupt checks if an error is due to interruption
  270. func IsInterrupt(err error) bool {
  271. return errors.Is(err, context.Canceled) ||
  272. errors.Is(err, context.DeadlineExceeded)
  273. }
  274. // ExitCode extracts the exit code from an error
  275. func ExitCode(err error) int {
  276. if err == nil {
  277. return 0
  278. }
  279. var exitErr interp.ExitStatus
  280. if errors.As(err, &exitErr) {
  281. return int(exitErr)
  282. }
  283. return 1
  284. }