bash_test.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "os"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. )
  12. func TestBashTool_Info(t *testing.T) {
  13. tool := NewBashTool(newMockPermissionService(true))
  14. info := tool.Info()
  15. assert.Equal(t, BashToolName, info.Name)
  16. assert.NotEmpty(t, info.Description)
  17. assert.Contains(t, info.Parameters, "command")
  18. assert.Contains(t, info.Parameters, "timeout")
  19. assert.Contains(t, info.Required, "command")
  20. }
  21. func TestBashTool_Run(t *testing.T) {
  22. // Save original working directory
  23. origWd, err := os.Getwd()
  24. require.NoError(t, err)
  25. defer func() {
  26. os.Chdir(origWd)
  27. }()
  28. t.Run("executes command successfully", func(t *testing.T) {
  29. tool := NewBashTool(newMockPermissionService(true))
  30. params := BashParams{
  31. Command: "echo 'Hello World'",
  32. }
  33. paramsJSON, err := json.Marshal(params)
  34. require.NoError(t, err)
  35. call := ToolCall{
  36. Name: BashToolName,
  37. Input: string(paramsJSON),
  38. }
  39. response, err := tool.Run(context.Background(), call)
  40. require.NoError(t, err)
  41. assert.Equal(t, "Hello World\n", response.Content)
  42. })
  43. t.Run("handles invalid parameters", func(t *testing.T) {
  44. tool := NewBashTool(newMockPermissionService(true))
  45. call := ToolCall{
  46. Name: BashToolName,
  47. Input: "invalid json",
  48. }
  49. response, err := tool.Run(context.Background(), call)
  50. require.NoError(t, err)
  51. assert.Contains(t, response.Content, "invalid parameters")
  52. })
  53. t.Run("handles missing command", func(t *testing.T) {
  54. tool := NewBashTool(newMockPermissionService(true))
  55. params := BashParams{
  56. Command: "",
  57. }
  58. paramsJSON, err := json.Marshal(params)
  59. require.NoError(t, err)
  60. call := ToolCall{
  61. Name: BashToolName,
  62. Input: string(paramsJSON),
  63. }
  64. response, err := tool.Run(context.Background(), call)
  65. require.NoError(t, err)
  66. assert.Contains(t, response.Content, "missing command")
  67. })
  68. t.Run("handles banned commands", func(t *testing.T) {
  69. tool := NewBashTool(newMockPermissionService(true))
  70. for _, bannedCmd := range bannedCommands {
  71. params := BashParams{
  72. Command: bannedCmd + " arg1 arg2",
  73. }
  74. paramsJSON, err := json.Marshal(params)
  75. require.NoError(t, err)
  76. call := ToolCall{
  77. Name: BashToolName,
  78. Input: string(paramsJSON),
  79. }
  80. response, err := tool.Run(context.Background(), call)
  81. require.NoError(t, err)
  82. assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
  83. }
  84. })
  85. t.Run("handles multi-word safe commands without permission check", func(t *testing.T) {
  86. tool := NewBashTool(newMockPermissionService(false))
  87. // Test with multi-word safe commands
  88. multiWordCommands := []string{
  89. "go env",
  90. }
  91. for _, cmd := range multiWordCommands {
  92. params := BashParams{
  93. Command: cmd,
  94. }
  95. paramsJSON, err := json.Marshal(params)
  96. require.NoError(t, err)
  97. call := ToolCall{
  98. Name: BashToolName,
  99. Input: string(paramsJSON),
  100. }
  101. response, err := tool.Run(context.Background(), call)
  102. require.NoError(t, err)
  103. assert.NotContains(t, response.Content, "permission denied",
  104. "Command %s should be allowed without permission", cmd)
  105. }
  106. })
  107. t.Run("handles permission denied", func(t *testing.T) {
  108. tool := NewBashTool(newMockPermissionService(false))
  109. // Test with a command that requires permission
  110. params := BashParams{
  111. Command: "mkdir test_dir",
  112. }
  113. paramsJSON, err := json.Marshal(params)
  114. require.NoError(t, err)
  115. call := ToolCall{
  116. Name: BashToolName,
  117. Input: string(paramsJSON),
  118. }
  119. response, err := tool.Run(context.Background(), call)
  120. require.NoError(t, err)
  121. assert.Contains(t, response.Content, "permission denied")
  122. })
  123. t.Run("handles command timeout", func(t *testing.T) {
  124. tool := NewBashTool(newMockPermissionService(true))
  125. params := BashParams{
  126. Command: "sleep 2",
  127. Timeout: 100, // 100ms timeout
  128. }
  129. paramsJSON, err := json.Marshal(params)
  130. require.NoError(t, err)
  131. call := ToolCall{
  132. Name: BashToolName,
  133. Input: string(paramsJSON),
  134. }
  135. response, err := tool.Run(context.Background(), call)
  136. require.NoError(t, err)
  137. assert.Contains(t, response.Content, "aborted")
  138. })
  139. t.Run("handles command with stderr output", func(t *testing.T) {
  140. tool := NewBashTool(newMockPermissionService(true))
  141. params := BashParams{
  142. Command: "echo 'error message' >&2",
  143. }
  144. paramsJSON, err := json.Marshal(params)
  145. require.NoError(t, err)
  146. call := ToolCall{
  147. Name: BashToolName,
  148. Input: string(paramsJSON),
  149. }
  150. response, err := tool.Run(context.Background(), call)
  151. require.NoError(t, err)
  152. assert.Contains(t, response.Content, "error message")
  153. })
  154. t.Run("handles command with both stdout and stderr", func(t *testing.T) {
  155. tool := NewBashTool(newMockPermissionService(true))
  156. params := BashParams{
  157. Command: "echo 'stdout message' && echo 'stderr message' >&2",
  158. }
  159. paramsJSON, err := json.Marshal(params)
  160. require.NoError(t, err)
  161. call := ToolCall{
  162. Name: BashToolName,
  163. Input: string(paramsJSON),
  164. }
  165. response, err := tool.Run(context.Background(), call)
  166. require.NoError(t, err)
  167. assert.Contains(t, response.Content, "stdout message")
  168. assert.Contains(t, response.Content, "stderr message")
  169. })
  170. t.Run("handles context cancellation", func(t *testing.T) {
  171. tool := NewBashTool(newMockPermissionService(true))
  172. params := BashParams{
  173. Command: "sleep 5",
  174. }
  175. paramsJSON, err := json.Marshal(params)
  176. require.NoError(t, err)
  177. call := ToolCall{
  178. Name: BashToolName,
  179. Input: string(paramsJSON),
  180. }
  181. ctx, cancel := context.WithCancel(context.Background())
  182. // Cancel the context after a short delay
  183. go func() {
  184. time.Sleep(100 * time.Millisecond)
  185. cancel()
  186. }()
  187. response, err := tool.Run(ctx, call)
  188. require.NoError(t, err)
  189. assert.Contains(t, response.Content, "aborted")
  190. })
  191. t.Run("respects max timeout", func(t *testing.T) {
  192. tool := NewBashTool(newMockPermissionService(true))
  193. params := BashParams{
  194. Command: "echo 'test'",
  195. Timeout: MaxTimeout + 1000, // Exceeds max timeout
  196. }
  197. paramsJSON, err := json.Marshal(params)
  198. require.NoError(t, err)
  199. call := ToolCall{
  200. Name: BashToolName,
  201. Input: string(paramsJSON),
  202. }
  203. response, err := tool.Run(context.Background(), call)
  204. require.NoError(t, err)
  205. assert.Equal(t, "test\n", response.Content)
  206. })
  207. t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
  208. tool := NewBashTool(newMockPermissionService(true))
  209. params := BashParams{
  210. Command: "echo 'test'",
  211. Timeout: -100, // Negative timeout
  212. }
  213. paramsJSON, err := json.Marshal(params)
  214. require.NoError(t, err)
  215. call := ToolCall{
  216. Name: BashToolName,
  217. Input: string(paramsJSON),
  218. }
  219. response, err := tool.Run(context.Background(), call)
  220. require.NoError(t, err)
  221. assert.Equal(t, "test\n", response.Content)
  222. })
  223. }
  224. func TestTruncateOutput(t *testing.T) {
  225. t.Run("does not truncate short output", func(t *testing.T) {
  226. output := "short output"
  227. result := truncateOutput(output)
  228. assert.Equal(t, output, result)
  229. })
  230. t.Run("truncates long output", func(t *testing.T) {
  231. // Create a string longer than MaxOutputLength
  232. longOutput := strings.Repeat("a\n", MaxOutputLength)
  233. result := truncateOutput(longOutput)
  234. // Check that the result is shorter than the original
  235. assert.Less(t, len(result), len(longOutput))
  236. // Check that the truncation message is included
  237. assert.Contains(t, result, "lines truncated")
  238. // Check that we have the beginning and end of the original string
  239. assert.True(t, strings.HasPrefix(result, "a\n"))
  240. assert.True(t, strings.HasSuffix(result, "a\n"))
  241. })
  242. }
  243. func TestCountLines(t *testing.T) {
  244. testCases := []struct {
  245. name string
  246. input string
  247. expected int
  248. }{
  249. {
  250. name: "empty string",
  251. input: "",
  252. expected: 0,
  253. },
  254. {
  255. name: "single line",
  256. input: "line1",
  257. expected: 1,
  258. },
  259. {
  260. name: "multiple lines",
  261. input: "line1\nline2\nline3",
  262. expected: 3,
  263. },
  264. {
  265. name: "trailing newline",
  266. input: "line1\nline2\n",
  267. expected: 3, // Empty string after last newline counts as a line
  268. },
  269. }
  270. for _, tc := range testCases {
  271. t.Run(tc.name, func(t *testing.T) {
  272. result := countLines(tc.input)
  273. assert.Equal(t, tc.expected, result)
  274. })
  275. }
  276. }