shell.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. // Package shell provides cross-platform shell execution capabilities.
  2. //
  3. // This package offers two main types:
  4. // - Shell: A general-purpose shell executor for one-off or managed commands
  5. // - PersistentShell: A singleton shell that maintains state across the application
  6. //
  7. // WINDOWS COMPATIBILITY:
  8. // This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
  9. // even on Windows. Some caution has to be taken: commands should have forward
  10. // slashes (/) as path separators to work, even on Windows.
  11. package shell
  12. import (
  13. "bytes"
  14. "context"
  15. "errors"
  16. "fmt"
  17. "os"
  18. "strings"
  19. "sync"
  20. "mvdan.cc/sh/v3/expand"
  21. "mvdan.cc/sh/v3/interp"
  22. "mvdan.cc/sh/v3/syntax"
  23. )
  24. // ShellType represents the type of shell to use
  25. type ShellType int
  26. const (
  27. ShellTypePOSIX ShellType = iota
  28. ShellTypeCmd
  29. ShellTypePowerShell
  30. )
  31. // Logger interface for optional logging
  32. type Logger interface {
  33. InfoPersist(msg string, keysAndValues ...any)
  34. }
  35. // noopLogger is a logger that does nothing
  36. type noopLogger struct{}
  37. func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
  38. // BlockFunc is a function that determines if a command should be blocked
  39. type BlockFunc func(args []string) bool
  40. // Shell provides cross-platform shell execution with optional state persistence
  41. type Shell struct {
  42. env []string
  43. cwd string
  44. mu sync.Mutex
  45. logger Logger
  46. blockFuncs []BlockFunc
  47. }
  48. // Options for creating a new shell
  49. type Options struct {
  50. WorkingDir string
  51. Env []string
  52. Logger Logger
  53. BlockFuncs []BlockFunc
  54. }
  55. // NewShell creates a new shell instance with the given options
  56. func NewShell(opts *Options) *Shell {
  57. if opts == nil {
  58. opts = &Options{}
  59. }
  60. cwd := opts.WorkingDir
  61. if cwd == "" {
  62. cwd, _ = os.Getwd()
  63. }
  64. env := opts.Env
  65. if env == nil {
  66. env = os.Environ()
  67. }
  68. logger := opts.Logger
  69. if logger == nil {
  70. logger = noopLogger{}
  71. }
  72. return &Shell{
  73. cwd: cwd,
  74. env: env,
  75. logger: logger,
  76. blockFuncs: opts.BlockFuncs,
  77. }
  78. }
  79. // Exec executes a command in the shell
  80. func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
  81. s.mu.Lock()
  82. defer s.mu.Unlock()
  83. return s.execPOSIX(ctx, command)
  84. }
  85. // GetWorkingDir returns the current working directory
  86. func (s *Shell) GetWorkingDir() string {
  87. s.mu.Lock()
  88. defer s.mu.Unlock()
  89. return s.cwd
  90. }
  91. // SetWorkingDir sets the working directory
  92. func (s *Shell) SetWorkingDir(dir string) error {
  93. s.mu.Lock()
  94. defer s.mu.Unlock()
  95. // Verify the directory exists
  96. if _, err := os.Stat(dir); err != nil {
  97. return fmt.Errorf("directory does not exist: %w", err)
  98. }
  99. s.cwd = dir
  100. return nil
  101. }
  102. // GetEnv returns a copy of the environment variables
  103. func (s *Shell) GetEnv() []string {
  104. s.mu.Lock()
  105. defer s.mu.Unlock()
  106. env := make([]string, len(s.env))
  107. copy(env, s.env)
  108. return env
  109. }
  110. // SetEnv sets an environment variable
  111. func (s *Shell) SetEnv(key, value string) {
  112. s.mu.Lock()
  113. defer s.mu.Unlock()
  114. // Update or add the environment variable
  115. keyPrefix := key + "="
  116. for i, env := range s.env {
  117. if strings.HasPrefix(env, keyPrefix) {
  118. s.env[i] = keyPrefix + value
  119. return
  120. }
  121. }
  122. s.env = append(s.env, keyPrefix+value)
  123. }
  124. // SetBlockFuncs sets the command block functions for the shell
  125. func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
  126. s.mu.Lock()
  127. defer s.mu.Unlock()
  128. s.blockFuncs = blockFuncs
  129. }
  130. // CommandsBlocker creates a BlockFunc that blocks exact command matches
  131. func CommandsBlocker(bannedCommands []string) BlockFunc {
  132. bannedSet := make(map[string]bool)
  133. for _, cmd := range bannedCommands {
  134. bannedSet[cmd] = true
  135. }
  136. return func(args []string) bool {
  137. if len(args) == 0 {
  138. return false
  139. }
  140. return bannedSet[args[0]]
  141. }
  142. }
  143. // ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
  144. func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
  145. return func(args []string) bool {
  146. for _, blocked := range blockedSubCommands {
  147. if len(args) >= len(blocked) {
  148. match := true
  149. for i, part := range blocked {
  150. if args[i] != part {
  151. match = false
  152. break
  153. }
  154. }
  155. if match {
  156. return true
  157. }
  158. }
  159. }
  160. return false
  161. }
  162. }
  163. func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
  164. return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
  165. return func(ctx context.Context, args []string) error {
  166. if len(args) == 0 {
  167. return next(ctx, args)
  168. }
  169. for _, blockFunc := range s.blockFuncs {
  170. if blockFunc(args) {
  171. return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
  172. }
  173. }
  174. return next(ctx, args)
  175. }
  176. }
  177. }
  178. // execPOSIX executes commands using POSIX shell emulation (cross-platform)
  179. func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
  180. line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
  181. if err != nil {
  182. return "", "", fmt.Errorf("could not parse command: %w", err)
  183. }
  184. var stdout, stderr bytes.Buffer
  185. runner, err := interp.New(
  186. interp.StdIO(nil, &stdout, &stderr),
  187. interp.Interactive(false),
  188. interp.Env(expand.ListEnviron(s.env...)),
  189. interp.Dir(s.cwd),
  190. interp.ExecHandlers(s.blockHandler(), s.coreUtilsHandler()),
  191. )
  192. if err != nil {
  193. return "", "", fmt.Errorf("could not run command: %w", err)
  194. }
  195. err = runner.Run(ctx, line)
  196. s.cwd = runner.Dir
  197. s.env = []string{}
  198. for name, vr := range runner.Vars {
  199. s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
  200. }
  201. s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
  202. return stdout.String(), stderr.String(), err
  203. }
  204. // IsInterrupt checks if an error is due to interruption
  205. func IsInterrupt(err error) bool {
  206. return errors.Is(err, context.Canceled) ||
  207. errors.Is(err, context.DeadlineExceeded)
  208. }
  209. // ExitCode extracts the exit code from an error
  210. func ExitCode(err error) int {
  211. if err == nil {
  212. return 0
  213. }
  214. var exitErr interp.ExitStatus
  215. if errors.As(err, &exitErr) {
  216. return int(exitErr)
  217. }
  218. return 1
  219. }