shell.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. package shell
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "os/exec"
  8. "path/filepath"
  9. "strings"
  10. "sync"
  11. "syscall"
  12. "time"
  13. )
  14. type PersistentShell struct {
  15. cmd *exec.Cmd
  16. stdin *os.File
  17. isAlive bool
  18. cwd string
  19. mu sync.Mutex
  20. commandQueue chan *commandExecution
  21. }
  22. type commandExecution struct {
  23. command string
  24. timeout time.Duration
  25. resultChan chan commandResult
  26. ctx context.Context
  27. }
  28. type commandResult struct {
  29. stdout string
  30. stderr string
  31. exitCode int
  32. interrupted bool
  33. err error
  34. }
  35. var (
  36. shellInstance *PersistentShell
  37. shellInstanceOnce sync.Once
  38. )
  39. func GetPersistentShell(workingDir string) *PersistentShell {
  40. shellInstanceOnce.Do(func() {
  41. shellInstance = newPersistentShell(workingDir)
  42. })
  43. if shellInstance == nil {
  44. shellInstance = newPersistentShell(workingDir)
  45. } else if !shellInstance.isAlive {
  46. shellInstance = newPersistentShell(shellInstance.cwd)
  47. }
  48. return shellInstance
  49. }
  50. func newPersistentShell(cwd string) *PersistentShell {
  51. shellPath := os.Getenv("SHELL")
  52. if shellPath == "" {
  53. shellPath = "/bin/bash"
  54. }
  55. cmd := exec.Command(shellPath, "-l")
  56. cmd.Dir = cwd
  57. stdinPipe, err := cmd.StdinPipe()
  58. if err != nil {
  59. return nil
  60. }
  61. cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
  62. err = cmd.Start()
  63. if err != nil {
  64. return nil
  65. }
  66. shell := &PersistentShell{
  67. cmd: cmd,
  68. stdin: stdinPipe.(*os.File),
  69. isAlive: true,
  70. cwd: cwd,
  71. commandQueue: make(chan *commandExecution, 10),
  72. }
  73. go func() {
  74. defer func() {
  75. if r := recover(); r != nil {
  76. fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
  77. shell.isAlive = false
  78. close(shell.commandQueue)
  79. }
  80. }()
  81. shell.processCommands()
  82. }()
  83. go func() {
  84. err := cmd.Wait()
  85. if err != nil {
  86. // Log the error if needed
  87. }
  88. shell.isAlive = false
  89. close(shell.commandQueue)
  90. }()
  91. return shell
  92. }
  93. func (s *PersistentShell) processCommands() {
  94. for cmd := range s.commandQueue {
  95. result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
  96. cmd.resultChan <- result
  97. }
  98. }
  99. func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
  100. s.mu.Lock()
  101. defer s.mu.Unlock()
  102. if !s.isAlive {
  103. return commandResult{
  104. stderr: "Shell is not alive",
  105. exitCode: 1,
  106. err: errors.New("shell is not alive"),
  107. }
  108. }
  109. tempDir := os.TempDir()
  110. stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano()))
  111. stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano()))
  112. statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano()))
  113. cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano()))
  114. defer func() {
  115. os.Remove(stdoutFile)
  116. os.Remove(stderrFile)
  117. os.Remove(statusFile)
  118. os.Remove(cwdFile)
  119. }()
  120. fullCommand := fmt.Sprintf(`
  121. eval %s < /dev/null > %s 2> %s
  122. EXEC_EXIT_CODE=$?
  123. pwd > %s
  124. echo $EXEC_EXIT_CODE > %s
  125. `,
  126. shellQuote(command),
  127. shellQuote(stdoutFile),
  128. shellQuote(stderrFile),
  129. shellQuote(cwdFile),
  130. shellQuote(statusFile),
  131. )
  132. _, err := s.stdin.Write([]byte(fullCommand + "\n"))
  133. if err != nil {
  134. return commandResult{
  135. stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
  136. exitCode: 1,
  137. err: err,
  138. }
  139. }
  140. interrupted := false
  141. startTime := time.Now()
  142. done := make(chan bool)
  143. go func() {
  144. for {
  145. select {
  146. case <-ctx.Done():
  147. s.killChildren()
  148. interrupted = true
  149. done <- true
  150. return
  151. case <-time.After(10 * time.Millisecond):
  152. if fileExists(statusFile) && fileSize(statusFile) > 0 {
  153. done <- true
  154. return
  155. }
  156. if timeout > 0 {
  157. elapsed := time.Since(startTime)
  158. if elapsed > timeout {
  159. s.killChildren()
  160. interrupted = true
  161. done <- true
  162. return
  163. }
  164. }
  165. }
  166. }
  167. }()
  168. <-done
  169. stdout := readFileOrEmpty(stdoutFile)
  170. stderr := readFileOrEmpty(stderrFile)
  171. exitCodeStr := readFileOrEmpty(statusFile)
  172. newCwd := readFileOrEmpty(cwdFile)
  173. exitCode := 0
  174. if exitCodeStr != "" {
  175. fmt.Sscanf(exitCodeStr, "%d", &exitCode)
  176. } else if interrupted {
  177. exitCode = 143
  178. stderr += "\nCommand execution timed out or was interrupted"
  179. }
  180. if newCwd != "" {
  181. s.cwd = strings.TrimSpace(newCwd)
  182. }
  183. return commandResult{
  184. stdout: stdout,
  185. stderr: stderr,
  186. exitCode: exitCode,
  187. interrupted: interrupted,
  188. }
  189. }
  190. func (s *PersistentShell) killChildren() {
  191. if s.cmd == nil || s.cmd.Process == nil {
  192. return
  193. }
  194. pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
  195. output, err := pgrepCmd.Output()
  196. if err != nil {
  197. return
  198. }
  199. for pidStr := range strings.SplitSeq(string(output), "\n") {
  200. if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
  201. var pid int
  202. fmt.Sscanf(pidStr, "%d", &pid)
  203. if pid > 0 {
  204. proc, err := os.FindProcess(pid)
  205. if err == nil {
  206. proc.Signal(syscall.SIGTERM)
  207. }
  208. }
  209. }
  210. }
  211. }
  212. func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
  213. if !s.isAlive {
  214. return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
  215. }
  216. timeout := time.Duration(timeoutMs) * time.Millisecond
  217. resultChan := make(chan commandResult)
  218. s.commandQueue <- &commandExecution{
  219. command: command,
  220. timeout: timeout,
  221. resultChan: resultChan,
  222. ctx: ctx,
  223. }
  224. result := <-resultChan
  225. return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
  226. }
  227. func (s *PersistentShell) Close() {
  228. s.mu.Lock()
  229. defer s.mu.Unlock()
  230. if !s.isAlive {
  231. return
  232. }
  233. s.stdin.Write([]byte("exit\n"))
  234. s.cmd.Process.Kill()
  235. s.isAlive = false
  236. }
  237. func shellQuote(s string) string {
  238. return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
  239. }
  240. func readFileOrEmpty(path string) string {
  241. content, err := os.ReadFile(path)
  242. if err != nil {
  243. return ""
  244. }
  245. return string(content)
  246. }
  247. func fileExists(path string) bool {
  248. _, err := os.Stat(path)
  249. return err == nil
  250. }
  251. func fileSize(path string) int64 {
  252. info, err := os.Stat(path)
  253. if err != nil {
  254. return 0
  255. }
  256. return info.Size()
  257. }