| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- package shell
- import (
- "context"
- "errors"
- "fmt"
- "os"
- "os/exec"
- "path/filepath"
- "strings"
- "sync"
- "syscall"
- "time"
- )
- type PersistentShell struct {
- cmd *exec.Cmd
- stdin *os.File
- isAlive bool
- cwd string
- mu sync.Mutex
- commandQueue chan *commandExecution
- }
- type commandExecution struct {
- command string
- timeout time.Duration
- resultChan chan commandResult
- ctx context.Context
- }
- type commandResult struct {
- stdout string
- stderr string
- exitCode int
- interrupted bool
- err error
- }
- var (
- shellInstance *PersistentShell
- shellInstanceOnce sync.Once
- )
- func GetPersistentShell(workingDir string) *PersistentShell {
- shellInstanceOnce.Do(func() {
- shellInstance = newPersistentShell(workingDir)
- })
- if !shellInstance.isAlive {
- shellInstance = newPersistentShell(shellInstance.cwd)
- }
- return shellInstance
- }
- func newPersistentShell(cwd string) *PersistentShell {
- shellPath := os.Getenv("SHELL")
- if shellPath == "" {
- shellPath = "/bin/bash"
- }
- cmd := exec.Command(shellPath, "-l")
- cmd.Dir = cwd
- stdinPipe, err := cmd.StdinPipe()
- if err != nil {
- return nil
- }
- cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
- err = cmd.Start()
- if err != nil {
- return nil
- }
- shell := &PersistentShell{
- cmd: cmd,
- stdin: stdinPipe.(*os.File),
- isAlive: true,
- cwd: cwd,
- commandQueue: make(chan *commandExecution, 10),
- }
- go func() {
- defer func() {
- if r := recover(); r != nil {
- fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
- shell.isAlive = false
- close(shell.commandQueue)
- }
- }()
- shell.processCommands()
- }()
- go func() {
- err := cmd.Wait()
- if err != nil {
- // Log the error if needed
- }
- shell.isAlive = false
- close(shell.commandQueue)
- }()
- return shell
- }
- func (s *PersistentShell) processCommands() {
- for cmd := range s.commandQueue {
- result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
- cmd.resultChan <- result
- }
- }
- func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
- s.mu.Lock()
- defer s.mu.Unlock()
- if !s.isAlive {
- return commandResult{
- stderr: "Shell is not alive",
- exitCode: 1,
- err: errors.New("shell is not alive"),
- }
- }
- tempDir := os.TempDir()
- stdoutFile := filepath.Join(tempDir, fmt.Sprintf("termai-stdout-%d", time.Now().UnixNano()))
- stderrFile := filepath.Join(tempDir, fmt.Sprintf("termai-stderr-%d", time.Now().UnixNano()))
- statusFile := filepath.Join(tempDir, fmt.Sprintf("termai-status-%d", time.Now().UnixNano()))
- cwdFile := filepath.Join(tempDir, fmt.Sprintf("termai-cwd-%d", time.Now().UnixNano()))
- defer func() {
- os.Remove(stdoutFile)
- os.Remove(stderrFile)
- os.Remove(statusFile)
- os.Remove(cwdFile)
- }()
- fullCommand := fmt.Sprintf(`
- eval %s < /dev/null > %s 2> %s
- EXEC_EXIT_CODE=$?
- pwd > %s
- echo $EXEC_EXIT_CODE > %s
- `,
- shellQuote(command),
- shellQuote(stdoutFile),
- shellQuote(stderrFile),
- shellQuote(cwdFile),
- shellQuote(statusFile),
- )
- _, err := s.stdin.Write([]byte(fullCommand + "\n"))
- if err != nil {
- return commandResult{
- stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
- exitCode: 1,
- err: err,
- }
- }
- interrupted := false
- startTime := time.Now()
- done := make(chan bool)
- go func() {
- for {
- select {
- case <-ctx.Done():
- s.killChildren()
- interrupted = true
- done <- true
- return
- case <-time.After(10 * time.Millisecond):
- if fileExists(statusFile) && fileSize(statusFile) > 0 {
- done <- true
- return
- }
- if timeout > 0 {
- elapsed := time.Since(startTime)
- if elapsed > timeout {
- s.killChildren()
- interrupted = true
- done <- true
- return
- }
- }
- }
- }
- }()
- <-done
- stdout := readFileOrEmpty(stdoutFile)
- stderr := readFileOrEmpty(stderrFile)
- exitCodeStr := readFileOrEmpty(statusFile)
- newCwd := readFileOrEmpty(cwdFile)
- exitCode := 0
- if exitCodeStr != "" {
- fmt.Sscanf(exitCodeStr, "%d", &exitCode)
- } else if interrupted {
- exitCode = 143
- stderr += "\nCommand execution timed out or was interrupted"
- }
- if newCwd != "" {
- s.cwd = strings.TrimSpace(newCwd)
- }
- return commandResult{
- stdout: stdout,
- stderr: stderr,
- exitCode: exitCode,
- interrupted: interrupted,
- }
- }
- func (s *PersistentShell) killChildren() {
- if s.cmd == nil || s.cmd.Process == nil {
- return
- }
- pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
- output, err := pgrepCmd.Output()
- if err != nil {
- return
- }
- for pidStr := range strings.SplitSeq(string(output), "\n") {
- if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
- var pid int
- fmt.Sscanf(pidStr, "%d", &pid)
- if pid > 0 {
- proc, err := os.FindProcess(pid)
- if err == nil {
- proc.Signal(syscall.SIGTERM)
- }
- }
- }
- }
- }
- func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
- if !s.isAlive {
- return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
- }
- timeout := time.Duration(timeoutMs) * time.Millisecond
- resultChan := make(chan commandResult)
- s.commandQueue <- &commandExecution{
- command: command,
- timeout: timeout,
- resultChan: resultChan,
- ctx: ctx,
- }
- result := <-resultChan
- return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
- }
- func (s *PersistentShell) Close() {
- s.mu.Lock()
- defer s.mu.Unlock()
- if !s.isAlive {
- return
- }
- s.stdin.Write([]byte("exit\n"))
- s.cmd.Process.Kill()
- s.isAlive = false
- }
- func shellQuote(s string) string {
- return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
- }
- func readFileOrEmpty(path string) string {
- content, err := os.ReadFile(path)
- if err != nil {
- return ""
- }
- return string(content)
- }
- func fileExists(path string) bool {
- _, err := os.Stat(path)
- return err == nil
- }
- func fileSize(path string) int64 {
- info, err := os.Stat(path)
- if err != nil {
- return 0
- }
- return info.Size()
- }
|