| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395 |
- package tools
- import (
- "bytes"
- "context"
- _ "embed"
- "encoding/json"
- "fmt"
- "html/template"
- "strings"
- "time"
- "github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/shell"
- )
- type BashParams struct {
- Command string `json:"command"`
- Timeout int `json:"timeout"`
- }
- type BashPermissionsParams struct {
- Command string `json:"command"`
- Timeout int `json:"timeout"`
- }
- type BashResponseMetadata struct {
- StartTime int64 `json:"start_time"`
- EndTime int64 `json:"end_time"`
- Output string `json:"output"`
- WorkingDirectory string `json:"working_directory"`
- }
- type bashTool struct {
- permissions permission.Service
- workingDir string
- attribution *config.Attribution
- }
- const (
- BashToolName = "bash"
- DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
- MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
- MaxOutputLength = 30000
- BashNoOutput = "no output"
- )
- //go:embed bash.md
- var bashDescription []byte
- var bashDescriptionTpl = template.Must(
- template.New("bashDescription").
- Parse(string(bashDescription)),
- )
- type bashDescriptionData struct {
- BannedCommands string
- MaxOutputLength int
- AttributionStep string
- AttributionExample string
- PRAttribution string
- }
- var bannedCommands = []string{
- // Network/Download tools
- "alias",
- "aria2c",
- "axel",
- "chrome",
- "curl",
- "curlie",
- "firefox",
- "http-prompt",
- "httpie",
- "links",
- "lynx",
- "nc",
- "safari",
- "scp",
- "ssh",
- "telnet",
- "w3m",
- "wget",
- "xh",
- // System administration
- "doas",
- "su",
- "sudo",
- // Package managers
- "apk",
- "apt",
- "apt-cache",
- "apt-get",
- "dnf",
- "dpkg",
- "emerge",
- "home-manager",
- "makepkg",
- "opkg",
- "pacman",
- "paru",
- "pkg",
- "pkg_add",
- "pkg_delete",
- "portage",
- "rpm",
- "yay",
- "yum",
- "zypper",
- // System modification
- "at",
- "batch",
- "chkconfig",
- "crontab",
- "fdisk",
- "mkfs",
- "mount",
- "parted",
- "service",
- "systemctl",
- "umount",
- // Network configuration
- "firewall-cmd",
- "ifconfig",
- "ip",
- "iptables",
- "netstat",
- "pfctl",
- "route",
- "ufw",
- }
- func (b *bashTool) bashDescription() string {
- bannedCommandsStr := strings.Join(bannedCommands, ", ")
- // Build attribution text based on settings
- var attributionStep, attributionExample, prAttribution string
- // Default to true if attribution is nil (backward compatibility)
- generatedWith := b.attribution == nil || b.attribution.GeneratedWith
- coAuthoredBy := b.attribution == nil || b.attribution.CoAuthoredBy
- // Build PR attribution
- if generatedWith {
- prAttribution = "💘 Generated with Crush"
- }
- if generatedWith || coAuthoredBy {
- var attributionParts []string
- if generatedWith {
- attributionParts = append(attributionParts, "💘 Generated with Crush")
- }
- if coAuthoredBy {
- attributionParts = append(attributionParts, "Co-Authored-By: Crush <[email protected]>")
- }
- if len(attributionParts) > 0 {
- attributionStep = fmt.Sprintf("4. Create the commit with a message ending with:\n%s", strings.Join(attributionParts, "\n"))
- attributionText := strings.Join(attributionParts, "\n ")
- attributionExample = fmt.Sprintf(`<example>
- git commit -m "$(cat <<'EOF'
- Commit message here.
- %s
- EOF
- )"</example>`, attributionText)
- }
- }
- if attributionStep == "" {
- attributionStep = "4. Create the commit with your commit message."
- attributionExample = `<example>
- git commit -m "$(cat <<'EOF'
- Commit message here.
- EOF
- )"</example>`
- }
- var out bytes.Buffer
- if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
- BannedCommands: bannedCommandsStr,
- MaxOutputLength: MaxOutputLength,
- AttributionStep: attributionStep,
- AttributionExample: attributionExample,
- PRAttribution: prAttribution,
- }); err != nil {
- // this should never happen.
- panic("failed to execute bash description template: " + err.Error())
- }
- return out.String()
- }
- func blockFuncs() []shell.BlockFunc {
- return []shell.BlockFunc{
- shell.CommandsBlocker(bannedCommands),
- // System package managers
- shell.ArgumentsBlocker("apk", []string{"add"}, nil),
- shell.ArgumentsBlocker("apt", []string{"install"}, nil),
- shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
- shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
- shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
- shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
- shell.ArgumentsBlocker("yum", []string{"install"}, nil),
- shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
- // Language-specific package managers
- shell.ArgumentsBlocker("brew", []string{"install"}, nil),
- shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
- shell.ArgumentsBlocker("gem", []string{"install"}, nil),
- shell.ArgumentsBlocker("go", []string{"install"}, nil),
- shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
- shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
- shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
- shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
- shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
- shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
- shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
- // `go test -exec` can run arbitrary commands
- shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
- }
- }
- func NewBashTool(permission permission.Service, workingDir string, attribution *config.Attribution) BaseTool {
- // Set up command blocking on the persistent shell
- persistentShell := shell.GetPersistentShell(workingDir)
- persistentShell.SetBlockFuncs(blockFuncs())
- return &bashTool{
- permissions: permission,
- workingDir: workingDir,
- attribution: attribution,
- }
- }
- func (b *bashTool) Name() string {
- return BashToolName
- }
- func (b *bashTool) Info() ToolInfo {
- return ToolInfo{
- Name: BashToolName,
- Description: b.bashDescription(),
- Parameters: map[string]any{
- "command": map[string]any{
- "type": "string",
- "description": "The command to execute",
- },
- "timeout": map[string]any{
- "type": "number",
- "description": "Optional timeout in milliseconds (max 600000)",
- },
- },
- Required: []string{"command"},
- }
- }
- func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
- var params BashParams
- if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
- return NewTextErrorResponse("invalid parameters"), nil
- }
- if params.Timeout > MaxTimeout {
- params.Timeout = MaxTimeout
- } else if params.Timeout <= 0 {
- params.Timeout = DefaultTimeout
- }
- if params.Command == "" {
- return NewTextErrorResponse("missing command"), nil
- }
- isSafeReadOnly := false
- cmdLower := strings.ToLower(params.Command)
- for _, safe := range safeCommands {
- if strings.HasPrefix(cmdLower, safe) {
- if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
- isSafeReadOnly = true
- break
- }
- }
- }
- sessionID, messageID := GetContextValues(ctx)
- if sessionID == "" || messageID == "" {
- return ToolResponse{}, fmt.Errorf("session ID and message ID are required for executing shell command")
- }
- if !isSafeReadOnly {
- shell := shell.GetPersistentShell(b.workingDir)
- p := b.permissions.Request(
- permission.CreatePermissionRequest{
- SessionID: sessionID,
- Path: shell.GetWorkingDir(),
- ToolCallID: call.ID,
- ToolName: BashToolName,
- Action: "execute",
- Description: fmt.Sprintf("Execute command: %s", params.Command),
- Params: BashPermissionsParams{
- Command: params.Command,
- },
- },
- )
- if !p {
- return ToolResponse{}, permission.ErrorPermissionDenied
- }
- }
- startTime := time.Now()
- if params.Timeout > 0 {
- var cancel context.CancelFunc
- ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
- defer cancel()
- }
- persistentShell := shell.GetPersistentShell(b.workingDir)
- stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
- // Get the current working directory after command execution
- currentWorkingDir := persistentShell.GetWorkingDir()
- interrupted := shell.IsInterrupt(err)
- exitCode := shell.ExitCode(err)
- if exitCode == 0 && !interrupted && err != nil {
- return ToolResponse{}, fmt.Errorf("error executing command: %w", err)
- }
- stdout = truncateOutput(stdout)
- stderr = truncateOutput(stderr)
- errorMessage := stderr
- if errorMessage == "" && err != nil {
- errorMessage = err.Error()
- }
- if interrupted {
- if errorMessage != "" {
- errorMessage += "\n"
- }
- errorMessage += "Command was aborted before completion"
- } else if exitCode != 0 {
- if errorMessage != "" {
- errorMessage += "\n"
- }
- errorMessage += fmt.Sprintf("Exit code %d", exitCode)
- }
- hasBothOutputs := stdout != "" && stderr != ""
- if hasBothOutputs {
- stdout += "\n"
- }
- if errorMessage != "" {
- stdout += "\n" + errorMessage
- }
- metadata := BashResponseMetadata{
- StartTime: startTime.UnixMilli(),
- EndTime: time.Now().UnixMilli(),
- Output: stdout,
- WorkingDirectory: currentWorkingDir,
- }
- if stdout == "" {
- return WithResponseMetadata(NewTextResponse(BashNoOutput), metadata), nil
- }
- stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", currentWorkingDir)
- return WithResponseMetadata(NewTextResponse(stdout), metadata), nil
- }
- func truncateOutput(content string) string {
- if len(content) <= MaxOutputLength {
- return content
- }
- halfLength := MaxOutputLength / 2
- start := content[:halfLength]
- end := content[len(content)-halfLength:]
- truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
- return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
- }
- func countLines(s string) int {
- if s == "" {
- return 0
- }
- return len(strings.Split(s, "\n"))
- }
|