bash_test.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "os"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/kujtimiihoxha/termai/internal/permission"
  10. "github.com/kujtimiihoxha/termai/internal/pubsub"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func TestBashTool_Info(t *testing.T) {
  15. tool := NewBashTool()
  16. info := tool.Info()
  17. assert.Equal(t, BashToolName, info.Name)
  18. assert.NotEmpty(t, info.Description)
  19. assert.Contains(t, info.Parameters, "command")
  20. assert.Contains(t, info.Parameters, "timeout")
  21. assert.Contains(t, info.Required, "command")
  22. }
  23. func TestBashTool_Run(t *testing.T) {
  24. // Setup a mock permission handler that always allows
  25. origPermission := permission.Default
  26. defer func() {
  27. permission.Default = origPermission
  28. }()
  29. permission.Default = newMockPermissionService(true)
  30. // Save original working directory
  31. origWd, err := os.Getwd()
  32. require.NoError(t, err)
  33. defer func() {
  34. os.Chdir(origWd)
  35. }()
  36. t.Run("executes command successfully", func(t *testing.T) {
  37. permission.Default = newMockPermissionService(true)
  38. tool := NewBashTool()
  39. params := BashParams{
  40. Command: "echo 'Hello World'",
  41. }
  42. paramsJSON, err := json.Marshal(params)
  43. require.NoError(t, err)
  44. call := ToolCall{
  45. Name: BashToolName,
  46. Input: string(paramsJSON),
  47. }
  48. response, err := tool.Run(context.Background(), call)
  49. require.NoError(t, err)
  50. assert.Equal(t, "Hello World\n", response.Content)
  51. })
  52. t.Run("handles invalid parameters", func(t *testing.T) {
  53. permission.Default = newMockPermissionService(true)
  54. tool := NewBashTool()
  55. call := ToolCall{
  56. Name: BashToolName,
  57. Input: "invalid json",
  58. }
  59. response, err := tool.Run(context.Background(), call)
  60. require.NoError(t, err)
  61. assert.Contains(t, response.Content, "invalid parameters")
  62. })
  63. t.Run("handles missing command", func(t *testing.T) {
  64. permission.Default = newMockPermissionService(true)
  65. tool := NewBashTool()
  66. params := BashParams{
  67. Command: "",
  68. }
  69. paramsJSON, err := json.Marshal(params)
  70. require.NoError(t, err)
  71. call := ToolCall{
  72. Name: BashToolName,
  73. Input: string(paramsJSON),
  74. }
  75. response, err := tool.Run(context.Background(), call)
  76. require.NoError(t, err)
  77. assert.Contains(t, response.Content, "missing command")
  78. })
  79. t.Run("handles banned commands", func(t *testing.T) {
  80. permission.Default = newMockPermissionService(true)
  81. tool := NewBashTool()
  82. for _, bannedCmd := range BannedCommands {
  83. params := BashParams{
  84. Command: bannedCmd + " arg1 arg2",
  85. }
  86. paramsJSON, err := json.Marshal(params)
  87. require.NoError(t, err)
  88. call := ToolCall{
  89. Name: BashToolName,
  90. Input: string(paramsJSON),
  91. }
  92. response, err := tool.Run(context.Background(), call)
  93. require.NoError(t, err)
  94. assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
  95. }
  96. })
  97. t.Run("handles safe read-only commands without permission check", func(t *testing.T) {
  98. permission.Default = newMockPermissionService(false)
  99. tool := NewBashTool()
  100. // Test with a safe read-only command
  101. params := BashParams{
  102. Command: "echo 'test'",
  103. }
  104. paramsJSON, err := json.Marshal(params)
  105. require.NoError(t, err)
  106. call := ToolCall{
  107. Name: BashToolName,
  108. Input: string(paramsJSON),
  109. }
  110. response, err := tool.Run(context.Background(), call)
  111. require.NoError(t, err)
  112. assert.Equal(t, "test\n", response.Content)
  113. })
  114. t.Run("handles permission denied", func(t *testing.T) {
  115. permission.Default = newMockPermissionService(false)
  116. tool := NewBashTool()
  117. // Test with a command that requires permission
  118. params := BashParams{
  119. Command: "mkdir test_dir",
  120. }
  121. paramsJSON, err := json.Marshal(params)
  122. require.NoError(t, err)
  123. call := ToolCall{
  124. Name: BashToolName,
  125. Input: string(paramsJSON),
  126. }
  127. response, err := tool.Run(context.Background(), call)
  128. require.NoError(t, err)
  129. assert.Contains(t, response.Content, "permission denied")
  130. })
  131. t.Run("handles command timeout", func(t *testing.T) {
  132. permission.Default = newMockPermissionService(true)
  133. tool := NewBashTool()
  134. params := BashParams{
  135. Command: "sleep 2",
  136. Timeout: 100, // 100ms timeout
  137. }
  138. paramsJSON, err := json.Marshal(params)
  139. require.NoError(t, err)
  140. call := ToolCall{
  141. Name: BashToolName,
  142. Input: string(paramsJSON),
  143. }
  144. response, err := tool.Run(context.Background(), call)
  145. require.NoError(t, err)
  146. assert.Contains(t, response.Content, "aborted")
  147. })
  148. t.Run("handles command with stderr output", func(t *testing.T) {
  149. permission.Default = newMockPermissionService(true)
  150. tool := NewBashTool()
  151. params := BashParams{
  152. Command: "echo 'error message' >&2",
  153. }
  154. paramsJSON, err := json.Marshal(params)
  155. require.NoError(t, err)
  156. call := ToolCall{
  157. Name: BashToolName,
  158. Input: string(paramsJSON),
  159. }
  160. response, err := tool.Run(context.Background(), call)
  161. require.NoError(t, err)
  162. assert.Contains(t, response.Content, "error message")
  163. })
  164. t.Run("handles command with both stdout and stderr", func(t *testing.T) {
  165. permission.Default = newMockPermissionService(true)
  166. tool := NewBashTool()
  167. params := BashParams{
  168. Command: "echo 'stdout message' && echo 'stderr message' >&2",
  169. }
  170. paramsJSON, err := json.Marshal(params)
  171. require.NoError(t, err)
  172. call := ToolCall{
  173. Name: BashToolName,
  174. Input: string(paramsJSON),
  175. }
  176. response, err := tool.Run(context.Background(), call)
  177. require.NoError(t, err)
  178. assert.Contains(t, response.Content, "stdout message")
  179. assert.Contains(t, response.Content, "stderr message")
  180. })
  181. t.Run("handles context cancellation", func(t *testing.T) {
  182. permission.Default = newMockPermissionService(true)
  183. tool := NewBashTool()
  184. params := BashParams{
  185. Command: "sleep 5",
  186. }
  187. paramsJSON, err := json.Marshal(params)
  188. require.NoError(t, err)
  189. call := ToolCall{
  190. Name: BashToolName,
  191. Input: string(paramsJSON),
  192. }
  193. ctx, cancel := context.WithCancel(context.Background())
  194. // Cancel the context after a short delay
  195. go func() {
  196. time.Sleep(100 * time.Millisecond)
  197. cancel()
  198. }()
  199. response, err := tool.Run(ctx, call)
  200. require.NoError(t, err)
  201. assert.Contains(t, response.Content, "aborted")
  202. })
  203. t.Run("respects max timeout", func(t *testing.T) {
  204. permission.Default = newMockPermissionService(true)
  205. tool := NewBashTool()
  206. params := BashParams{
  207. Command: "echo 'test'",
  208. Timeout: MaxTimeout + 1000, // Exceeds max timeout
  209. }
  210. paramsJSON, err := json.Marshal(params)
  211. require.NoError(t, err)
  212. call := ToolCall{
  213. Name: BashToolName,
  214. Input: string(paramsJSON),
  215. }
  216. response, err := tool.Run(context.Background(), call)
  217. require.NoError(t, err)
  218. assert.Equal(t, "test\n", response.Content)
  219. })
  220. t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
  221. permission.Default = newMockPermissionService(true)
  222. tool := NewBashTool()
  223. params := BashParams{
  224. Command: "echo 'test'",
  225. Timeout: -100, // Negative timeout
  226. }
  227. paramsJSON, err := json.Marshal(params)
  228. require.NoError(t, err)
  229. call := ToolCall{
  230. Name: BashToolName,
  231. Input: string(paramsJSON),
  232. }
  233. response, err := tool.Run(context.Background(), call)
  234. require.NoError(t, err)
  235. assert.Equal(t, "test\n", response.Content)
  236. })
  237. }
  238. func TestTruncateOutput(t *testing.T) {
  239. t.Run("does not truncate short output", func(t *testing.T) {
  240. output := "short output"
  241. result := truncateOutput(output)
  242. assert.Equal(t, output, result)
  243. })
  244. t.Run("truncates long output", func(t *testing.T) {
  245. // Create a string longer than MaxOutputLength
  246. longOutput := strings.Repeat("a\n", MaxOutputLength)
  247. result := truncateOutput(longOutput)
  248. // Check that the result is shorter than the original
  249. assert.Less(t, len(result), len(longOutput))
  250. // Check that the truncation message is included
  251. assert.Contains(t, result, "lines truncated")
  252. // Check that we have the beginning and end of the original string
  253. assert.True(t, strings.HasPrefix(result, "a\n"))
  254. assert.True(t, strings.HasSuffix(result, "a\n"))
  255. })
  256. }
  257. func TestCountLines(t *testing.T) {
  258. testCases := []struct {
  259. name string
  260. input string
  261. expected int
  262. }{
  263. {
  264. name: "empty string",
  265. input: "",
  266. expected: 0,
  267. },
  268. {
  269. name: "single line",
  270. input: "line1",
  271. expected: 1,
  272. },
  273. {
  274. name: "multiple lines",
  275. input: "line1\nline2\nline3",
  276. expected: 3,
  277. },
  278. {
  279. name: "trailing newline",
  280. input: "line1\nline2\n",
  281. expected: 3, // Empty string after last newline counts as a line
  282. },
  283. }
  284. for _, tc := range testCases {
  285. t.Run(tc.name, func(t *testing.T) {
  286. result := countLines(tc.input)
  287. assert.Equal(t, tc.expected, result)
  288. })
  289. }
  290. }
  291. // Mock permission service for testing
  292. type mockPermissionService struct {
  293. *pubsub.Broker[permission.PermissionRequest]
  294. allow bool
  295. }
  296. func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
  297. // Not needed for tests
  298. }
  299. func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
  300. // Not needed for tests
  301. }
  302. func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
  303. // Not needed for tests
  304. }
  305. func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
  306. return m.allow
  307. }
  308. func newMockPermissionService(allow bool) permission.Service {
  309. return &mockPermissionService{
  310. Broker: pubsub.NewBroker[permission.PermissionRequest](),
  311. allow: allow,
  312. }
  313. }