bash_test.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "os"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/kujtimiihoxha/termai/internal/permission"
  10. "github.com/kujtimiihoxha/termai/internal/pubsub"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func TestBashTool_Info(t *testing.T) {
  15. tool := NewBashTool(newMockPermissionService(true))
  16. info := tool.Info()
  17. assert.Equal(t, BashToolName, info.Name)
  18. assert.NotEmpty(t, info.Description)
  19. assert.Contains(t, info.Parameters, "command")
  20. assert.Contains(t, info.Parameters, "timeout")
  21. assert.Contains(t, info.Required, "command")
  22. }
  23. func TestBashTool_Run(t *testing.T) {
  24. // Save original working directory
  25. origWd, err := os.Getwd()
  26. require.NoError(t, err)
  27. defer func() {
  28. os.Chdir(origWd)
  29. }()
  30. t.Run("executes command successfully", func(t *testing.T) {
  31. tool := NewBashTool(newMockPermissionService(true))
  32. params := BashParams{
  33. Command: "echo 'Hello World'",
  34. }
  35. paramsJSON, err := json.Marshal(params)
  36. require.NoError(t, err)
  37. call := ToolCall{
  38. Name: BashToolName,
  39. Input: string(paramsJSON),
  40. }
  41. response, err := tool.Run(context.Background(), call)
  42. require.NoError(t, err)
  43. assert.Equal(t, "Hello World\n", response.Content)
  44. })
  45. t.Run("handles invalid parameters", func(t *testing.T) {
  46. tool := NewBashTool(newMockPermissionService(true))
  47. call := ToolCall{
  48. Name: BashToolName,
  49. Input: "invalid json",
  50. }
  51. response, err := tool.Run(context.Background(), call)
  52. require.NoError(t, err)
  53. assert.Contains(t, response.Content, "invalid parameters")
  54. })
  55. t.Run("handles missing command", func(t *testing.T) {
  56. tool := NewBashTool(newMockPermissionService(true))
  57. params := BashParams{
  58. Command: "",
  59. }
  60. paramsJSON, err := json.Marshal(params)
  61. require.NoError(t, err)
  62. call := ToolCall{
  63. Name: BashToolName,
  64. Input: string(paramsJSON),
  65. }
  66. response, err := tool.Run(context.Background(), call)
  67. require.NoError(t, err)
  68. assert.Contains(t, response.Content, "missing command")
  69. })
  70. t.Run("handles banned commands", func(t *testing.T) {
  71. tool := NewBashTool(newMockPermissionService(true))
  72. for _, bannedCmd := range bannedCommands {
  73. params := BashParams{
  74. Command: bannedCmd + " arg1 arg2",
  75. }
  76. paramsJSON, err := json.Marshal(params)
  77. require.NoError(t, err)
  78. call := ToolCall{
  79. Name: BashToolName,
  80. Input: string(paramsJSON),
  81. }
  82. response, err := tool.Run(context.Background(), call)
  83. require.NoError(t, err)
  84. assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
  85. }
  86. })
  87. t.Run("handles multi-word safe commands without permission check", func(t *testing.T) {
  88. tool := NewBashTool(newMockPermissionService(false))
  89. // Test with multi-word safe commands
  90. multiWordCommands := []string{
  91. "go env",
  92. }
  93. for _, cmd := range multiWordCommands {
  94. params := BashParams{
  95. Command: cmd,
  96. }
  97. paramsJSON, err := json.Marshal(params)
  98. require.NoError(t, err)
  99. call := ToolCall{
  100. Name: BashToolName,
  101. Input: string(paramsJSON),
  102. }
  103. response, err := tool.Run(context.Background(), call)
  104. require.NoError(t, err)
  105. assert.NotContains(t, response.Content, "permission denied",
  106. "Command %s should be allowed without permission", cmd)
  107. }
  108. })
  109. t.Run("handles permission denied", func(t *testing.T) {
  110. tool := NewBashTool(newMockPermissionService(false))
  111. // Test with a command that requires permission
  112. params := BashParams{
  113. Command: "mkdir test_dir",
  114. }
  115. paramsJSON, err := json.Marshal(params)
  116. require.NoError(t, err)
  117. call := ToolCall{
  118. Name: BashToolName,
  119. Input: string(paramsJSON),
  120. }
  121. response, err := tool.Run(context.Background(), call)
  122. require.NoError(t, err)
  123. assert.Contains(t, response.Content, "permission denied")
  124. })
  125. t.Run("handles command timeout", func(t *testing.T) {
  126. tool := NewBashTool(newMockPermissionService(true))
  127. params := BashParams{
  128. Command: "sleep 2",
  129. Timeout: 100, // 100ms timeout
  130. }
  131. paramsJSON, err := json.Marshal(params)
  132. require.NoError(t, err)
  133. call := ToolCall{
  134. Name: BashToolName,
  135. Input: string(paramsJSON),
  136. }
  137. response, err := tool.Run(context.Background(), call)
  138. require.NoError(t, err)
  139. assert.Contains(t, response.Content, "aborted")
  140. })
  141. t.Run("handles command with stderr output", func(t *testing.T) {
  142. tool := NewBashTool(newMockPermissionService(true))
  143. params := BashParams{
  144. Command: "echo 'error message' >&2",
  145. }
  146. paramsJSON, err := json.Marshal(params)
  147. require.NoError(t, err)
  148. call := ToolCall{
  149. Name: BashToolName,
  150. Input: string(paramsJSON),
  151. }
  152. response, err := tool.Run(context.Background(), call)
  153. require.NoError(t, err)
  154. assert.Contains(t, response.Content, "error message")
  155. })
  156. t.Run("handles command with both stdout and stderr", func(t *testing.T) {
  157. tool := NewBashTool(newMockPermissionService(true))
  158. params := BashParams{
  159. Command: "echo 'stdout message' && echo 'stderr message' >&2",
  160. }
  161. paramsJSON, err := json.Marshal(params)
  162. require.NoError(t, err)
  163. call := ToolCall{
  164. Name: BashToolName,
  165. Input: string(paramsJSON),
  166. }
  167. response, err := tool.Run(context.Background(), call)
  168. require.NoError(t, err)
  169. assert.Contains(t, response.Content, "stdout message")
  170. assert.Contains(t, response.Content, "stderr message")
  171. })
  172. t.Run("handles context cancellation", func(t *testing.T) {
  173. tool := NewBashTool(newMockPermissionService(true))
  174. params := BashParams{
  175. Command: "sleep 5",
  176. }
  177. paramsJSON, err := json.Marshal(params)
  178. require.NoError(t, err)
  179. call := ToolCall{
  180. Name: BashToolName,
  181. Input: string(paramsJSON),
  182. }
  183. ctx, cancel := context.WithCancel(context.Background())
  184. // Cancel the context after a short delay
  185. go func() {
  186. time.Sleep(100 * time.Millisecond)
  187. cancel()
  188. }()
  189. response, err := tool.Run(ctx, call)
  190. require.NoError(t, err)
  191. assert.Contains(t, response.Content, "aborted")
  192. })
  193. t.Run("respects max timeout", func(t *testing.T) {
  194. tool := NewBashTool(newMockPermissionService(true))
  195. params := BashParams{
  196. Command: "echo 'test'",
  197. Timeout: MaxTimeout + 1000, // Exceeds max timeout
  198. }
  199. paramsJSON, err := json.Marshal(params)
  200. require.NoError(t, err)
  201. call := ToolCall{
  202. Name: BashToolName,
  203. Input: string(paramsJSON),
  204. }
  205. response, err := tool.Run(context.Background(), call)
  206. require.NoError(t, err)
  207. assert.Equal(t, "test\n", response.Content)
  208. })
  209. t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
  210. tool := NewBashTool(newMockPermissionService(true))
  211. params := BashParams{
  212. Command: "echo 'test'",
  213. Timeout: -100, // Negative timeout
  214. }
  215. paramsJSON, err := json.Marshal(params)
  216. require.NoError(t, err)
  217. call := ToolCall{
  218. Name: BashToolName,
  219. Input: string(paramsJSON),
  220. }
  221. response, err := tool.Run(context.Background(), call)
  222. require.NoError(t, err)
  223. assert.Equal(t, "test\n", response.Content)
  224. })
  225. }
  226. func TestTruncateOutput(t *testing.T) {
  227. t.Run("does not truncate short output", func(t *testing.T) {
  228. output := "short output"
  229. result := truncateOutput(output)
  230. assert.Equal(t, output, result)
  231. })
  232. t.Run("truncates long output", func(t *testing.T) {
  233. // Create a string longer than MaxOutputLength
  234. longOutput := strings.Repeat("a\n", MaxOutputLength)
  235. result := truncateOutput(longOutput)
  236. // Check that the result is shorter than the original
  237. assert.Less(t, len(result), len(longOutput))
  238. // Check that the truncation message is included
  239. assert.Contains(t, result, "lines truncated")
  240. // Check that we have the beginning and end of the original string
  241. assert.True(t, strings.HasPrefix(result, "a\n"))
  242. assert.True(t, strings.HasSuffix(result, "a\n"))
  243. })
  244. }
  245. func TestCountLines(t *testing.T) {
  246. testCases := []struct {
  247. name string
  248. input string
  249. expected int
  250. }{
  251. {
  252. name: "empty string",
  253. input: "",
  254. expected: 0,
  255. },
  256. {
  257. name: "single line",
  258. input: "line1",
  259. expected: 1,
  260. },
  261. {
  262. name: "multiple lines",
  263. input: "line1\nline2\nline3",
  264. expected: 3,
  265. },
  266. {
  267. name: "trailing newline",
  268. input: "line1\nline2\n",
  269. expected: 3, // Empty string after last newline counts as a line
  270. },
  271. }
  272. for _, tc := range testCases {
  273. t.Run(tc.name, func(t *testing.T) {
  274. result := countLines(tc.input)
  275. assert.Equal(t, tc.expected, result)
  276. })
  277. }
  278. }
  279. // Mock permission service for testing
  280. type mockPermissionService struct {
  281. *pubsub.Broker[permission.PermissionRequest]
  282. allow bool
  283. }
  284. func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
  285. // Not needed for tests
  286. }
  287. func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
  288. // Not needed for tests
  289. }
  290. func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
  291. // Not needed for tests
  292. }
  293. func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
  294. return m.allow
  295. }
  296. func newMockPermissionService(allow bool) permission.Service {
  297. return &mockPermissionService{
  298. Broker: pubsub.NewBroker[permission.PermissionRequest](),
  299. allow: allow,
  300. }
  301. }