shell.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. package shell
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "os"
  7. "os/exec"
  8. "path/filepath"
  9. "strings"
  10. "sync"
  11. "syscall"
  12. "time"
  13. "github.com/sst/opencode/internal/config"
  14. "github.com/sst/opencode/internal/status"
  15. )
  16. type PersistentShell struct {
  17. cmd *exec.Cmd
  18. stdin *os.File
  19. isAlive bool
  20. cwd string
  21. mu sync.Mutex
  22. commandQueue chan *commandExecution
  23. }
  24. type commandExecution struct {
  25. command string
  26. timeout time.Duration
  27. resultChan chan commandResult
  28. ctx context.Context
  29. }
  30. type commandResult struct {
  31. stdout string
  32. stderr string
  33. exitCode int
  34. interrupted bool
  35. err error
  36. }
  37. var (
  38. shellInstance *PersistentShell
  39. shellInstanceOnce sync.Once
  40. )
  41. func GetPersistentShell(workingDir string) *PersistentShell {
  42. shellInstanceOnce.Do(func() {
  43. shellInstance = newPersistentShell(workingDir)
  44. })
  45. if shellInstance == nil {
  46. shellInstance = newPersistentShell(workingDir)
  47. } else if !shellInstance.isAlive {
  48. shellInstance = newPersistentShell(shellInstance.cwd)
  49. }
  50. return shellInstance
  51. }
  52. func newPersistentShell(cwd string) *PersistentShell {
  53. cfg := config.Get()
  54. // Use shell from config if specified
  55. shellPath := ""
  56. shellArgs := []string{"-l"}
  57. if cfg != nil && cfg.Shell.Path != "" {
  58. shellPath = cfg.Shell.Path
  59. if len(cfg.Shell.Args) > 0 {
  60. shellArgs = cfg.Shell.Args
  61. }
  62. } else {
  63. // Fall back to environment variable
  64. shellPath = os.Getenv("SHELL")
  65. if shellPath == "" {
  66. // Default to bash if neither config nor environment variable is set
  67. shellPath = "/bin/bash"
  68. }
  69. }
  70. cmd := exec.Command(shellPath, shellArgs...)
  71. cmd.Dir = cwd
  72. stdinPipe, err := cmd.StdinPipe()
  73. if err != nil {
  74. return nil
  75. }
  76. cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
  77. err = cmd.Start()
  78. if err != nil {
  79. return nil
  80. }
  81. shell := &PersistentShell{
  82. cmd: cmd,
  83. stdin: stdinPipe.(*os.File),
  84. isAlive: true,
  85. cwd: cwd,
  86. commandQueue: make(chan *commandExecution, 10),
  87. }
  88. go func() {
  89. defer func() {
  90. if r := recover(); r != nil {
  91. fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
  92. shell.isAlive = false
  93. close(shell.commandQueue)
  94. }
  95. }()
  96. shell.processCommands()
  97. }()
  98. go func() {
  99. err := cmd.Wait()
  100. if err != nil {
  101. status.Error(fmt.Sprintf("Shell process exited with error: %v", err))
  102. }
  103. shell.isAlive = false
  104. close(shell.commandQueue)
  105. }()
  106. return shell
  107. }
  108. func (s *PersistentShell) processCommands() {
  109. for cmd := range s.commandQueue {
  110. result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
  111. cmd.resultChan <- result
  112. }
  113. }
  114. func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
  115. s.mu.Lock()
  116. defer s.mu.Unlock()
  117. if !s.isAlive {
  118. return commandResult{
  119. stderr: "Shell is not alive",
  120. exitCode: 1,
  121. err: errors.New("shell is not alive"),
  122. }
  123. }
  124. tempDir := os.TempDir()
  125. stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano()))
  126. stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano()))
  127. statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano()))
  128. cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano()))
  129. defer func() {
  130. os.Remove(stdoutFile)
  131. os.Remove(stderrFile)
  132. os.Remove(statusFile)
  133. os.Remove(cwdFile)
  134. }()
  135. fullCommand := fmt.Sprintf(`
  136. eval %s < /dev/null > %s 2> %s
  137. EXEC_EXIT_CODE=$?
  138. pwd > %s
  139. echo $EXEC_EXIT_CODE > %s
  140. `,
  141. shellQuote(command),
  142. shellQuote(stdoutFile),
  143. shellQuote(stderrFile),
  144. shellQuote(cwdFile),
  145. shellQuote(statusFile),
  146. )
  147. _, err := s.stdin.Write([]byte(fullCommand + "\n"))
  148. if err != nil {
  149. return commandResult{
  150. stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
  151. exitCode: 1,
  152. err: err,
  153. }
  154. }
  155. interrupted := false
  156. startTime := time.Now()
  157. done := make(chan bool)
  158. go func() {
  159. for {
  160. select {
  161. case <-ctx.Done():
  162. s.killChildren()
  163. interrupted = true
  164. done <- true
  165. return
  166. case <-time.After(10 * time.Millisecond):
  167. if fileExists(statusFile) && fileSize(statusFile) > 0 {
  168. done <- true
  169. return
  170. }
  171. if timeout > 0 {
  172. elapsed := time.Since(startTime)
  173. if elapsed > timeout {
  174. s.killChildren()
  175. interrupted = true
  176. done <- true
  177. return
  178. }
  179. }
  180. }
  181. }
  182. }()
  183. <-done
  184. stdout := readFileOrEmpty(stdoutFile)
  185. stderr := readFileOrEmpty(stderrFile)
  186. exitCodeStr := readFileOrEmpty(statusFile)
  187. newCwd := readFileOrEmpty(cwdFile)
  188. exitCode := 0
  189. if exitCodeStr != "" {
  190. fmt.Sscanf(exitCodeStr, "%d", &exitCode)
  191. } else if interrupted {
  192. exitCode = 143
  193. stderr += "\nCommand execution timed out or was interrupted"
  194. }
  195. if newCwd != "" {
  196. s.cwd = strings.TrimSpace(newCwd)
  197. }
  198. return commandResult{
  199. stdout: stdout,
  200. stderr: stderr,
  201. exitCode: exitCode,
  202. interrupted: interrupted,
  203. }
  204. }
  205. func (s *PersistentShell) killChildren() {
  206. if s.cmd == nil || s.cmd.Process == nil {
  207. return
  208. }
  209. pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
  210. output, err := pgrepCmd.Output()
  211. if err != nil {
  212. return
  213. }
  214. for pidStr := range strings.SplitSeq(string(output), "\n") {
  215. if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
  216. var pid int
  217. fmt.Sscanf(pidStr, "%d", &pid)
  218. if pid > 0 {
  219. proc, err := os.FindProcess(pid)
  220. if err == nil {
  221. proc.Signal(syscall.SIGTERM)
  222. }
  223. }
  224. }
  225. }
  226. }
  227. func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
  228. if !s.isAlive {
  229. return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
  230. }
  231. timeout := time.Duration(timeoutMs) * time.Millisecond
  232. resultChan := make(chan commandResult)
  233. s.commandQueue <- &commandExecution{
  234. command: command,
  235. timeout: timeout,
  236. resultChan: resultChan,
  237. ctx: ctx,
  238. }
  239. result := <-resultChan
  240. return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
  241. }
  242. func (s *PersistentShell) Close() {
  243. s.mu.Lock()
  244. defer s.mu.Unlock()
  245. if !s.isAlive {
  246. return
  247. }
  248. s.stdin.Write([]byte("exit\n"))
  249. s.cmd.Process.Kill()
  250. s.isAlive = false
  251. }
  252. func shellQuote(s string) string {
  253. return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
  254. }
  255. func readFileOrEmpty(path string) string {
  256. content, err := os.ReadFile(path)
  257. if err != nil {
  258. return ""
  259. }
  260. return string(content)
  261. }
  262. func fileExists(path string) bool {
  263. _, err := os.Stat(path)
  264. return err == nil
  265. }
  266. func fileSize(path string) int64 {
  267. info, err := os.Stat(path)
  268. if err != nil {
  269. return 0
  270. }
  271. return info.Size()
  272. }