bash_test.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "os"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/kujtimiihoxha/termai/internal/permission"
  10. "github.com/kujtimiihoxha/termai/internal/pubsub"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func TestBashTool_Info(t *testing.T) {
  15. tool := NewBashTool()
  16. info := tool.Info()
  17. assert.Equal(t, BashToolName, info.Name)
  18. assert.NotEmpty(t, info.Description)
  19. assert.Contains(t, info.Parameters, "command")
  20. assert.Contains(t, info.Parameters, "timeout")
  21. assert.Contains(t, info.Required, "command")
  22. }
  23. func TestBashTool_Run(t *testing.T) {
  24. // Setup a mock permission handler that always allows
  25. origPermission := permission.Default
  26. defer func() {
  27. permission.Default = origPermission
  28. }()
  29. permission.Default = newMockPermissionService(true)
  30. // Save original working directory
  31. origWd, err := os.Getwd()
  32. require.NoError(t, err)
  33. defer func() {
  34. os.Chdir(origWd)
  35. }()
  36. t.Run("executes command successfully", func(t *testing.T) {
  37. permission.Default = newMockPermissionService(true)
  38. tool := NewBashTool()
  39. params := BashParams{
  40. Command: "echo 'Hello World'",
  41. }
  42. paramsJSON, err := json.Marshal(params)
  43. require.NoError(t, err)
  44. call := ToolCall{
  45. Name: BashToolName,
  46. Input: string(paramsJSON),
  47. }
  48. response, err := tool.Run(context.Background(), call)
  49. require.NoError(t, err)
  50. assert.Equal(t, "Hello World\n", response.Content)
  51. })
  52. t.Run("handles invalid parameters", func(t *testing.T) {
  53. permission.Default = newMockPermissionService(true)
  54. tool := NewBashTool()
  55. call := ToolCall{
  56. Name: BashToolName,
  57. Input: "invalid json",
  58. }
  59. response, err := tool.Run(context.Background(), call)
  60. require.NoError(t, err)
  61. assert.Contains(t, response.Content, "invalid parameters")
  62. })
  63. t.Run("handles missing command", func(t *testing.T) {
  64. permission.Default = newMockPermissionService(true)
  65. tool := NewBashTool()
  66. params := BashParams{
  67. Command: "",
  68. }
  69. paramsJSON, err := json.Marshal(params)
  70. require.NoError(t, err)
  71. call := ToolCall{
  72. Name: BashToolName,
  73. Input: string(paramsJSON),
  74. }
  75. response, err := tool.Run(context.Background(), call)
  76. require.NoError(t, err)
  77. assert.Contains(t, response.Content, "missing command")
  78. })
  79. t.Run("handles banned commands", func(t *testing.T) {
  80. permission.Default = newMockPermissionService(true)
  81. tool := NewBashTool()
  82. for _, bannedCmd := range BannedCommands {
  83. params := BashParams{
  84. Command: bannedCmd + " arg1 arg2",
  85. }
  86. paramsJSON, err := json.Marshal(params)
  87. require.NoError(t, err)
  88. call := ToolCall{
  89. Name: BashToolName,
  90. Input: string(paramsJSON),
  91. }
  92. response, err := tool.Run(context.Background(), call)
  93. require.NoError(t, err)
  94. assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
  95. }
  96. })
  97. t.Run("handles multi-word safe commands without permission check", func(t *testing.T) {
  98. permission.Default = newMockPermissionService(false)
  99. tool := NewBashTool()
  100. // Test with multi-word safe commands
  101. multiWordCommands := []string{
  102. "git status",
  103. "git log -n 5",
  104. "docker ps",
  105. "go test ./...",
  106. "kubectl get pods",
  107. }
  108. for _, cmd := range multiWordCommands {
  109. params := BashParams{
  110. Command: cmd,
  111. }
  112. paramsJSON, err := json.Marshal(params)
  113. require.NoError(t, err)
  114. call := ToolCall{
  115. Name: BashToolName,
  116. Input: string(paramsJSON),
  117. }
  118. response, err := tool.Run(context.Background(), call)
  119. require.NoError(t, err)
  120. assert.NotContains(t, response.Content, "permission denied",
  121. "Command %s should be allowed without permission", cmd)
  122. }
  123. })
  124. t.Run("handles permission denied", func(t *testing.T) {
  125. permission.Default = newMockPermissionService(false)
  126. tool := NewBashTool()
  127. // Test with a command that requires permission
  128. params := BashParams{
  129. Command: "mkdir test_dir",
  130. }
  131. paramsJSON, err := json.Marshal(params)
  132. require.NoError(t, err)
  133. call := ToolCall{
  134. Name: BashToolName,
  135. Input: string(paramsJSON),
  136. }
  137. response, err := tool.Run(context.Background(), call)
  138. require.NoError(t, err)
  139. assert.Contains(t, response.Content, "permission denied")
  140. })
  141. t.Run("handles command timeout", func(t *testing.T) {
  142. permission.Default = newMockPermissionService(true)
  143. tool := NewBashTool()
  144. params := BashParams{
  145. Command: "sleep 2",
  146. Timeout: 100, // 100ms timeout
  147. }
  148. paramsJSON, err := json.Marshal(params)
  149. require.NoError(t, err)
  150. call := ToolCall{
  151. Name: BashToolName,
  152. Input: string(paramsJSON),
  153. }
  154. response, err := tool.Run(context.Background(), call)
  155. require.NoError(t, err)
  156. assert.Contains(t, response.Content, "aborted")
  157. })
  158. t.Run("handles command with stderr output", func(t *testing.T) {
  159. permission.Default = newMockPermissionService(true)
  160. tool := NewBashTool()
  161. params := BashParams{
  162. Command: "echo 'error message' >&2",
  163. }
  164. paramsJSON, err := json.Marshal(params)
  165. require.NoError(t, err)
  166. call := ToolCall{
  167. Name: BashToolName,
  168. Input: string(paramsJSON),
  169. }
  170. response, err := tool.Run(context.Background(), call)
  171. require.NoError(t, err)
  172. assert.Contains(t, response.Content, "error message")
  173. })
  174. t.Run("handles command with both stdout and stderr", func(t *testing.T) {
  175. permission.Default = newMockPermissionService(true)
  176. tool := NewBashTool()
  177. params := BashParams{
  178. Command: "echo 'stdout message' && echo 'stderr message' >&2",
  179. }
  180. paramsJSON, err := json.Marshal(params)
  181. require.NoError(t, err)
  182. call := ToolCall{
  183. Name: BashToolName,
  184. Input: string(paramsJSON),
  185. }
  186. response, err := tool.Run(context.Background(), call)
  187. require.NoError(t, err)
  188. assert.Contains(t, response.Content, "stdout message")
  189. assert.Contains(t, response.Content, "stderr message")
  190. })
  191. t.Run("handles context cancellation", func(t *testing.T) {
  192. permission.Default = newMockPermissionService(true)
  193. tool := NewBashTool()
  194. params := BashParams{
  195. Command: "sleep 5",
  196. }
  197. paramsJSON, err := json.Marshal(params)
  198. require.NoError(t, err)
  199. call := ToolCall{
  200. Name: BashToolName,
  201. Input: string(paramsJSON),
  202. }
  203. ctx, cancel := context.WithCancel(context.Background())
  204. // Cancel the context after a short delay
  205. go func() {
  206. time.Sleep(100 * time.Millisecond)
  207. cancel()
  208. }()
  209. response, err := tool.Run(ctx, call)
  210. require.NoError(t, err)
  211. assert.Contains(t, response.Content, "aborted")
  212. })
  213. t.Run("respects max timeout", func(t *testing.T) {
  214. permission.Default = newMockPermissionService(true)
  215. tool := NewBashTool()
  216. params := BashParams{
  217. Command: "echo 'test'",
  218. Timeout: MaxTimeout + 1000, // Exceeds max timeout
  219. }
  220. paramsJSON, err := json.Marshal(params)
  221. require.NoError(t, err)
  222. call := ToolCall{
  223. Name: BashToolName,
  224. Input: string(paramsJSON),
  225. }
  226. response, err := tool.Run(context.Background(), call)
  227. require.NoError(t, err)
  228. assert.Equal(t, "test\n", response.Content)
  229. })
  230. t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
  231. permission.Default = newMockPermissionService(true)
  232. tool := NewBashTool()
  233. params := BashParams{
  234. Command: "echo 'test'",
  235. Timeout: -100, // Negative timeout
  236. }
  237. paramsJSON, err := json.Marshal(params)
  238. require.NoError(t, err)
  239. call := ToolCall{
  240. Name: BashToolName,
  241. Input: string(paramsJSON),
  242. }
  243. response, err := tool.Run(context.Background(), call)
  244. require.NoError(t, err)
  245. assert.Equal(t, "test\n", response.Content)
  246. })
  247. }
  248. func TestTruncateOutput(t *testing.T) {
  249. t.Run("does not truncate short output", func(t *testing.T) {
  250. output := "short output"
  251. result := truncateOutput(output)
  252. assert.Equal(t, output, result)
  253. })
  254. t.Run("truncates long output", func(t *testing.T) {
  255. // Create a string longer than MaxOutputLength
  256. longOutput := strings.Repeat("a\n", MaxOutputLength)
  257. result := truncateOutput(longOutput)
  258. // Check that the result is shorter than the original
  259. assert.Less(t, len(result), len(longOutput))
  260. // Check that the truncation message is included
  261. assert.Contains(t, result, "lines truncated")
  262. // Check that we have the beginning and end of the original string
  263. assert.True(t, strings.HasPrefix(result, "a\n"))
  264. assert.True(t, strings.HasSuffix(result, "a\n"))
  265. })
  266. }
  267. func TestCountLines(t *testing.T) {
  268. testCases := []struct {
  269. name string
  270. input string
  271. expected int
  272. }{
  273. {
  274. name: "empty string",
  275. input: "",
  276. expected: 0,
  277. },
  278. {
  279. name: "single line",
  280. input: "line1",
  281. expected: 1,
  282. },
  283. {
  284. name: "multiple lines",
  285. input: "line1\nline2\nline3",
  286. expected: 3,
  287. },
  288. {
  289. name: "trailing newline",
  290. input: "line1\nline2\n",
  291. expected: 3, // Empty string after last newline counts as a line
  292. },
  293. }
  294. for _, tc := range testCases {
  295. t.Run(tc.name, func(t *testing.T) {
  296. result := countLines(tc.input)
  297. assert.Equal(t, tc.expected, result)
  298. })
  299. }
  300. }
  301. // Mock permission service for testing
  302. type mockPermissionService struct {
  303. *pubsub.Broker[permission.PermissionRequest]
  304. allow bool
  305. }
  306. func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
  307. // Not needed for tests
  308. }
  309. func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
  310. // Not needed for tests
  311. }
  312. func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
  313. // Not needed for tests
  314. }
  315. func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
  316. return m.allow
  317. }
  318. func newMockPermissionService(allow bool) permission.Service {
  319. return &mockPermissionService{
  320. Broker: pubsub.NewBroker[permission.PermissionRequest](),
  321. allow: allow,
  322. }
  323. }