shell.go 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package shell
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "os"
  8. "strings"
  9. "sync"
  10. "github.com/charmbracelet/crush/internal/logging"
  11. "mvdan.cc/sh/v3/expand"
  12. "mvdan.cc/sh/v3/interp"
  13. "mvdan.cc/sh/v3/syntax"
  14. )
  15. type PersistentShell struct {
  16. env []string
  17. cwd string
  18. mu sync.Mutex
  19. }
  20. var (
  21. once sync.Once
  22. shellInstance *PersistentShell
  23. )
  24. func GetPersistentShell(cwd string) *PersistentShell {
  25. once.Do(func() {
  26. shellInstance = newPersistentShell(cwd)
  27. })
  28. return shellInstance
  29. }
  30. func newPersistentShell(cwd string) *PersistentShell {
  31. return &PersistentShell{
  32. cwd: cwd,
  33. env: os.Environ(),
  34. }
  35. }
  36. func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
  37. s.mu.Lock()
  38. defer s.mu.Unlock()
  39. line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
  40. if err != nil {
  41. return "", "", fmt.Errorf("could not parse command: %w", err)
  42. }
  43. var stdout, stderr bytes.Buffer
  44. runner, err := interp.New(
  45. interp.StdIO(nil, &stdout, &stderr),
  46. interp.Interactive(false),
  47. interp.Env(expand.ListEnviron(s.env...)),
  48. interp.Dir(s.cwd),
  49. )
  50. if err != nil {
  51. return "", "", fmt.Errorf("could not run command: %w", err)
  52. }
  53. err = runner.Run(ctx, line)
  54. s.cwd = runner.Dir
  55. s.env = []string{}
  56. for name, vr := range runner.Vars {
  57. s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
  58. }
  59. logging.InfoPersist("Command finished", "command", command, "err", err)
  60. return stdout.String(), stderr.String(), err
  61. }
  62. func IsInterrupt(err error) bool {
  63. return errors.Is(err, context.Canceled) ||
  64. errors.Is(err, context.DeadlineExceeded)
  65. }
  66. func ExitCode(err error) int {
  67. if err == nil {
  68. return 0
  69. }
  70. status, ok := interp.IsExitStatus(err)
  71. if ok {
  72. return int(status)
  73. }
  74. return 1
  75. }