shell.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. // Package shell provides cross-platform shell execution capabilities.
  2. //
  3. // This package offers two main types:
  4. // - Shell: A general-purpose shell executor for one-off or managed commands
  5. // - PersistentShell: A singleton shell that maintains state across the application
  6. //
  7. // WINDOWS COMPATIBILITY:
  8. // This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
  9. // native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
  10. package shell
  11. import (
  12. "bytes"
  13. "context"
  14. "errors"
  15. "fmt"
  16. "os"
  17. "os/exec"
  18. "runtime"
  19. "strings"
  20. "sync"
  21. "mvdan.cc/sh/v3/expand"
  22. "mvdan.cc/sh/v3/interp"
  23. "mvdan.cc/sh/v3/syntax"
  24. )
  25. // ShellType represents the type of shell to use
  26. type ShellType int
  27. const (
  28. ShellTypePOSIX ShellType = iota
  29. ShellTypeCmd
  30. ShellTypePowerShell
  31. )
  32. // Logger interface for optional logging
  33. type Logger interface {
  34. InfoPersist(msg string, keysAndValues ...interface{})
  35. }
  36. // noopLogger is a logger that does nothing
  37. type noopLogger struct{}
  38. func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
  39. // Shell provides cross-platform shell execution with optional state persistence
  40. type Shell struct {
  41. env []string
  42. cwd string
  43. mu sync.Mutex
  44. logger Logger
  45. }
  46. // Options for creating a new shell
  47. type Options struct {
  48. WorkingDir string
  49. Env []string
  50. Logger Logger
  51. }
  52. // NewShell creates a new shell instance with the given options
  53. func NewShell(opts *Options) *Shell {
  54. if opts == nil {
  55. opts = &Options{}
  56. }
  57. cwd := opts.WorkingDir
  58. if cwd == "" {
  59. cwd, _ = os.Getwd()
  60. }
  61. env := opts.Env
  62. if env == nil {
  63. env = os.Environ()
  64. }
  65. logger := opts.Logger
  66. if logger == nil {
  67. logger = noopLogger{}
  68. }
  69. return &Shell{
  70. cwd: cwd,
  71. env: env,
  72. logger: logger,
  73. }
  74. }
  75. // Exec executes a command in the shell
  76. func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
  77. s.mu.Lock()
  78. defer s.mu.Unlock()
  79. // Determine which shell to use based on platform and command
  80. shellType := s.determineShellType(command)
  81. switch shellType {
  82. case ShellTypeCmd:
  83. return s.execWindows(ctx, command, "cmd")
  84. case ShellTypePowerShell:
  85. return s.execWindows(ctx, command, "powershell")
  86. default:
  87. return s.execPOSIX(ctx, command)
  88. }
  89. }
  90. // GetWorkingDir returns the current working directory
  91. func (s *Shell) GetWorkingDir() string {
  92. s.mu.Lock()
  93. defer s.mu.Unlock()
  94. return s.cwd
  95. }
  96. // SetWorkingDir sets the working directory
  97. func (s *Shell) SetWorkingDir(dir string) error {
  98. s.mu.Lock()
  99. defer s.mu.Unlock()
  100. // Verify the directory exists
  101. if _, err := os.Stat(dir); err != nil {
  102. return fmt.Errorf("directory does not exist: %w", err)
  103. }
  104. s.cwd = dir
  105. return nil
  106. }
  107. // GetEnv returns a copy of the environment variables
  108. func (s *Shell) GetEnv() []string {
  109. s.mu.Lock()
  110. defer s.mu.Unlock()
  111. env := make([]string, len(s.env))
  112. copy(env, s.env)
  113. return env
  114. }
  115. // SetEnv sets an environment variable
  116. func (s *Shell) SetEnv(key, value string) {
  117. s.mu.Lock()
  118. defer s.mu.Unlock()
  119. // Update or add the environment variable
  120. keyPrefix := key + "="
  121. for i, env := range s.env {
  122. if strings.HasPrefix(env, keyPrefix) {
  123. s.env[i] = keyPrefix + value
  124. return
  125. }
  126. }
  127. s.env = append(s.env, keyPrefix+value)
  128. }
  129. // Windows-specific commands that should use native shell
  130. var windowsNativeCommands = map[string]bool{
  131. "dir": true,
  132. "type": true,
  133. "copy": true,
  134. "move": true,
  135. "del": true,
  136. "md": true,
  137. "mkdir": true,
  138. "rd": true,
  139. "rmdir": true,
  140. "cls": true,
  141. "where": true,
  142. "tasklist": true,
  143. "taskkill": true,
  144. "net": true,
  145. "sc": true,
  146. "reg": true,
  147. "wmic": true,
  148. }
  149. // determineShellType decides which shell to use based on platform and command
  150. func (s *Shell) determineShellType(command string) ShellType {
  151. if runtime.GOOS != "windows" {
  152. return ShellTypePOSIX
  153. }
  154. // Extract the first command from the command line
  155. parts := strings.Fields(command)
  156. if len(parts) == 0 {
  157. return ShellTypePOSIX
  158. }
  159. firstCmd := strings.ToLower(parts[0])
  160. // Check if it's a Windows-specific command
  161. if windowsNativeCommands[firstCmd] {
  162. return ShellTypeCmd
  163. }
  164. // Check for PowerShell-specific syntax
  165. if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") ||
  166. strings.Contains(command, "New-") || strings.Contains(command, "$_") ||
  167. strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") {
  168. return ShellTypePowerShell
  169. }
  170. // Default to POSIX emulation for cross-platform compatibility
  171. return ShellTypePOSIX
  172. }
  173. // execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
  174. func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
  175. var cmd *exec.Cmd
  176. // Handle directory changes specially to maintain persistent shell behavior
  177. if strings.HasPrefix(strings.TrimSpace(command), "cd ") {
  178. return s.handleWindowsCD(command)
  179. }
  180. switch shell {
  181. case "cmd":
  182. // Use cmd.exe for Windows commands
  183. // Add current directory context to maintain state
  184. fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command)
  185. cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand)
  186. case "powershell":
  187. // Use PowerShell for PowerShell commands
  188. // Add current directory context to maintain state
  189. fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command)
  190. cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand)
  191. default:
  192. return "", "", fmt.Errorf("unsupported Windows shell: %s", shell)
  193. }
  194. // Set environment variables
  195. cmd.Env = s.env
  196. var stdout, stderr bytes.Buffer
  197. cmd.Stdout = &stdout
  198. cmd.Stderr = &stderr
  199. err := cmd.Run()
  200. s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
  201. return stdout.String(), stderr.String(), err
  202. }
  203. // handleWindowsCD handles directory changes for Windows shells
  204. func (s *Shell) handleWindowsCD(command string) (string, string, error) {
  205. // Extract the target directory from the cd command
  206. parts := strings.Fields(command)
  207. if len(parts) < 2 {
  208. return "", "cd: missing directory argument", fmt.Errorf("missing directory argument")
  209. }
  210. targetDir := parts[1]
  211. // Handle relative paths
  212. if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") {
  213. // Relative path - resolve against current directory
  214. if targetDir == ".." {
  215. // Go up one directory
  216. if len(s.cwd) > 3 { // Don't go above drive root (C:\)
  217. lastSlash := strings.LastIndex(s.cwd, "\\")
  218. if lastSlash > 2 { // Keep drive letter
  219. s.cwd = s.cwd[:lastSlash]
  220. }
  221. }
  222. } else if targetDir != "." {
  223. // Go to subdirectory
  224. s.cwd = s.cwd + "\\" + targetDir
  225. }
  226. } else {
  227. // Absolute path
  228. s.cwd = targetDir
  229. }
  230. // Verify the directory exists
  231. if _, err := os.Stat(s.cwd); err != nil {
  232. return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err
  233. }
  234. return "", "", nil
  235. }
  236. // execPOSIX executes commands using POSIX shell emulation (cross-platform)
  237. func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
  238. line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
  239. if err != nil {
  240. return "", "", fmt.Errorf("could not parse command: %w", err)
  241. }
  242. var stdout, stderr bytes.Buffer
  243. runner, err := interp.New(
  244. interp.StdIO(nil, &stdout, &stderr),
  245. interp.Interactive(false),
  246. interp.Env(expand.ListEnviron(s.env...)),
  247. interp.Dir(s.cwd),
  248. )
  249. if err != nil {
  250. return "", "", fmt.Errorf("could not run command: %w", err)
  251. }
  252. err = runner.Run(ctx, line)
  253. s.cwd = runner.Dir
  254. s.env = []string{}
  255. for name, vr := range runner.Vars {
  256. s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
  257. }
  258. s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
  259. return stdout.String(), stderr.String(), err
  260. }
  261. // IsInterrupt checks if an error is due to interruption
  262. func IsInterrupt(err error) bool {
  263. return errors.Is(err, context.Canceled) ||
  264. errors.Is(err, context.DeadlineExceeded)
  265. }
  266. // ExitCode extracts the exit code from an error
  267. func ExitCode(err error) int {
  268. if err == nil {
  269. return 0
  270. }
  271. status, ok := interp.IsExitStatus(err)
  272. if ok {
  273. return int(status)
  274. }
  275. return 1
  276. }