shell.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. package shell
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "os/exec"
  8. "path/filepath"
  9. "strings"
  10. "sync"
  11. "syscall"
  12. "time"
  13. )
  14. type PersistentShell struct {
  15. cmd *exec.Cmd
  16. stdin *os.File
  17. isAlive bool
  18. cwd string
  19. mu sync.Mutex
  20. commandQueue chan *commandExecution
  21. }
  22. type commandExecution struct {
  23. command string
  24. timeout time.Duration
  25. resultChan chan commandResult
  26. ctx context.Context
  27. }
  28. type commandResult struct {
  29. stdout string
  30. stderr string
  31. exitCode int
  32. interrupted bool
  33. err error
  34. }
  35. var (
  36. shellInstance *PersistentShell
  37. shellInstanceOnce sync.Once
  38. )
  39. func GetPersistentShell(workingDir string) *PersistentShell {
  40. shellInstanceOnce.Do(func() {
  41. shellInstance = newPersistentShell(workingDir)
  42. })
  43. if !shellInstance.isAlive {
  44. shellInstance = newPersistentShell(shellInstance.cwd)
  45. }
  46. return shellInstance
  47. }
  48. func newPersistentShell(cwd string) *PersistentShell {
  49. shellPath := os.Getenv("SHELL")
  50. if shellPath == "" {
  51. shellPath = "/bin/bash"
  52. }
  53. cmd := exec.Command(shellPath, "-l")
  54. cmd.Dir = cwd
  55. stdinPipe, err := cmd.StdinPipe()
  56. if err != nil {
  57. return nil
  58. }
  59. cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
  60. err = cmd.Start()
  61. if err != nil {
  62. return nil
  63. }
  64. shell := &PersistentShell{
  65. cmd: cmd,
  66. stdin: stdinPipe.(*os.File),
  67. isAlive: true,
  68. cwd: cwd,
  69. commandQueue: make(chan *commandExecution, 10),
  70. }
  71. go shell.processCommands()
  72. go func() {
  73. err := cmd.Wait()
  74. if err != nil {
  75. }
  76. shell.isAlive = false
  77. close(shell.commandQueue)
  78. }()
  79. return shell
  80. }
  81. func (s *PersistentShell) processCommands() {
  82. for cmd := range s.commandQueue {
  83. result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
  84. cmd.resultChan <- result
  85. }
  86. }
  87. func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
  88. s.mu.Lock()
  89. defer s.mu.Unlock()
  90. if !s.isAlive {
  91. return commandResult{
  92. stderr: "Shell is not alive",
  93. exitCode: 1,
  94. err: errors.New("shell is not alive"),
  95. }
  96. }
  97. tempDir := os.TempDir()
  98. stdoutFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-stdout-%d", time.Now().UnixNano()))
  99. stderrFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-stderr-%d", time.Now().UnixNano()))
  100. statusFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-status-%d", time.Now().UnixNano()))
  101. cwdFile := filepath.Join(tempDir, fmt.Sprintf("orbitowl-cwd-%d", time.Now().UnixNano()))
  102. defer func() {
  103. os.Remove(stdoutFile)
  104. os.Remove(stderrFile)
  105. os.Remove(statusFile)
  106. os.Remove(cwdFile)
  107. }()
  108. fullCommand := fmt.Sprintf(`
  109. eval %s < /dev/null > %s 2> %s
  110. EXEC_EXIT_CODE=$?
  111. pwd > %s
  112. echo $EXEC_EXIT_CODE > %s
  113. `,
  114. shellQuote(command),
  115. shellQuote(stdoutFile),
  116. shellQuote(stderrFile),
  117. shellQuote(cwdFile),
  118. shellQuote(statusFile),
  119. )
  120. _, err := s.stdin.Write([]byte(fullCommand + "\n"))
  121. if err != nil {
  122. return commandResult{
  123. stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
  124. exitCode: 1,
  125. err: err,
  126. }
  127. }
  128. interrupted := false
  129. startTime := time.Now()
  130. done := make(chan bool)
  131. go func() {
  132. for {
  133. select {
  134. case <-ctx.Done():
  135. s.killChildren()
  136. interrupted = true
  137. done <- true
  138. return
  139. case <-time.After(10 * time.Millisecond):
  140. if fileExists(statusFile) && fileSize(statusFile) > 0 {
  141. done <- true
  142. return
  143. }
  144. if timeout > 0 {
  145. elapsed := time.Since(startTime)
  146. if elapsed > timeout {
  147. s.killChildren()
  148. interrupted = true
  149. done <- true
  150. return
  151. }
  152. }
  153. }
  154. }
  155. }()
  156. <-done
  157. stdout := readFileOrEmpty(stdoutFile)
  158. stderr := readFileOrEmpty(stderrFile)
  159. exitCodeStr := readFileOrEmpty(statusFile)
  160. newCwd := readFileOrEmpty(cwdFile)
  161. exitCode := 0
  162. if exitCodeStr != "" {
  163. fmt.Sscanf(exitCodeStr, "%d", &exitCode)
  164. } else if interrupted {
  165. exitCode = 143
  166. stderr += "\nCommand execution timed out or was interrupted"
  167. }
  168. if newCwd != "" {
  169. s.cwd = strings.TrimSpace(newCwd)
  170. }
  171. return commandResult{
  172. stdout: stdout,
  173. stderr: stderr,
  174. exitCode: exitCode,
  175. interrupted: interrupted,
  176. }
  177. }
  178. func (s *PersistentShell) killChildren() {
  179. if s.cmd == nil || s.cmd.Process == nil {
  180. return
  181. }
  182. pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
  183. output, err := pgrepCmd.Output()
  184. if err != nil {
  185. return
  186. }
  187. for _, pidStr := range strings.Split(string(output), "\n") {
  188. if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
  189. var pid int
  190. fmt.Sscanf(pidStr, "%d", &pid)
  191. if pid > 0 {
  192. proc, err := os.FindProcess(pid)
  193. if err == nil {
  194. proc.Signal(syscall.SIGTERM)
  195. }
  196. }
  197. }
  198. }
  199. }
  200. func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
  201. if !s.isAlive {
  202. return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
  203. }
  204. timeout := time.Duration(timeoutMs) * time.Millisecond
  205. resultChan := make(chan commandResult)
  206. s.commandQueue <- &commandExecution{
  207. command: command,
  208. timeout: timeout,
  209. resultChan: resultChan,
  210. ctx: ctx,
  211. }
  212. result := <-resultChan
  213. return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
  214. }
  215. func (s *PersistentShell) Close() {
  216. s.mu.Lock()
  217. defer s.mu.Unlock()
  218. if !s.isAlive {
  219. return
  220. }
  221. s.stdin.Write([]byte("exit\n"))
  222. s.cmd.Process.Kill()
  223. s.isAlive = false
  224. }
  225. func shellQuote(s string) string {
  226. return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
  227. }
  228. func readFileOrEmpty(path string) string {
  229. content, err := os.ReadFile(path)
  230. if err != nil {
  231. return ""
  232. }
  233. return string(content)
  234. }
  235. func fileExists(path string) bool {
  236. _, err := os.Stat(path)
  237. return err == nil
  238. }
  239. func fileSize(path string) int64 {
  240. info, err := os.Stat(path)
  241. if err != nil {
  242. return 0
  243. }
  244. return info.Size()
  245. }