| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- package agent
- import (
- "context"
- "encoding/json"
- "fmt"
- "log/slog"
- "charm.land/fantasy"
- "github.com/charmbracelet/crush/internal/agent/tools"
- "github.com/charmbracelet/crush/internal/hooks"
- "github.com/tidwall/sjson"
- )
- // hookedTool wraps a fantasy.AgentTool to run PreToolUse hooks before
- // delegating to the inner tool.
- type hookedTool struct {
- inner fantasy.AgentTool
- runner *hooks.Runner
- }
- func newHookedTool(inner fantasy.AgentTool, runner *hooks.Runner) *hookedTool {
- return &hookedTool{inner: inner, runner: runner}
- }
- func (h *hookedTool) Info() fantasy.ToolInfo {
- return h.inner.Info()
- }
- func (h *hookedTool) ProviderOptions() fantasy.ProviderOptions {
- return h.inner.ProviderOptions()
- }
- func (h *hookedTool) SetProviderOptions(opts fantasy.ProviderOptions) {
- h.inner.SetProviderOptions(opts)
- }
- func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
- sessionID := tools.GetSessionFromContext(ctx)
- result, err := h.runner.Run(ctx, hooks.EventPreToolUse, sessionID, call.Name, call.Input)
- if err != nil {
- slog.Warn("Hook execution error, proceeding with tool call",
- "tool", call.Name, "error", err)
- }
- if result.Decision == hooks.DecisionDeny {
- reason := fmt.Sprintf(
- "This tool call was blocked by a hook and must not be retried. Reason: %s",
- result.Reason,
- )
- resp := fantasy.NewTextErrorResponse(reason)
- resp.Metadata = hookMetadataJSON(result)
- return resp, nil
- }
- if result.UpdatedInput != "" {
- call.Input = result.UpdatedInput
- }
- resp, err := h.inner.Run(ctx, call)
- if err != nil {
- return resp, err
- }
- if result.Context != "" {
- if resp.Content != "" {
- resp.Content += "\n"
- }
- resp.Content += result.Context
- }
- resp.Metadata = mergeHookMetadata(resp.Metadata, result)
- return resp, nil
- }
- // buildHookMetadata creates a HookMetadata from an AggregateResult.
- func buildHookMetadata(result hooks.AggregateResult) hooks.HookMetadata {
- return hooks.HookMetadata{
- HookCount: result.HookCount,
- Decision: result.Decision.String(),
- Reason: result.Reason,
- InputRewrite: result.UpdatedInput != "",
- Hooks: result.Hooks,
- }
- }
- // hookMetadataJSON builds a JSON string containing only the hook metadata.
- func hookMetadataJSON(result hooks.AggregateResult) string {
- meta := buildHookMetadata(result)
- data, err := json.Marshal(meta)
- if err != nil {
- return ""
- }
- return `{"hook":` + string(data) + `}`
- }
- // mergeHookMetadata injects hook metadata into existing tool metadata.
- func mergeHookMetadata(existing string, result hooks.AggregateResult) string {
- if result.HookCount == 0 {
- return existing
- }
- meta := buildHookMetadata(result)
- data, err := json.Marshal(meta)
- if err != nil {
- return existing
- }
- if existing == "" {
- existing = "{}"
- }
- merged, err := sjson.SetRaw(existing, "hook", string(data))
- if err != nil {
- return existing
- }
- return merged
- }
|