shell.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. package shell
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "os/exec"
  8. "path/filepath"
  9. "strings"
  10. "sync"
  11. "syscall"
  12. "time"
  13. )
  14. type PersistentShell struct {
  15. cmd *exec.Cmd
  16. stdin *os.File
  17. isAlive bool
  18. cwd string
  19. mu sync.Mutex
  20. commandQueue chan *commandExecution
  21. }
  22. type commandExecution struct {
  23. command string
  24. timeout time.Duration
  25. resultChan chan commandResult
  26. ctx context.Context
  27. }
  28. type commandResult struct {
  29. stdout string
  30. stderr string
  31. exitCode int
  32. interrupted bool
  33. err error
  34. }
  35. var (
  36. shellInstance *PersistentShell
  37. shellInstanceOnce sync.Once
  38. )
  39. func GetPersistentShell(workingDir string) *PersistentShell {
  40. shellInstanceOnce.Do(func() {
  41. shellInstance = newPersistentShell(workingDir)
  42. })
  43. if !shellInstance.isAlive {
  44. shellInstance = newPersistentShell(shellInstance.cwd)
  45. }
  46. return shellInstance
  47. }
  48. func newPersistentShell(cwd string) *PersistentShell {
  49. shellPath := os.Getenv("SHELL")
  50. if shellPath == "" {
  51. shellPath = "/bin/bash"
  52. }
  53. cmd := exec.Command(shellPath, "-l")
  54. cmd.Dir = cwd
  55. stdinPipe, err := cmd.StdinPipe()
  56. if err != nil {
  57. return nil
  58. }
  59. cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
  60. err = cmd.Start()
  61. if err != nil {
  62. return nil
  63. }
  64. shell := &PersistentShell{
  65. cmd: cmd,
  66. stdin: stdinPipe.(*os.File),
  67. isAlive: true,
  68. cwd: cwd,
  69. commandQueue: make(chan *commandExecution, 10),
  70. }
  71. go func() {
  72. defer func() {
  73. if r := recover(); r != nil {
  74. fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
  75. shell.isAlive = false
  76. close(shell.commandQueue)
  77. }
  78. }()
  79. shell.processCommands()
  80. }()
  81. go func() {
  82. err := cmd.Wait()
  83. if err != nil {
  84. // Log the error if needed
  85. }
  86. shell.isAlive = false
  87. close(shell.commandQueue)
  88. }()
  89. return shell
  90. }
  91. func (s *PersistentShell) processCommands() {
  92. for cmd := range s.commandQueue {
  93. result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
  94. cmd.resultChan <- result
  95. }
  96. }
  97. func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
  98. s.mu.Lock()
  99. defer s.mu.Unlock()
  100. if !s.isAlive {
  101. return commandResult{
  102. stderr: "Shell is not alive",
  103. exitCode: 1,
  104. err: errors.New("shell is not alive"),
  105. }
  106. }
  107. tempDir := os.TempDir()
  108. stdoutFile := filepath.Join(tempDir, fmt.Sprintf("termai-stdout-%d", time.Now().UnixNano()))
  109. stderrFile := filepath.Join(tempDir, fmt.Sprintf("termai-stderr-%d", time.Now().UnixNano()))
  110. statusFile := filepath.Join(tempDir, fmt.Sprintf("termai-status-%d", time.Now().UnixNano()))
  111. cwdFile := filepath.Join(tempDir, fmt.Sprintf("termai-cwd-%d", time.Now().UnixNano()))
  112. defer func() {
  113. os.Remove(stdoutFile)
  114. os.Remove(stderrFile)
  115. os.Remove(statusFile)
  116. os.Remove(cwdFile)
  117. }()
  118. fullCommand := fmt.Sprintf(`
  119. eval %s < /dev/null > %s 2> %s
  120. EXEC_EXIT_CODE=$?
  121. pwd > %s
  122. echo $EXEC_EXIT_CODE > %s
  123. `,
  124. shellQuote(command),
  125. shellQuote(stdoutFile),
  126. shellQuote(stderrFile),
  127. shellQuote(cwdFile),
  128. shellQuote(statusFile),
  129. )
  130. _, err := s.stdin.Write([]byte(fullCommand + "\n"))
  131. if err != nil {
  132. return commandResult{
  133. stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
  134. exitCode: 1,
  135. err: err,
  136. }
  137. }
  138. interrupted := false
  139. startTime := time.Now()
  140. done := make(chan bool)
  141. go func() {
  142. for {
  143. select {
  144. case <-ctx.Done():
  145. s.killChildren()
  146. interrupted = true
  147. done <- true
  148. return
  149. case <-time.After(10 * time.Millisecond):
  150. if fileExists(statusFile) && fileSize(statusFile) > 0 {
  151. done <- true
  152. return
  153. }
  154. if timeout > 0 {
  155. elapsed := time.Since(startTime)
  156. if elapsed > timeout {
  157. s.killChildren()
  158. interrupted = true
  159. done <- true
  160. return
  161. }
  162. }
  163. }
  164. }
  165. }()
  166. <-done
  167. stdout := readFileOrEmpty(stdoutFile)
  168. stderr := readFileOrEmpty(stderrFile)
  169. exitCodeStr := readFileOrEmpty(statusFile)
  170. newCwd := readFileOrEmpty(cwdFile)
  171. exitCode := 0
  172. if exitCodeStr != "" {
  173. fmt.Sscanf(exitCodeStr, "%d", &exitCode)
  174. } else if interrupted {
  175. exitCode = 143
  176. stderr += "\nCommand execution timed out or was interrupted"
  177. }
  178. if newCwd != "" {
  179. s.cwd = strings.TrimSpace(newCwd)
  180. }
  181. return commandResult{
  182. stdout: stdout,
  183. stderr: stderr,
  184. exitCode: exitCode,
  185. interrupted: interrupted,
  186. }
  187. }
  188. func (s *PersistentShell) killChildren() {
  189. if s.cmd == nil || s.cmd.Process == nil {
  190. return
  191. }
  192. pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
  193. output, err := pgrepCmd.Output()
  194. if err != nil {
  195. return
  196. }
  197. for pidStr := range strings.SplitSeq(string(output), "\n") {
  198. if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
  199. var pid int
  200. fmt.Sscanf(pidStr, "%d", &pid)
  201. if pid > 0 {
  202. proc, err := os.FindProcess(pid)
  203. if err == nil {
  204. proc.Signal(syscall.SIGTERM)
  205. }
  206. }
  207. }
  208. }
  209. }
  210. func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
  211. if !s.isAlive {
  212. return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
  213. }
  214. timeout := time.Duration(timeoutMs) * time.Millisecond
  215. resultChan := make(chan commandResult)
  216. s.commandQueue <- &commandExecution{
  217. command: command,
  218. timeout: timeout,
  219. resultChan: resultChan,
  220. ctx: ctx,
  221. }
  222. result := <-resultChan
  223. return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
  224. }
  225. func (s *PersistentShell) Close() {
  226. s.mu.Lock()
  227. defer s.mu.Unlock()
  228. if !s.isAlive {
  229. return
  230. }
  231. s.stdin.Write([]byte("exit\n"))
  232. s.cmd.Process.Kill()
  233. s.isAlive = false
  234. }
  235. func shellQuote(s string) string {
  236. return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
  237. }
  238. func readFileOrEmpty(path string) string {
  239. content, err := os.ReadFile(path)
  240. if err != nil {
  241. return ""
  242. }
  243. return string(content)
  244. }
  245. func fileExists(path string) bool {
  246. _, err := os.Stat(path)
  247. return err == nil
  248. }
  249. func fileSize(path string) int64 {
  250. info, err := os.Stat(path)
  251. if err != nil {
  252. return 0
  253. }
  254. return info.Size()
  255. }