write_test.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "os"
  6. "path/filepath"
  7. "testing"
  8. "time"
  9. "github.com/kujtimiihoxha/termai/internal/lsp"
  10. "github.com/kujtimiihoxha/termai/internal/permission"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func TestWriteTool_Info(t *testing.T) {
  15. tool := NewWriteTool(make(map[string]*lsp.Client))
  16. info := tool.Info()
  17. assert.Equal(t, WriteToolName, info.Name)
  18. assert.NotEmpty(t, info.Description)
  19. assert.Contains(t, info.Parameters, "file_path")
  20. assert.Contains(t, info.Parameters, "content")
  21. assert.Contains(t, info.Required, "file_path")
  22. assert.Contains(t, info.Required, "content")
  23. }
  24. func TestWriteTool_Run(t *testing.T) {
  25. // Setup a mock permission handler that always allows
  26. origPermission := permission.Default
  27. defer func() {
  28. permission.Default = origPermission
  29. }()
  30. permission.Default = newMockPermissionService(true)
  31. // Create a temporary directory for testing
  32. tempDir, err := os.MkdirTemp("", "write_tool_test")
  33. require.NoError(t, err)
  34. defer os.RemoveAll(tempDir)
  35. t.Run("creates a new file successfully", func(t *testing.T) {
  36. permission.Default = newMockPermissionService(true)
  37. tool := NewWriteTool(make(map[string]*lsp.Client))
  38. filePath := filepath.Join(tempDir, "new_file.txt")
  39. content := "This is a test content"
  40. params := WriteParams{
  41. FilePath: filePath,
  42. Content: content,
  43. }
  44. paramsJSON, err := json.Marshal(params)
  45. require.NoError(t, err)
  46. call := ToolCall{
  47. Name: WriteToolName,
  48. Input: string(paramsJSON),
  49. }
  50. response, err := tool.Run(context.Background(), call)
  51. require.NoError(t, err)
  52. assert.Contains(t, response.Content, "successfully written")
  53. // Verify file was created with correct content
  54. fileContent, err := os.ReadFile(filePath)
  55. require.NoError(t, err)
  56. assert.Equal(t, content, string(fileContent))
  57. })
  58. t.Run("creates file with nested directories", func(t *testing.T) {
  59. permission.Default = newMockPermissionService(true)
  60. tool := NewWriteTool(make(map[string]*lsp.Client))
  61. filePath := filepath.Join(tempDir, "nested/dirs/new_file.txt")
  62. content := "Content in nested directory"
  63. params := WriteParams{
  64. FilePath: filePath,
  65. Content: content,
  66. }
  67. paramsJSON, err := json.Marshal(params)
  68. require.NoError(t, err)
  69. call := ToolCall{
  70. Name: WriteToolName,
  71. Input: string(paramsJSON),
  72. }
  73. response, err := tool.Run(context.Background(), call)
  74. require.NoError(t, err)
  75. assert.Contains(t, response.Content, "successfully written")
  76. // Verify file was created with correct content
  77. fileContent, err := os.ReadFile(filePath)
  78. require.NoError(t, err)
  79. assert.Equal(t, content, string(fileContent))
  80. })
  81. t.Run("updates existing file", func(t *testing.T) {
  82. permission.Default = newMockPermissionService(true)
  83. tool := NewWriteTool(make(map[string]*lsp.Client))
  84. // Create a file first
  85. filePath := filepath.Join(tempDir, "existing_file.txt")
  86. initialContent := "Initial content"
  87. err := os.WriteFile(filePath, []byte(initialContent), 0o644)
  88. require.NoError(t, err)
  89. // Record the file read to avoid modification time check failure
  90. recordFileRead(filePath)
  91. // Update the file
  92. updatedContent := "Updated content"
  93. params := WriteParams{
  94. FilePath: filePath,
  95. Content: updatedContent,
  96. }
  97. paramsJSON, err := json.Marshal(params)
  98. require.NoError(t, err)
  99. call := ToolCall{
  100. Name: WriteToolName,
  101. Input: string(paramsJSON),
  102. }
  103. response, err := tool.Run(context.Background(), call)
  104. require.NoError(t, err)
  105. assert.Contains(t, response.Content, "successfully written")
  106. // Verify file was updated with correct content
  107. fileContent, err := os.ReadFile(filePath)
  108. require.NoError(t, err)
  109. assert.Equal(t, updatedContent, string(fileContent))
  110. })
  111. t.Run("handles invalid parameters", func(t *testing.T) {
  112. permission.Default = newMockPermissionService(true)
  113. tool := NewWriteTool(make(map[string]*lsp.Client))
  114. call := ToolCall{
  115. Name: WriteToolName,
  116. Input: "invalid json",
  117. }
  118. response, err := tool.Run(context.Background(), call)
  119. require.NoError(t, err)
  120. assert.Contains(t, response.Content, "error parsing parameters")
  121. })
  122. t.Run("handles missing file_path", func(t *testing.T) {
  123. permission.Default = newMockPermissionService(true)
  124. tool := NewWriteTool(make(map[string]*lsp.Client))
  125. params := WriteParams{
  126. FilePath: "",
  127. Content: "Some content",
  128. }
  129. paramsJSON, err := json.Marshal(params)
  130. require.NoError(t, err)
  131. call := ToolCall{
  132. Name: WriteToolName,
  133. Input: string(paramsJSON),
  134. }
  135. response, err := tool.Run(context.Background(), call)
  136. require.NoError(t, err)
  137. assert.Contains(t, response.Content, "file_path is required")
  138. })
  139. t.Run("handles missing content", func(t *testing.T) {
  140. permission.Default = newMockPermissionService(true)
  141. tool := NewWriteTool(make(map[string]*lsp.Client))
  142. params := WriteParams{
  143. FilePath: filepath.Join(tempDir, "file.txt"),
  144. Content: "",
  145. }
  146. paramsJSON, err := json.Marshal(params)
  147. require.NoError(t, err)
  148. call := ToolCall{
  149. Name: WriteToolName,
  150. Input: string(paramsJSON),
  151. }
  152. response, err := tool.Run(context.Background(), call)
  153. require.NoError(t, err)
  154. assert.Contains(t, response.Content, "content is required")
  155. })
  156. t.Run("handles writing to a directory path", func(t *testing.T) {
  157. permission.Default = newMockPermissionService(true)
  158. tool := NewWriteTool(make(map[string]*lsp.Client))
  159. // Create a directory
  160. dirPath := filepath.Join(tempDir, "test_dir")
  161. err := os.Mkdir(dirPath, 0o755)
  162. require.NoError(t, err)
  163. params := WriteParams{
  164. FilePath: dirPath,
  165. Content: "Some content",
  166. }
  167. paramsJSON, err := json.Marshal(params)
  168. require.NoError(t, err)
  169. call := ToolCall{
  170. Name: WriteToolName,
  171. Input: string(paramsJSON),
  172. }
  173. response, err := tool.Run(context.Background(), call)
  174. require.NoError(t, err)
  175. assert.Contains(t, response.Content, "Path is a directory")
  176. })
  177. t.Run("handles permission denied", func(t *testing.T) {
  178. permission.Default = newMockPermissionService(false)
  179. tool := NewWriteTool(make(map[string]*lsp.Client))
  180. filePath := filepath.Join(tempDir, "permission_denied.txt")
  181. params := WriteParams{
  182. FilePath: filePath,
  183. Content: "Content that should not be written",
  184. }
  185. paramsJSON, err := json.Marshal(params)
  186. require.NoError(t, err)
  187. call := ToolCall{
  188. Name: WriteToolName,
  189. Input: string(paramsJSON),
  190. }
  191. response, err := tool.Run(context.Background(), call)
  192. require.NoError(t, err)
  193. assert.Contains(t, response.Content, "Permission denied")
  194. // Verify file was not created
  195. _, err = os.Stat(filePath)
  196. assert.True(t, os.IsNotExist(err))
  197. })
  198. t.Run("detects file modified since last read", func(t *testing.T) {
  199. permission.Default = newMockPermissionService(true)
  200. tool := NewWriteTool(make(map[string]*lsp.Client))
  201. // Create a file
  202. filePath := filepath.Join(tempDir, "modified_file.txt")
  203. initialContent := "Initial content"
  204. err := os.WriteFile(filePath, []byte(initialContent), 0o644)
  205. require.NoError(t, err)
  206. // Record an old read time
  207. fileRecordMutex.Lock()
  208. fileRecords[filePath] = fileRecord{
  209. path: filePath,
  210. readTime: time.Now().Add(-1 * time.Hour),
  211. }
  212. fileRecordMutex.Unlock()
  213. // Try to update the file
  214. params := WriteParams{
  215. FilePath: filePath,
  216. Content: "Updated content",
  217. }
  218. paramsJSON, err := json.Marshal(params)
  219. require.NoError(t, err)
  220. call := ToolCall{
  221. Name: WriteToolName,
  222. Input: string(paramsJSON),
  223. }
  224. response, err := tool.Run(context.Background(), call)
  225. require.NoError(t, err)
  226. assert.Contains(t, response.Content, "has been modified since it was last read")
  227. // Verify file was not modified
  228. fileContent, err := os.ReadFile(filePath)
  229. require.NoError(t, err)
  230. assert.Equal(t, initialContent, string(fileContent))
  231. })
  232. t.Run("skips writing when content is identical", func(t *testing.T) {
  233. permission.Default = newMockPermissionService(true)
  234. tool := NewWriteTool(make(map[string]*lsp.Client))
  235. // Create a file
  236. filePath := filepath.Join(tempDir, "identical_content.txt")
  237. content := "Content that won't change"
  238. err := os.WriteFile(filePath, []byte(content), 0o644)
  239. require.NoError(t, err)
  240. // Record a read time
  241. recordFileRead(filePath)
  242. // Try to write the same content
  243. params := WriteParams{
  244. FilePath: filePath,
  245. Content: content,
  246. }
  247. paramsJSON, err := json.Marshal(params)
  248. require.NoError(t, err)
  249. call := ToolCall{
  250. Name: WriteToolName,
  251. Input: string(paramsJSON),
  252. }
  253. response, err := tool.Run(context.Background(), call)
  254. require.NoError(t, err)
  255. assert.Contains(t, response.Content, "already contains the exact content")
  256. })
  257. }