edit.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. package tools
  2. import (
  3. "context"
  4. _ "embed"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "time"
  12. "github.com/charmbracelet/crush/internal/csync"
  13. "github.com/charmbracelet/crush/internal/diff"
  14. "github.com/charmbracelet/crush/internal/fsext"
  15. "github.com/charmbracelet/crush/internal/history"
  16. "github.com/charmbracelet/crush/internal/lsp"
  17. "github.com/charmbracelet/crush/internal/permission"
  18. )
  19. type EditParams struct {
  20. FilePath string `json:"file_path"`
  21. OldString string `json:"old_string"`
  22. NewString string `json:"new_string"`
  23. ReplaceAll bool `json:"replace_all,omitempty"`
  24. }
  25. type EditPermissionsParams struct {
  26. FilePath string `json:"file_path"`
  27. OldContent string `json:"old_content,omitempty"`
  28. NewContent string `json:"new_content,omitempty"`
  29. }
  30. type EditResponseMetadata struct {
  31. Additions int `json:"additions"`
  32. Removals int `json:"removals"`
  33. OldContent string `json:"old_content,omitempty"`
  34. NewContent string `json:"new_content,omitempty"`
  35. }
  36. type editTool struct {
  37. lspClients *csync.Map[string, *lsp.Client]
  38. permissions permission.Service
  39. files history.Service
  40. workingDir string
  41. }
  42. const EditToolName = "edit"
  43. //go:embed edit.md
  44. var editDescription []byte
  45. func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
  46. return &editTool{
  47. lspClients: lspClients,
  48. permissions: permissions,
  49. files: files,
  50. workingDir: workingDir,
  51. }
  52. }
  53. func (e *editTool) Name() string {
  54. return EditToolName
  55. }
  56. func (e *editTool) Info() ToolInfo {
  57. return ToolInfo{
  58. Name: EditToolName,
  59. Description: string(editDescription),
  60. Parameters: map[string]any{
  61. "file_path": map[string]any{
  62. "type": "string",
  63. "description": "The absolute path to the file to modify",
  64. },
  65. "old_string": map[string]any{
  66. "type": "string",
  67. "description": "The text to replace",
  68. },
  69. "new_string": map[string]any{
  70. "type": "string",
  71. "description": "The text to replace it with",
  72. },
  73. "replace_all": map[string]any{
  74. "type": "boolean",
  75. "description": "Replace all occurrences of old_string (default false)",
  76. },
  77. },
  78. Required: []string{"file_path", "old_string", "new_string"},
  79. }
  80. }
  81. func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
  82. var params EditParams
  83. if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
  84. return NewTextErrorResponse("invalid parameters"), nil
  85. }
  86. if params.FilePath == "" {
  87. return NewTextErrorResponse("file_path is required"), nil
  88. }
  89. if !filepath.IsAbs(params.FilePath) {
  90. params.FilePath = filepath.Join(e.workingDir, params.FilePath)
  91. }
  92. var response ToolResponse
  93. var err error
  94. if params.OldString == "" {
  95. response, err = e.createNewFile(ctx, params.FilePath, params.NewString, call)
  96. if err != nil {
  97. return response, err
  98. }
  99. }
  100. if params.NewString == "" {
  101. response, err = e.deleteContent(ctx, params.FilePath, params.OldString, params.ReplaceAll, call)
  102. if err != nil {
  103. return response, err
  104. }
  105. }
  106. response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString, params.ReplaceAll, call)
  107. if err != nil {
  108. return response, err
  109. }
  110. if response.IsError {
  111. // Return early if there was an error during content replacement
  112. // This prevents unnecessary LSP diagnostics processing
  113. return response, nil
  114. }
  115. notifyLSPs(ctx, e.lspClients, params.FilePath)
  116. text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
  117. text += getDiagnostics(params.FilePath, e.lspClients)
  118. response.Content = text
  119. return response, nil
  120. }
  121. func (e *editTool) createNewFile(ctx context.Context, filePath, content string, call ToolCall) (ToolResponse, error) {
  122. fileInfo, err := os.Stat(filePath)
  123. if err == nil {
  124. if fileInfo.IsDir() {
  125. return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
  126. }
  127. return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", filePath)), nil
  128. } else if !os.IsNotExist(err) {
  129. return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
  130. }
  131. dir := filepath.Dir(filePath)
  132. if err = os.MkdirAll(dir, 0o755); err != nil {
  133. return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
  134. }
  135. sessionID, messageID := GetContextValues(ctx)
  136. if sessionID == "" || messageID == "" {
  137. return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
  138. }
  139. _, additions, removals := diff.GenerateDiff(
  140. "",
  141. content,
  142. strings.TrimPrefix(filePath, e.workingDir),
  143. )
  144. p := e.permissions.Request(
  145. permission.CreatePermissionRequest{
  146. SessionID: sessionID,
  147. Path: fsext.PathOrPrefix(filePath, e.workingDir),
  148. ToolCallID: call.ID,
  149. ToolName: EditToolName,
  150. Action: "write",
  151. Description: fmt.Sprintf("Create file %s", filePath),
  152. Params: EditPermissionsParams{
  153. FilePath: filePath,
  154. OldContent: "",
  155. NewContent: content,
  156. },
  157. },
  158. )
  159. if !p {
  160. return ToolResponse{}, permission.ErrorPermissionDenied
  161. }
  162. err = os.WriteFile(filePath, []byte(content), 0o644)
  163. if err != nil {
  164. return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
  165. }
  166. // File can't be in the history so we create a new file history
  167. _, err = e.files.Create(ctx, sessionID, filePath, "")
  168. if err != nil {
  169. // Log error but don't fail the operation
  170. return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
  171. }
  172. // Add the new content to the file history
  173. _, err = e.files.CreateVersion(ctx, sessionID, filePath, content)
  174. if err != nil {
  175. // Log error but don't fail the operation
  176. slog.Debug("Error creating file history version", "error", err)
  177. }
  178. recordFileWrite(filePath)
  179. recordFileRead(filePath)
  180. return WithResponseMetadata(
  181. NewTextResponse("File created: "+filePath),
  182. EditResponseMetadata{
  183. OldContent: "",
  184. NewContent: content,
  185. Additions: additions,
  186. Removals: removals,
  187. },
  188. ), nil
  189. }
  190. func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string, replaceAll bool, call ToolCall) (ToolResponse, error) {
  191. fileInfo, err := os.Stat(filePath)
  192. if err != nil {
  193. if os.IsNotExist(err) {
  194. return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
  195. }
  196. return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
  197. }
  198. if fileInfo.IsDir() {
  199. return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
  200. }
  201. if getLastReadTime(filePath).IsZero() {
  202. return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
  203. }
  204. modTime := fileInfo.ModTime()
  205. lastRead := getLastReadTime(filePath)
  206. if modTime.After(lastRead) {
  207. return NewTextErrorResponse(
  208. fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
  209. filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
  210. )), nil
  211. }
  212. content, err := os.ReadFile(filePath)
  213. if err != nil {
  214. return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
  215. }
  216. oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
  217. var newContent string
  218. var deletionCount int
  219. if replaceAll {
  220. newContent = strings.ReplaceAll(oldContent, oldString, "")
  221. deletionCount = strings.Count(oldContent, oldString)
  222. if deletionCount == 0 {
  223. return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
  224. }
  225. } else {
  226. index := strings.Index(oldContent, oldString)
  227. if index == -1 {
  228. return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
  229. }
  230. lastIndex := strings.LastIndex(oldContent, oldString)
  231. if index != lastIndex {
  232. return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil
  233. }
  234. newContent = oldContent[:index] + oldContent[index+len(oldString):]
  235. deletionCount = 1
  236. }
  237. sessionID, messageID := GetContextValues(ctx)
  238. if sessionID == "" || messageID == "" {
  239. return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
  240. }
  241. _, additions, removals := diff.GenerateDiff(
  242. oldContent,
  243. newContent,
  244. strings.TrimPrefix(filePath, e.workingDir),
  245. )
  246. p := e.permissions.Request(
  247. permission.CreatePermissionRequest{
  248. SessionID: sessionID,
  249. Path: fsext.PathOrPrefix(filePath, e.workingDir),
  250. ToolCallID: call.ID,
  251. ToolName: EditToolName,
  252. Action: "write",
  253. Description: fmt.Sprintf("Delete content from file %s", filePath),
  254. Params: EditPermissionsParams{
  255. FilePath: filePath,
  256. OldContent: oldContent,
  257. NewContent: newContent,
  258. },
  259. },
  260. )
  261. if !p {
  262. return ToolResponse{}, permission.ErrorPermissionDenied
  263. }
  264. if isCrlf {
  265. newContent, _ = fsext.ToWindowsLineEndings(newContent)
  266. }
  267. err = os.WriteFile(filePath, []byte(newContent), 0o644)
  268. if err != nil {
  269. return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
  270. }
  271. // Check if file exists in history
  272. file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
  273. if err != nil {
  274. _, err = e.files.Create(ctx, sessionID, filePath, oldContent)
  275. if err != nil {
  276. // Log error but don't fail the operation
  277. return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
  278. }
  279. }
  280. if file.Content != oldContent {
  281. // User Manually changed the content store an intermediate version
  282. _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
  283. if err != nil {
  284. slog.Debug("Error creating file history version", "error", err)
  285. }
  286. }
  287. // Store the new version
  288. _, err = e.files.CreateVersion(ctx, sessionID, filePath, "")
  289. if err != nil {
  290. slog.Debug("Error creating file history version", "error", err)
  291. }
  292. recordFileWrite(filePath)
  293. recordFileRead(filePath)
  294. return WithResponseMetadata(
  295. NewTextResponse("Content deleted from file: "+filePath),
  296. EditResponseMetadata{
  297. OldContent: oldContent,
  298. NewContent: newContent,
  299. Additions: additions,
  300. Removals: removals,
  301. },
  302. ), nil
  303. }
  304. func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string, replaceAll bool, call ToolCall) (ToolResponse, error) {
  305. fileInfo, err := os.Stat(filePath)
  306. if err != nil {
  307. if os.IsNotExist(err) {
  308. return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
  309. }
  310. return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
  311. }
  312. if fileInfo.IsDir() {
  313. return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
  314. }
  315. if getLastReadTime(filePath).IsZero() {
  316. return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
  317. }
  318. modTime := fileInfo.ModTime()
  319. lastRead := getLastReadTime(filePath)
  320. if modTime.After(lastRead) {
  321. return NewTextErrorResponse(
  322. fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
  323. filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
  324. )), nil
  325. }
  326. content, err := os.ReadFile(filePath)
  327. if err != nil {
  328. return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
  329. }
  330. oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
  331. var newContent string
  332. var replacementCount int
  333. if replaceAll {
  334. newContent = strings.ReplaceAll(oldContent, oldString, newString)
  335. replacementCount = strings.Count(oldContent, oldString)
  336. if replacementCount == 0 {
  337. return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
  338. }
  339. } else {
  340. index := strings.Index(oldContent, oldString)
  341. if index == -1 {
  342. return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
  343. }
  344. lastIndex := strings.LastIndex(oldContent, oldString)
  345. if index != lastIndex {
  346. return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil
  347. }
  348. newContent = oldContent[:index] + newString + oldContent[index+len(oldString):]
  349. replacementCount = 1
  350. }
  351. if oldContent == newContent {
  352. return NewTextErrorResponse("new content is the same as old content. No changes made."), nil
  353. }
  354. sessionID, messageID := GetContextValues(ctx)
  355. if sessionID == "" || messageID == "" {
  356. return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
  357. }
  358. _, additions, removals := diff.GenerateDiff(
  359. oldContent,
  360. newContent,
  361. strings.TrimPrefix(filePath, e.workingDir),
  362. )
  363. p := e.permissions.Request(
  364. permission.CreatePermissionRequest{
  365. SessionID: sessionID,
  366. Path: fsext.PathOrPrefix(filePath, e.workingDir),
  367. ToolCallID: call.ID,
  368. ToolName: EditToolName,
  369. Action: "write",
  370. Description: fmt.Sprintf("Replace content in file %s", filePath),
  371. Params: EditPermissionsParams{
  372. FilePath: filePath,
  373. OldContent: oldContent,
  374. NewContent: newContent,
  375. },
  376. },
  377. )
  378. if !p {
  379. return ToolResponse{}, permission.ErrorPermissionDenied
  380. }
  381. if isCrlf {
  382. newContent, _ = fsext.ToWindowsLineEndings(newContent)
  383. }
  384. err = os.WriteFile(filePath, []byte(newContent), 0o644)
  385. if err != nil {
  386. return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
  387. }
  388. // Check if file exists in history
  389. file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
  390. if err != nil {
  391. _, err = e.files.Create(ctx, sessionID, filePath, oldContent)
  392. if err != nil {
  393. // Log error but don't fail the operation
  394. return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
  395. }
  396. }
  397. if file.Content != oldContent {
  398. // User Manually changed the content store an intermediate version
  399. _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
  400. if err != nil {
  401. slog.Debug("Error creating file history version", "error", err)
  402. }
  403. }
  404. // Store the new version
  405. _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent)
  406. if err != nil {
  407. slog.Debug("Error creating file history version", "error", err)
  408. }
  409. recordFileWrite(filePath)
  410. recordFileRead(filePath)
  411. return WithResponseMetadata(
  412. NewTextResponse("Content replaced in file: "+filePath),
  413. EditResponseMetadata{
  414. OldContent: oldContent,
  415. NewContent: newContent,
  416. Additions: additions,
  417. Removals: removals,
  418. }), nil
  419. }