write.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package tools
  2. import (
  3. "context"
  4. _ "embed"
  5. "fmt"
  6. "log/slog"
  7. "os"
  8. "path/filepath"
  9. "strings"
  10. "time"
  11. "charm.land/fantasy"
  12. "github.com/charmbracelet/crush/internal/csync"
  13. "github.com/charmbracelet/crush/internal/diff"
  14. "github.com/charmbracelet/crush/internal/filepathext"
  15. "github.com/charmbracelet/crush/internal/filetracker"
  16. "github.com/charmbracelet/crush/internal/fsext"
  17. "github.com/charmbracelet/crush/internal/history"
  18. "github.com/charmbracelet/crush/internal/lsp"
  19. "github.com/charmbracelet/crush/internal/permission"
  20. )
  21. //go:embed write.md
  22. var writeDescription []byte
  23. type WriteParams struct {
  24. FilePath string `json:"file_path" description:"The path to the file to write"`
  25. Content string `json:"content" description:"The content to write to the file"`
  26. }
  27. type WritePermissionsParams struct {
  28. FilePath string `json:"file_path"`
  29. OldContent string `json:"old_content,omitempty"`
  30. NewContent string `json:"new_content,omitempty"`
  31. }
  32. type WriteResponseMetadata struct {
  33. Diff string `json:"diff"`
  34. Additions int `json:"additions"`
  35. Removals int `json:"removals"`
  36. }
  37. const WriteToolName = "write"
  38. func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) fantasy.AgentTool {
  39. return fantasy.NewAgentTool(
  40. WriteToolName,
  41. string(writeDescription),
  42. func(ctx context.Context, params WriteParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
  43. if params.FilePath == "" {
  44. return fantasy.NewTextErrorResponse("file_path is required"), nil
  45. }
  46. if params.Content == "" {
  47. return fantasy.NewTextErrorResponse("content is required"), nil
  48. }
  49. filePath := filepathext.SmartJoin(workingDir, params.FilePath)
  50. fileInfo, err := os.Stat(filePath)
  51. if err == nil {
  52. if fileInfo.IsDir() {
  53. return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
  54. }
  55. modTime := fileInfo.ModTime()
  56. lastRead := filetracker.LastReadTime(filePath)
  57. if modTime.After(lastRead) {
  58. return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
  59. filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
  60. }
  61. oldContent, readErr := os.ReadFile(filePath)
  62. if readErr == nil && string(oldContent) == params.Content {
  63. return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
  64. }
  65. } else if !os.IsNotExist(err) {
  66. return fantasy.ToolResponse{}, fmt.Errorf("error checking file: %w", err)
  67. }
  68. dir := filepath.Dir(filePath)
  69. if err = os.MkdirAll(dir, 0o755); err != nil {
  70. return fantasy.ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
  71. }
  72. oldContent := ""
  73. if fileInfo != nil && !fileInfo.IsDir() {
  74. oldBytes, readErr := os.ReadFile(filePath)
  75. if readErr == nil {
  76. oldContent = string(oldBytes)
  77. }
  78. }
  79. sessionID := GetSessionFromContext(ctx)
  80. if sessionID == "" {
  81. return fantasy.ToolResponse{}, fmt.Errorf("session_id is required")
  82. }
  83. diff, additions, removals := diff.GenerateDiff(
  84. oldContent,
  85. params.Content,
  86. strings.TrimPrefix(filePath, workingDir),
  87. )
  88. p, err := permissions.Request(ctx,
  89. permission.CreatePermissionRequest{
  90. SessionID: sessionID,
  91. Path: fsext.PathOrPrefix(filePath, workingDir),
  92. ToolCallID: call.ID,
  93. ToolName: WriteToolName,
  94. Action: "write",
  95. Description: fmt.Sprintf("Create file %s", filePath),
  96. Params: WritePermissionsParams{
  97. FilePath: filePath,
  98. OldContent: oldContent,
  99. NewContent: params.Content,
  100. },
  101. },
  102. )
  103. if err != nil {
  104. return fantasy.ToolResponse{}, err
  105. }
  106. if !p {
  107. return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
  108. }
  109. err = os.WriteFile(filePath, []byte(params.Content), 0o644)
  110. if err != nil {
  111. return fantasy.ToolResponse{}, fmt.Errorf("error writing file: %w", err)
  112. }
  113. // Check if file exists in history
  114. file, err := files.GetByPathAndSession(ctx, filePath, sessionID)
  115. if err != nil {
  116. _, err = files.Create(ctx, sessionID, filePath, oldContent)
  117. if err != nil {
  118. // Log error but don't fail the operation
  119. return fantasy.ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
  120. }
  121. }
  122. if file.Content != oldContent {
  123. // User manually changed the content; store an intermediate version
  124. _, err = files.CreateVersion(ctx, sessionID, filePath, oldContent)
  125. if err != nil {
  126. slog.Error("Error creating file history version", "error", err)
  127. }
  128. }
  129. // Store the new version
  130. _, err = files.CreateVersion(ctx, sessionID, filePath, params.Content)
  131. if err != nil {
  132. slog.Error("Error creating file history version", "error", err)
  133. }
  134. filetracker.RecordWrite(filePath)
  135. filetracker.RecordRead(filePath)
  136. notifyLSPs(ctx, lspClients, params.FilePath)
  137. result := fmt.Sprintf("File successfully written: %s", filePath)
  138. result = fmt.Sprintf("<result>\n%s\n</result>", result)
  139. result += getDiagnostics(filePath, lspClients)
  140. return fantasy.WithResponseMetadata(fantasy.NewTextResponse(result),
  141. WriteResponseMetadata{
  142. Diff: diff,
  143. Additions: additions,
  144. Removals: removals,
  145. },
  146. ), nil
  147. })
  148. }