2
0

bash.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. package tools
  2. import (
  3. "bytes"
  4. "context"
  5. _ "embed"
  6. "encoding/json"
  7. "fmt"
  8. "html/template"
  9. "strings"
  10. "time"
  11. "github.com/charmbracelet/crush/internal/config"
  12. "github.com/charmbracelet/crush/internal/permission"
  13. "github.com/charmbracelet/crush/internal/shell"
  14. )
  15. type BashParams struct {
  16. Command string `json:"command"`
  17. Timeout int `json:"timeout"`
  18. }
  19. type BashPermissionsParams struct {
  20. Command string `json:"command"`
  21. Timeout int `json:"timeout"`
  22. }
  23. type BashResponseMetadata struct {
  24. StartTime int64 `json:"start_time"`
  25. EndTime int64 `json:"end_time"`
  26. Output string `json:"output"`
  27. WorkingDirectory string `json:"working_directory"`
  28. }
  29. type bashTool struct {
  30. permissions permission.Service
  31. workingDir string
  32. attribution *config.Attribution
  33. }
  34. const (
  35. BashToolName = "bash"
  36. DefaultTimeout = 1 * 60 * 1000 // 1 minutes in milliseconds
  37. MaxTimeout = 10 * 60 * 1000 // 10 minutes in milliseconds
  38. MaxOutputLength = 30000
  39. BashNoOutput = "no output"
  40. )
  41. //go:embed bash.md
  42. var bashDescription []byte
  43. var bashDescriptionTpl = template.Must(
  44. template.New("bashDescription").
  45. Parse(string(bashDescription)),
  46. )
  47. type bashDescriptionData struct {
  48. BannedCommands string
  49. MaxOutputLength int
  50. AttributionStep string
  51. AttributionExample string
  52. PRAttribution string
  53. }
  54. var bannedCommands = []string{
  55. // Network/Download tools
  56. "alias",
  57. "aria2c",
  58. "axel",
  59. "chrome",
  60. "curl",
  61. "curlie",
  62. "firefox",
  63. "http-prompt",
  64. "httpie",
  65. "links",
  66. "lynx",
  67. "nc",
  68. "safari",
  69. "scp",
  70. "ssh",
  71. "telnet",
  72. "w3m",
  73. "wget",
  74. "xh",
  75. // System administration
  76. "doas",
  77. "su",
  78. "sudo",
  79. // Package managers
  80. "apk",
  81. "apt",
  82. "apt-cache",
  83. "apt-get",
  84. "dnf",
  85. "dpkg",
  86. "emerge",
  87. "home-manager",
  88. "makepkg",
  89. "opkg",
  90. "pacman",
  91. "paru",
  92. "pkg",
  93. "pkg_add",
  94. "pkg_delete",
  95. "portage",
  96. "rpm",
  97. "yay",
  98. "yum",
  99. "zypper",
  100. // System modification
  101. "at",
  102. "batch",
  103. "chkconfig",
  104. "crontab",
  105. "fdisk",
  106. "mkfs",
  107. "mount",
  108. "parted",
  109. "service",
  110. "systemctl",
  111. "umount",
  112. // Network configuration
  113. "firewall-cmd",
  114. "ifconfig",
  115. "ip",
  116. "iptables",
  117. "netstat",
  118. "pfctl",
  119. "route",
  120. "ufw",
  121. }
  122. func (b *bashTool) bashDescription() string {
  123. bannedCommandsStr := strings.Join(bannedCommands, ", ")
  124. // Build attribution text based on settings
  125. var attributionStep, attributionExample, prAttribution string
  126. // Default to true if attribution is nil (backward compatibility)
  127. generatedWith := b.attribution == nil || b.attribution.GeneratedWith
  128. coAuthoredBy := b.attribution == nil || b.attribution.CoAuthoredBy
  129. // Build PR attribution
  130. if generatedWith {
  131. prAttribution = "💘 Generated with Crush"
  132. }
  133. if generatedWith || coAuthoredBy {
  134. var attributionParts []string
  135. if generatedWith {
  136. attributionParts = append(attributionParts, "💘 Generated with Crush")
  137. }
  138. if coAuthoredBy {
  139. attributionParts = append(attributionParts, "Co-Authored-By: Crush <[email protected]>")
  140. }
  141. if len(attributionParts) > 0 {
  142. attributionStep = fmt.Sprintf("4. Create the commit with a message ending with:\n%s", strings.Join(attributionParts, "\n"))
  143. attributionText := strings.Join(attributionParts, "\n ")
  144. attributionExample = fmt.Sprintf(`<example>
  145. git commit -m "$(cat <<'EOF'
  146. Commit message here.
  147. %s
  148. EOF
  149. )"</example>`, attributionText)
  150. }
  151. }
  152. if attributionStep == "" {
  153. attributionStep = "4. Create the commit with your commit message."
  154. attributionExample = `<example>
  155. git commit -m "$(cat <<'EOF'
  156. Commit message here.
  157. EOF
  158. )"</example>`
  159. }
  160. var out bytes.Buffer
  161. if err := bashDescriptionTpl.Execute(&out, bashDescriptionData{
  162. BannedCommands: bannedCommandsStr,
  163. MaxOutputLength: MaxOutputLength,
  164. AttributionStep: attributionStep,
  165. AttributionExample: attributionExample,
  166. PRAttribution: prAttribution,
  167. }); err != nil {
  168. // this should never happen.
  169. panic("failed to execute bash description template: " + err.Error())
  170. }
  171. return out.String()
  172. }
  173. func blockFuncs() []shell.BlockFunc {
  174. return []shell.BlockFunc{
  175. shell.CommandsBlocker(bannedCommands),
  176. // System package managers
  177. shell.ArgumentsBlocker("apk", []string{"add"}, nil),
  178. shell.ArgumentsBlocker("apt", []string{"install"}, nil),
  179. shell.ArgumentsBlocker("apt-get", []string{"install"}, nil),
  180. shell.ArgumentsBlocker("dnf", []string{"install"}, nil),
  181. shell.ArgumentsBlocker("pacman", nil, []string{"-S"}),
  182. shell.ArgumentsBlocker("pkg", []string{"install"}, nil),
  183. shell.ArgumentsBlocker("yum", []string{"install"}, nil),
  184. shell.ArgumentsBlocker("zypper", []string{"install"}, nil),
  185. // Language-specific package managers
  186. shell.ArgumentsBlocker("brew", []string{"install"}, nil),
  187. shell.ArgumentsBlocker("cargo", []string{"install"}, nil),
  188. shell.ArgumentsBlocker("gem", []string{"install"}, nil),
  189. shell.ArgumentsBlocker("go", []string{"install"}, nil),
  190. shell.ArgumentsBlocker("npm", []string{"install"}, []string{"--global"}),
  191. shell.ArgumentsBlocker("npm", []string{"install"}, []string{"-g"}),
  192. shell.ArgumentsBlocker("pip", []string{"install"}, []string{"--user"}),
  193. shell.ArgumentsBlocker("pip3", []string{"install"}, []string{"--user"}),
  194. shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"--global"}),
  195. shell.ArgumentsBlocker("pnpm", []string{"add"}, []string{"-g"}),
  196. shell.ArgumentsBlocker("yarn", []string{"global", "add"}, nil),
  197. // `go test -exec` can run arbitrary commands
  198. shell.ArgumentsBlocker("go", []string{"test"}, []string{"-exec"}),
  199. }
  200. }
  201. func NewBashTool(permission permission.Service, workingDir string, attribution *config.Attribution) BaseTool {
  202. // Set up command blocking on the persistent shell
  203. persistentShell := shell.GetPersistentShell(workingDir)
  204. persistentShell.SetBlockFuncs(blockFuncs())
  205. return &bashTool{
  206. permissions: permission,
  207. workingDir: workingDir,
  208. attribution: attribution,
  209. }
  210. }
  211. func (b *bashTool) Name() string {
  212. return BashToolName
  213. }
  214. func (b *bashTool) Info() ToolInfo {
  215. return ToolInfo{
  216. Name: BashToolName,
  217. Description: b.bashDescription(),
  218. Parameters: map[string]any{
  219. "command": map[string]any{
  220. "type": "string",
  221. "description": "The command to execute",
  222. },
  223. "timeout": map[string]any{
  224. "type": "number",
  225. "description": "Optional timeout in milliseconds (max 600000)",
  226. },
  227. },
  228. Required: []string{"command"},
  229. }
  230. }
  231. func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
  232. var params BashParams
  233. if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
  234. return NewTextErrorResponse("invalid parameters"), nil
  235. }
  236. if params.Timeout > MaxTimeout {
  237. params.Timeout = MaxTimeout
  238. } else if params.Timeout <= 0 {
  239. params.Timeout = DefaultTimeout
  240. }
  241. if params.Command == "" {
  242. return NewTextErrorResponse("missing command"), nil
  243. }
  244. isSafeReadOnly := false
  245. cmdLower := strings.ToLower(params.Command)
  246. for _, safe := range safeCommands {
  247. if strings.HasPrefix(cmdLower, safe) {
  248. if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
  249. isSafeReadOnly = true
  250. break
  251. }
  252. }
  253. }
  254. sessionID, messageID := GetContextValues(ctx)
  255. if sessionID == "" || messageID == "" {
  256. return ToolResponse{}, fmt.Errorf("session ID and message ID are required for executing shell command")
  257. }
  258. if !isSafeReadOnly {
  259. shell := shell.GetPersistentShell(b.workingDir)
  260. p := b.permissions.Request(
  261. permission.CreatePermissionRequest{
  262. SessionID: sessionID,
  263. Path: shell.GetWorkingDir(),
  264. ToolCallID: call.ID,
  265. ToolName: BashToolName,
  266. Action: "execute",
  267. Description: fmt.Sprintf("Execute command: %s", params.Command),
  268. Params: BashPermissionsParams{
  269. Command: params.Command,
  270. },
  271. },
  272. )
  273. if !p {
  274. return ToolResponse{}, permission.ErrorPermissionDenied
  275. }
  276. }
  277. startTime := time.Now()
  278. if params.Timeout > 0 {
  279. var cancel context.CancelFunc
  280. ctx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Millisecond)
  281. defer cancel()
  282. }
  283. persistentShell := shell.GetPersistentShell(b.workingDir)
  284. stdout, stderr, err := persistentShell.Exec(ctx, params.Command)
  285. // Get the current working directory after command execution
  286. currentWorkingDir := persistentShell.GetWorkingDir()
  287. interrupted := shell.IsInterrupt(err)
  288. exitCode := shell.ExitCode(err)
  289. if exitCode == 0 && !interrupted && err != nil {
  290. return ToolResponse{}, fmt.Errorf("error executing command: %w", err)
  291. }
  292. stdout = truncateOutput(stdout)
  293. stderr = truncateOutput(stderr)
  294. errorMessage := stderr
  295. if errorMessage == "" && err != nil {
  296. errorMessage = err.Error()
  297. }
  298. if interrupted {
  299. if errorMessage != "" {
  300. errorMessage += "\n"
  301. }
  302. errorMessage += "Command was aborted before completion"
  303. } else if exitCode != 0 {
  304. if errorMessage != "" {
  305. errorMessage += "\n"
  306. }
  307. errorMessage += fmt.Sprintf("Exit code %d", exitCode)
  308. }
  309. hasBothOutputs := stdout != "" && stderr != ""
  310. if hasBothOutputs {
  311. stdout += "\n"
  312. }
  313. if errorMessage != "" {
  314. stdout += "\n" + errorMessage
  315. }
  316. metadata := BashResponseMetadata{
  317. StartTime: startTime.UnixMilli(),
  318. EndTime: time.Now().UnixMilli(),
  319. Output: stdout,
  320. WorkingDirectory: currentWorkingDir,
  321. }
  322. if stdout == "" {
  323. return WithResponseMetadata(NewTextResponse(BashNoOutput), metadata), nil
  324. }
  325. stdout += fmt.Sprintf("\n\n<cwd>%s</cwd>", currentWorkingDir)
  326. return WithResponseMetadata(NewTextResponse(stdout), metadata), nil
  327. }
  328. func truncateOutput(content string) string {
  329. if len(content) <= MaxOutputLength {
  330. return content
  331. }
  332. halfLength := MaxOutputLength / 2
  333. start := content[:halfLength]
  334. end := content[len(content)-halfLength:]
  335. truncatedLinesCount := countLines(content[halfLength : len(content)-halfLength])
  336. return fmt.Sprintf("%s\n\n... [%d lines truncated] ...\n\n%s", start, truncatedLinesCount, end)
  337. }
  338. func countLines(s string) int {
  339. if s == "" {
  340. return 0
  341. }
  342. return len(strings.Split(s, "\n"))
  343. }