hooked_tool.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. package agent
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log/slog"
  7. "charm.land/fantasy"
  8. "github.com/charmbracelet/crush/internal/agent/tools"
  9. "github.com/charmbracelet/crush/internal/hooks"
  10. "github.com/tidwall/sjson"
  11. )
  12. // hookedTool wraps a fantasy.AgentTool to run PreToolUse hooks before
  13. // delegating to the inner tool.
  14. type hookedTool struct {
  15. inner fantasy.AgentTool
  16. runner *hooks.Runner
  17. }
  18. func newHookedTool(inner fantasy.AgentTool, runner *hooks.Runner) *hookedTool {
  19. return &hookedTool{inner: inner, runner: runner}
  20. }
  21. func (h *hookedTool) Info() fantasy.ToolInfo {
  22. return h.inner.Info()
  23. }
  24. func (h *hookedTool) ProviderOptions() fantasy.ProviderOptions {
  25. return h.inner.ProviderOptions()
  26. }
  27. func (h *hookedTool) SetProviderOptions(opts fantasy.ProviderOptions) {
  28. h.inner.SetProviderOptions(opts)
  29. }
  30. func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  31. sessionID := tools.GetSessionFromContext(ctx)
  32. result, err := h.runner.Run(ctx, hooks.EventPreToolUse, sessionID, call.Name, call.Input)
  33. if err != nil {
  34. slog.Warn("Hook execution error, proceeding with tool call",
  35. "tool", call.Name, "error", err)
  36. }
  37. if result.Decision == hooks.DecisionDeny {
  38. reason := fmt.Sprintf(
  39. "This tool call was blocked by a hook and must not be retried. Reason: %s",
  40. result.Reason,
  41. )
  42. resp := fantasy.NewTextErrorResponse(reason)
  43. resp.Metadata = hookMetadataJSON(result)
  44. return resp, nil
  45. }
  46. if result.UpdatedInput != "" {
  47. call.Input = result.UpdatedInput
  48. }
  49. resp, err := h.inner.Run(ctx, call)
  50. if err != nil {
  51. return resp, err
  52. }
  53. if result.Context != "" {
  54. if resp.Content != "" {
  55. resp.Content += "\n"
  56. }
  57. resp.Content += result.Context
  58. }
  59. resp.Metadata = mergeHookMetadata(resp.Metadata, result)
  60. return resp, nil
  61. }
  62. // buildHookMetadata creates a HookMetadata from an AggregateResult.
  63. func buildHookMetadata(result hooks.AggregateResult) hooks.HookMetadata {
  64. return hooks.HookMetadata{
  65. HookCount: result.HookCount,
  66. Decision: result.Decision.String(),
  67. Reason: result.Reason,
  68. InputRewrite: result.UpdatedInput != "",
  69. Hooks: result.Hooks,
  70. }
  71. }
  72. // hookMetadataJSON builds a JSON string containing only the hook metadata.
  73. func hookMetadataJSON(result hooks.AggregateResult) string {
  74. meta := buildHookMetadata(result)
  75. data, err := json.Marshal(meta)
  76. if err != nil {
  77. return ""
  78. }
  79. return `{"hook":` + string(data) + `}`
  80. }
  81. // mergeHookMetadata injects hook metadata into existing tool metadata.
  82. func mergeHookMetadata(existing string, result hooks.AggregateResult) string {
  83. if result.HookCount == 0 {
  84. return existing
  85. }
  86. meta := buildHookMetadata(result)
  87. data, err := json.Marshal(meta)
  88. if err != nil {
  89. return existing
  90. }
  91. if existing == "" {
  92. existing = "{}"
  93. }
  94. merged, err := sjson.SetRaw(existing, "hook", string(data))
  95. if err != nil {
  96. return existing
  97. }
  98. return merged
  99. }