runner.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package hooks
  2. import (
  3. "bytes"
  4. "context"
  5. "log/slog"
  6. "os/exec"
  7. "path/filepath"
  8. "strings"
  9. "sync"
  10. "github.com/charmbracelet/crush/internal/config"
  11. )
  12. // Runner executes hook commands and aggregates their results.
  13. type Runner struct {
  14. hooks []config.HookConfig
  15. cwd string
  16. projectDir string
  17. }
  18. // NewRunner creates a Runner from the given hook configs.
  19. func NewRunner(hooks []config.HookConfig, cwd, projectDir string) *Runner {
  20. return &Runner{
  21. hooks: hooks,
  22. cwd: cwd,
  23. projectDir: projectDir,
  24. }
  25. }
  26. // Run executes all matching hooks for the given event and tool, returning
  27. // an aggregated result.
  28. func (r *Runner) Run(ctx context.Context, eventName, sessionID, toolName, toolInputJSON string) (AggregateResult, error) {
  29. matching := r.matchingHooks(toolName)
  30. if len(matching) == 0 {
  31. return AggregateResult{Decision: DecisionNone}, nil
  32. }
  33. // Deduplicate by command string.
  34. seen := make(map[string]bool, len(matching))
  35. var deduped []config.HookConfig
  36. for _, h := range matching {
  37. if seen[h.Command] {
  38. continue
  39. }
  40. seen[h.Command] = true
  41. deduped = append(deduped, h)
  42. }
  43. envVars := BuildEnv(eventName, toolName, sessionID, r.cwd, r.projectDir, toolInputJSON)
  44. payload := BuildPayload(eventName, sessionID, r.cwd, toolName, toolInputJSON)
  45. results := make([]HookResult, len(deduped))
  46. var wg sync.WaitGroup
  47. wg.Add(len(deduped))
  48. for i, h := range deduped {
  49. go func(idx int, hook config.HookConfig) {
  50. defer wg.Done()
  51. results[idx] = r.runOne(ctx, hook, envVars, payload)
  52. }(i, h)
  53. }
  54. wg.Wait()
  55. agg := aggregate(results)
  56. agg.Hooks = make([]HookInfo, len(deduped))
  57. for i, h := range deduped {
  58. agg.Hooks[i] = HookInfo{
  59. Name: filepath.Base(h.Command),
  60. Matcher: h.Matcher,
  61. Decision: results[i].Decision.String(),
  62. Reason: results[i].Reason,
  63. InputRewrite: results[i].UpdatedInput != "",
  64. }
  65. }
  66. slog.Info("Hook completed",
  67. "event", eventName,
  68. "tool", toolName,
  69. "hooks", len(deduped),
  70. "decision", agg.Decision.String(),
  71. )
  72. return agg, nil
  73. }
  74. // matchingHooks returns hooks whose matcher matches the tool name (or has
  75. // no matcher, which matches everything).
  76. func (r *Runner) matchingHooks(toolName string) []config.HookConfig {
  77. var matched []config.HookConfig
  78. for _, h := range r.hooks {
  79. re := h.MatcherRegex()
  80. if re == nil || re.MatchString(toolName) {
  81. matched = append(matched, h)
  82. }
  83. }
  84. return matched
  85. }
  86. // runOne executes a single hook command and returns its result.
  87. func (r *Runner) runOne(parentCtx context.Context, hook config.HookConfig, envVars []string, payload []byte) HookResult {
  88. timeout := hook.TimeoutDuration()
  89. ctx, cancel := context.WithTimeout(parentCtx, timeout)
  90. defer cancel()
  91. cmd := exec.CommandContext(ctx, "sh", "-c", hook.Command)
  92. cmd.Env = envVars
  93. cmd.Dir = r.cwd
  94. cmd.Stdin = bytes.NewReader(payload)
  95. var stdout, stderr bytes.Buffer
  96. cmd.Stdout = &stdout
  97. cmd.Stderr = &stderr
  98. err := cmd.Run()
  99. if ctx.Err() != nil {
  100. // Distinguish timeout from parent cancellation.
  101. if parentCtx.Err() != nil {
  102. slog.Debug("Hook cancelled by parent context", "command", hook.Command)
  103. } else {
  104. slog.Warn("Hook timed out", "command", hook.Command, "timeout", timeout)
  105. }
  106. return HookResult{Decision: DecisionNone}
  107. }
  108. if err != nil {
  109. exitCode := cmd.ProcessState.ExitCode()
  110. switch exitCode {
  111. case 2:
  112. // Exit code 2 = block. Stderr is the reason.
  113. reason := strings.TrimSpace(stderr.String())
  114. if reason == "" {
  115. reason = "blocked by hook"
  116. }
  117. return HookResult{
  118. Decision: DecisionDeny,
  119. Reason: reason,
  120. }
  121. default:
  122. // Other non-zero exits are non-blocking errors.
  123. slog.Warn("Hook failed with non-blocking error",
  124. "command", hook.Command,
  125. "exit_code", exitCode,
  126. "stderr", strings.TrimSpace(stderr.String()),
  127. )
  128. return HookResult{Decision: DecisionNone}
  129. }
  130. }
  131. // Exit code 0 — parse stdout JSON.
  132. result := parseStdout(stdout.String())
  133. slog.Debug("Hook executed",
  134. "command", hook.Command,
  135. "decision", result.Decision.String(),
  136. )
  137. return result
  138. }