bash.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. package tools
  2. import (
  3. "bytes"
  4. "context"
  5. _ "embed"
  6. "fmt"
  7. "html/template"
  8. "os"
  9. "path/filepath"
  10. "runtime"
  11. "strings"
  12. "time"
  13. "charm.land/fantasy"
  14. "github.com/charmbracelet/crush/internal/config"
  15. "github.com/charmbracelet/crush/internal/permission"
  16. "github.com/charmbracelet/crush/internal/shell"
  17. )
  18. type BashParams struct {
  19. Command string `json:"command" description:"The command to execute"`
  20. Description string `json:"description,omitempty" description:"A brief description of what the command does"`
  21. Timeout int `json:"timeout,omitempty" description:"Optional timeout in milliseconds (max 600000)"`
  22. }
  23. type BashPermissionsParams struct {
  24. Command string `json:"command"`
  25. Description string `json:"description"`
  26. Timeout int `json:"timeout"`
  27. }
  28. type BashResponseMetadata struct {
  29. StartTime int64 `json:"start_time"`
  30. EndTime int64 `json:"end_time"`
  31. Output string `json:"output"`
  32. Description string `json:"description"`
  33. WorkingDirectory string `json:"working_directory"`
  34. }
  35. const (
  36. BashToolName = "bash"
  37. DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
  38. MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
  39. MaxOutputLength = 30000
  40. BashNoOutput = "no output"
  41. )
  42. //go:embed bash.tpl
  43. var bashDescriptionTmpl []byte
  44. var bashDescriptionTpl = template.Must(
  45. template.New("bashDescription").
  46. Parse(string(bashDescriptionTmpl)),
  47. )
  48. type bashDescriptionData struct {
  49. BannedCommands string
  50. MaxOutputLength int
  51. Attribution config.Attribution
  52. }
  53. var bannedCommands = []string{
  54. // Network/Download tools
  55. "alias",
  56. "aria2c",
  57. "axel",
  58. "chrome",
  59. "curl",
  60. "curlie",
  61. "firefox",
  62. "http-prompt",
  63. "httpie",
  64. "links",
  65. "lynx",
  66. "nc",
  67. "safari",
  68. "scp",
  69. "ssh",
  70. "telnet",
  71. "w3m",
  72. "wget",
  73. "xh",
  74. // System administration
  75. "doas",
  76. "su",
  77. "sudo",
  78. // Package managers
  79. "apk",
  80. "apt",
  81. "apt-cache",
  82. "apt-get",
  83. "dnf",
  84. "dpkg",
  85. "emerge",
  86. "home-manager",
  87. "makepkg",
  88. "opkg",
  89. "pacman",
  90. "paru",
  91. "pkg",
  92. "pkg_add",
  93. "pkg_delete",
  94. "portage",
  95. "rpm",
  96. "yay",
  97. "yum",
  98. "zypper",
  99. // System modification
  100. "at",
  101. "batch",
  102. "chkconfig",
  103. "crontab",
  104. "fdisk",
  105. "mkfs",
  106. "mount",
  107. "parted",
  108. "service",
  109. "systemctl",
  110. "umount",
  111. // Network configuration
  112. "firewall-cmd",
  113. "ifconfig",
  114. "ip",
  115. "iptables",
  116. "netstat",
  117. "pfctl",
  118. "route",
  119. "ufw",
  120. }
  121. func bashDescription(attribution *config.Attribution) string {
  122. bannedCommandsStr := strings.Join(bannedCommands, ", ")
  123. var out bytes.Buffer
  124. if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
  125. BannedCommands: bannedCommandsStr,
  126. MaxOutputLength: MaxOutputLength,
  127. Attribution: *attribution,
  128. }); err != nil {
  129. // this should never happen.
  130. panic("failed to execute bash description template: " + err.Error())
  131. }
  132. return out.String()
  133. }
  134. func blockFuncs() []shell.BlockFunc {
  135. return []shell.BlockFunc{
  136. shell.CommandsBlocker(bannedCommands),
  137. // System package managers
  138. shell.ArgumentsBlocker("apk", []string{"add"}, nil),
  139. shell.ArgumentsBlocker("apt", []string{"install"}, nil),
  140. shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
  141. shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
  142. shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
  143. shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
  144. shell.ArgumentsBlocker("yum", []string{"install"}, nil),
  145. shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
  146. // Language-specific package managers
  147. shell.ArgumentsBlocker("brew", []string{"install"}, nil),
  148. shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
  149. shell.ArgumentsBlocker("gem", []string{"install"}, nil),
  150. shell.ArgumentsBlocker("go", []string{"install"}, nil),
  151. shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
  152. shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
  153. shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
  154. shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
  155. shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
  156. shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
  157. shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
  158. // `go test -exec` can run arbitrary commands
  159. shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
  160. }
  161. }
  162. func NewBashTool(permissions permission.Service, workingDir string, attribution *config.Attribution) fantasy.AgentTool {
  163. // Set up command blocking on the persistent shell
  164. persistentShell := shell.GetPersistentShell(workingDir)
  165. persistentShell.SetBlockFuncs(blockFuncs())
  166. return fantasy.NewAgentTool(
  167. BashToolName,
  168. string(bashDescription(attribution)),
  169. func(ctx context.Context, params BashParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  170. if params.Timeout > MaxTimeout {
  171. params.Timeout = MaxTimeout
  172. } else if params.Timeout <= 0 {
  173. params.Timeout = DefaultTimeout
  174. }
  175. if params.Command == "" {
  176. return fantasy.NewTextErrorResponse("missing command"), nil
  177. }
  178. isSafeReadOnly := false
  179. cmdLower := strings.ToLower(params.Command)
  180. for _, safe := range safeCommands {
  181. if strings.HasPrefix(cmdLower, safe) {
  182. if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
  183. isSafeReadOnly = true
  184. break
  185. }
  186. }
  187. }
  188. sessionID := GetSessionFromContext(ctx)
  189. if sessionID == "" {
  190. return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for executing shell command")
  191. }
  192. if !isSafeReadOnly {
  193. shell := shell.GetPersistentShell(workingDir)
  194. p := permissions.Request(
  195. permission.CreatePermissionRequest{
  196. SessionID: sessionID,
  197. Path: shell.GetWorkingDir(),
  198. ToolCallID: call.ID,
  199. ToolName: BashToolName,
  200. Action: "execute",
  201. Description: fmt.Sprintf("Execute command: %s", params.Command),
  202. Params: BashPermissionsParams{
  203. Command: params.Command,
  204. Description: params.Description,
  205. },
  206. },
  207. )
  208. if !p {
  209. return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
  210. }
  211. }
  212. startTime := time.Now()
  213. if params.Timeout > 0 {
  214. var cancel context.CancelFunc
  215. ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
  216. defer cancel()
  217. }
  218. persistentShell := shell.GetPersistentShell(workingDir)
  219. stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
  220. // Get the current working directory after command execution
  221. currentWorkingDir := persistentShell.GetWorkingDir()
  222. interrupted := shell.IsInterrupt(err)
  223. exitCode := shell.ExitCode(err)
  224. if exitCode == 0 && !interrupted && err != nil {
  225. return fantasy.ToolResponse{}, fmt.Errorf("error executing command: %w", err)
  226. }
  227. stdout = truncateOutput(stdout)
  228. stderr = truncateOutput(stderr)
  229. errorMessage := stderr
  230. if errorMessage == "" && err != nil {
  231. errorMessage = err.Error()
  232. }
  233. if interrupted {
  234. if errorMessage != "" {
  235. errorMessage += "\n"
  236. }
  237. errorMessage += "Command was aborted before completion"
  238. } else if exitCode != 0 {
  239. if errorMessage != "" {
  240. errorMessage += "\n"
  241. }
  242. errorMessage += fmt.Sprintf("Exit code %d", exitCode)
  243. }
  244. hasBothOutputs := stdout != "" && stderr != ""
  245. if hasBothOutputs {
  246. stdout += "\n"
  247. }
  248. if errorMessage != "" {
  249. stdout += "\n" + errorMessage
  250. }
  251. metadata := BashResponseMetadata{
  252. StartTime: startTime.UnixMilli(),
  253. EndTime: time.Now().UnixMilli(),
  254. Output: stdout,
  255. Description: params.Description,
  256. WorkingDirectory: currentWorkingDir,
  257. }
  258. if stdout == "" {
  259. return fantasy.WithResponseMetadata(fantasy.NewTextResponse(BashNoOutput), metadata), nil
  260. }
  261. stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", normalizeWorkingDir(currentWorkingDir))
  262. return fantasy.WithResponseMetadata(fantasy.NewTextResponse(stdout), metadata), nil
  263. })
  264. }
  265. func truncateOutput(content string) string {
  266. if len(content) <= MaxOutputLength {
  267. return content
  268. }
  269. halfLength := MaxOutputLength / 2
  270. start := content[:halfLength]
  271. end := content[len(content)-halfLength:]
  272. truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
  273. return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
  274. }
  275. func countLines(s string) int {
  276. if s == "" {
  277. return 0
  278. }
  279. return len(strings.Split(s, "\n"))
  280. }
  281. func normalizeWorkingDir(path string) string {
  282. if runtime.GOOS == "windows" {
  283. cwd, err := os.Getwd()
  284. if err != nil {
  285. cwd = "C:"
  286. }
  287. path = strings.ReplaceAll(path, filepath.VolumeName(cwd), "")
  288. }
  289. return filepath.ToSlash(path)
  290. }