write.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. package tools
  2. import (
  3. "context"
  4. _ "embed"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "time"
  12. "github.com/charmbracelet/crush/internal/csync"
  13. "github.com/charmbracelet/crush/internal/diff"
  14. "github.com/charmbracelet/crush/internal/fsext"
  15. "github.com/charmbracelet/crush/internal/history"
  16. "github.com/charmbracelet/crush/internal/lsp"
  17. "github.com/charmbracelet/crush/internal/permission"
  18. )
  19. //go:embed write.md
  20. var writeDescription []byte
  21. type WriteParams struct {
  22. FilePath string `json:"file_path"`
  23. Content string `json:"content"`
  24. }
  25. type WritePermissionsParams struct {
  26. FilePath string `json:"file_path"`
  27. OldContent string `json:"old_content,omitempty"`
  28. NewContent string `json:"new_content,omitempty"`
  29. }
  30. type writeTool struct {
  31. lspClients *csync.Map[string, *lsp.Client]
  32. permissions permission.Service
  33. files history.Service
  34. workingDir string
  35. }
  36. type WriteResponseMetadata struct {
  37. Diff string `json:"diff"`
  38. Additions int `json:"additions"`
  39. Removals int `json:"removals"`
  40. }
  41. const WriteToolName = "write"
  42. func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
  43. return &writeTool{
  44. lspClients: lspClients,
  45. permissions: permissions,
  46. files: files,
  47. workingDir: workingDir,
  48. }
  49. }
  50. func (w *writeTool) Name() string {
  51. return WriteToolName
  52. }
  53. func (w *writeTool) Info() ToolInfo {
  54. return ToolInfo{
  55. Name: WriteToolName,
  56. Description: string(writeDescription),
  57. Parameters: map[string]any{
  58. "file_path": map[string]any{
  59. "type": "string",
  60. "description": "The path to the file to write",
  61. },
  62. "content": map[string]any{
  63. "type": "string",
  64. "description": "The content to write to the file",
  65. },
  66. },
  67. Required: []string{"file_path", "content"},
  68. }
  69. }
  70. func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
  71. var params WriteParams
  72. if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
  73. return NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
  74. }
  75. if params.FilePath == "" {
  76. return NewTextErrorResponse("file_path is required"), nil
  77. }
  78. if params.Content == "" {
  79. return NewTextErrorResponse("content is required"), nil
  80. }
  81. filePath := params.FilePath
  82. if !filepath.IsAbs(filePath) {
  83. filePath = filepath.Join(w.workingDir, filePath)
  84. }
  85. fileInfo, err := os.Stat(filePath)
  86. if err == nil {
  87. if fileInfo.IsDir() {
  88. return NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
  89. }
  90. modTime := fileInfo.ModTime()
  91. lastRead := getLastReadTime(filePath)
  92. if modTime.After(lastRead) {
  93. return NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
  94. filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
  95. }
  96. oldContent, readErr := os.ReadFile(filePath)
  97. if readErr == nil && string(oldContent) == params.Content {
  98. return NewTextErrorResponse(fmt.Sprintf("File %s already contains the exact content. No changes made.", filePath)), nil
  99. }
  100. } else if !os.IsNotExist(err) {
  101. return ToolResponse{}, fmt.Errorf("error checking file: %w", err)
  102. }
  103. dir := filepath.Dir(filePath)
  104. if err = os.MkdirAll(dir, 0o755); err != nil {
  105. return ToolResponse{}, fmt.Errorf("error creating directory: %w", err)
  106. }
  107. oldContent := ""
  108. if fileInfo != nil && !fileInfo.IsDir() {
  109. oldBytes, readErr := os.ReadFile(filePath)
  110. if readErr == nil {
  111. oldContent = string(oldBytes)
  112. }
  113. }
  114. sessionID, messageID := GetContextValues(ctx)
  115. if sessionID == "" || messageID == "" {
  116. return ToolResponse{}, fmt.Errorf("session_id and message_id are required")
  117. }
  118. diff, additions, removals := diff.GenerateDiff(
  119. oldContent,
  120. params.Content,
  121. strings.TrimPrefix(filePath, w.workingDir),
  122. )
  123. p := w.permissions.Request(
  124. permission.CreatePermissionRequest{
  125. SessionID: sessionID,
  126. Path: fsext.PathOrPrefix(filePath, w.workingDir),
  127. ToolCallID: call.ID,
  128. ToolName: WriteToolName,
  129. Action: "write",
  130. Description: fmt.Sprintf("Create file %s", filePath),
  131. Params: WritePermissionsParams{
  132. FilePath: filePath,
  133. OldContent: oldContent,
  134. NewContent: params.Content,
  135. },
  136. },
  137. )
  138. if !p {
  139. return ToolResponse{}, permission.ErrorPermissionDenied
  140. }
  141. err = os.WriteFile(filePath, []byte(params.Content), 0o644)
  142. if err != nil {
  143. return ToolResponse{}, fmt.Errorf("error writing file: %w", err)
  144. }
  145. // Check if file exists in history
  146. file, err := w.files.GetByPathAndSession(ctx, filePath, sessionID)
  147. if err != nil {
  148. _, err = w.files.Create(ctx, sessionID, filePath, oldContent)
  149. if err != nil {
  150. // Log error but don't fail the operation
  151. return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
  152. }
  153. }
  154. if file.Content != oldContent {
  155. // User Manually changed the content store an intermediate version
  156. _, err = w.files.CreateVersion(ctx, sessionID, filePath, oldContent)
  157. if err != nil {
  158. slog.Debug("Error creating file history version", "error", err)
  159. }
  160. }
  161. // Store the new version
  162. _, err = w.files.CreateVersion(ctx, sessionID, filePath, params.Content)
  163. if err != nil {
  164. slog.Debug("Error creating file history version", "error", err)
  165. }
  166. recordFileWrite(filePath)
  167. recordFileRead(filePath)
  168. notifyLSPs(ctx, w.lspClients, params.FilePath)
  169. result := fmt.Sprintf("File successfully written: %s", filePath)
  170. result = fmt.Sprintf("<result>\n%s\n</result>", result)
  171. result += getDiagnostics(filePath, w.lspClients)
  172. return WithResponseMetadata(NewTextResponse(result),
  173. WriteResponseMetadata{
  174. Diff: diff,
  175. Additions: additions,
  176. Removals: removals,
  177. },
  178. ), nil
  179. }