patch.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. package tools
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "os"
  7. "path/filepath"
  8. "strings"
  9. "time"
  10. "github.com/kujtimiihoxha/opencode/internal/config"
  11. "github.com/kujtimiihoxha/opencode/internal/diff"
  12. "github.com/kujtimiihoxha/opencode/internal/history"
  13. "github.com/kujtimiihoxha/opencode/internal/lsp"
  14. "github.com/kujtimiihoxha/opencode/internal/permission"
  15. )
  16. type PatchParams struct {
  17. FilePath string `json:"file_path"`
  18. Patch string `json:"patch"`
  19. }
  20. type PatchPermissionsParams struct {
  21. FilePath string `json:"file_path"`
  22. Diff string `json:"diff"`
  23. }
  24. type PatchResponseMetadata struct {
  25. Diff string `json:"diff"`
  26. Additions int `json:"additions"`
  27. Removals int `json:"removals"`
  28. }
  29. type patchTool struct {
  30. lspClients map[string]*lsp.Client
  31. permissions permission.Service
  32. files history.Service
  33. }
  34. const (
  35. // TODO: test if this works as expected
  36. PatchToolName = "patch"
  37. patchDescription = `Applies a patch to a file. This tool is similar to the edit tool but accepts a unified diff patch instead of old/new strings.
  38. Before using this tool:
  39. 1. Use the FileRead tool to understand the file's contents and context
  40. 2. Verify the directory path is correct:
  41. - Use the LS tool to verify the parent directory exists and is the correct location
  42. To apply a patch, provide the following:
  43. 1. file_path: The absolute path to the file to modify (must be absolute, not relative)
  44. 2. patch: A unified diff patch to apply to the file
  45. The tool will apply the patch to the specified file. The patch must be in unified diff format.
  46. CRITICAL REQUIREMENTS FOR USING THIS TOOL:
  47. 1. PATCH FORMAT: The patch must be in unified diff format, which includes:
  48. - File headers (--- a/file_path, +++ b/file_path)
  49. - Hunk headers (@@ -start,count +start,count @@)
  50. - Added lines (prefixed with +)
  51. - Removed lines (prefixed with -)
  52. 2. CONTEXT: The patch must include sufficient context around the changes to ensure it applies correctly.
  53. 3. VERIFICATION: Before using this tool:
  54. - Ensure the patch applies cleanly to the current state of the file
  55. - Check that the file exists and you have read it first
  56. WARNING: If you do not follow these requirements:
  57. - The tool will fail if the patch doesn't apply cleanly
  58. - You may change the wrong parts of the file if the context is insufficient
  59. When applying patches:
  60. - Ensure the patch results in idiomatic, correct code
  61. - Do not leave the code in a broken state
  62. - Always use absolute file paths (starting with /)
  63. Remember: patches are a powerful way to make multiple related changes at once, but they require careful preparation.`
  64. )
  65. func NewPatchTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service) BaseTool {
  66. return &patchTool{
  67. lspClients: lspClients,
  68. permissions: permissions,
  69. files: files,
  70. }
  71. }
  72. func (p *patchTool) Info() ToolInfo {
  73. return ToolInfo{
  74. Name: PatchToolName,
  75. Description: patchDescription,
  76. Parameters: map[string]any{
  77. "file_path": map[string]any{
  78. "type": "string",
  79. "description": "The absolute path to the file to modify",
  80. },
  81. "patch": map[string]any{
  82. "type": "string",
  83. "description": "The unified diff patch to apply",
  84. },
  85. },
  86. Required: []string{"file_path", "patch"},
  87. }
  88. }
  89. func (p *patchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
  90. var params PatchParams
  91. if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
  92. return NewTextErrorResponse("invalid parameters"), nil
  93. }
  94. if params.FilePath == "" {
  95. return NewTextErrorResponse("file_path is required"), nil
  96. }
  97. if params.Patch == "" {
  98. return NewTextErrorResponse("patch is required"), nil
  99. }
  100. if !filepath.IsAbs(params.FilePath) {
  101. wd := config.WorkingDirectory()
  102. params.FilePath = filepath.Join(wd, params.FilePath)
  103. }
  104. // Check if file exists
  105. fileInfo, err := os.Stat(params.FilePath)
  106. if err != nil {
  107. if os.IsNotExist(err) {
  108. return NewTextErrorResponse(fmt.Sprintf("file not found: %s", params.FilePath)), nil
  109. }
  110. return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
  111. }
  112. if fileInfo.IsDir() {
  113. return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil
  114. }
  115. if getLastReadTime(params.FilePath).IsZero() {
  116. return NewTextErrorResponse("you must read the file before patching it. Use the View tool first"), nil
  117. }
  118. modTime := fileInfo.ModTime()
  119. lastRead := getLastReadTime(params.FilePath)
  120. if modTime.After(lastRead) {
  121. return NewTextErrorResponse(
  122. fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
  123. params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
  124. )), nil
  125. }
  126. // Read the current file content
  127. content, err := os.ReadFile(params.FilePath)
  128. if err != nil {
  129. return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
  130. }
  131. oldContent := string(content)
  132. // Parse and apply the patch
  133. diffResult, err := diff.ParseUnifiedDiff(params.Patch)
  134. if err != nil {
  135. return NewTextErrorResponse(fmt.Sprintf("failed to parse patch: %v", err)), nil
  136. }
  137. // Apply the patch to get the new content
  138. newContent, err := applyPatch(oldContent, diffResult)
  139. if err != nil {
  140. return NewTextErrorResponse(fmt.Sprintf("failed to apply patch: %v", err)), nil
  141. }
  142. if oldContent == newContent {
  143. return NewTextErrorResponse("patch did not result in any changes to the file"), nil
  144. }
  145. sessionID, messageID := GetContextValues(ctx)
  146. if sessionID == "" || messageID == "" {
  147. return ToolResponse{}, fmt.Errorf("session ID and message ID are required for patching a file")
  148. }
  149. // Generate a diff for permission request and metadata
  150. diffText, additions, removals := diff.GenerateDiff(
  151. oldContent,
  152. newContent,
  153. params.FilePath,
  154. )
  155. // Request permission to apply the patch
  156. p.permissions.Request(
  157. permission.CreatePermissionRequest{
  158. Path: filepath.Dir(params.FilePath),
  159. ToolName: PatchToolName,
  160. Action: "patch",
  161. Description: fmt.Sprintf("Apply patch to file %s", params.FilePath),
  162. Params: PatchPermissionsParams{
  163. FilePath: params.FilePath,
  164. Diff: diffText,
  165. },
  166. },
  167. )
  168. // Write the new content to the file
  169. err = os.WriteFile(params.FilePath, []byte(newContent), 0o644)
  170. if err != nil {
  171. return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
  172. }
  173. // Update file history
  174. file, err := p.files.GetByPathAndSession(ctx, params.FilePath, sessionID)
  175. if err != nil {
  176. _, err = p.files.Create(ctx, sessionID, params.FilePath, oldContent)
  177. if err != nil {
  178. return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
  179. }
  180. }
  181. if file.Content != oldContent {
  182. // User manually changed the content, store an intermediate version
  183. _, err = p.files.CreateVersion(ctx, sessionID, params.FilePath, oldContent)
  184. if err != nil {
  185. fmt.Printf("Error creating file history version: %v\n", err)
  186. }
  187. }
  188. // Store the new version
  189. _, err = p.files.CreateVersion(ctx, sessionID, params.FilePath, newContent)
  190. if err != nil {
  191. fmt.Printf("Error creating file history version: %v\n", err)
  192. }
  193. recordFileWrite(params.FilePath)
  194. recordFileRead(params.FilePath)
  195. // Wait for LSP diagnostics and include them in the response
  196. waitForLspDiagnostics(ctx, params.FilePath, p.lspClients)
  197. text := fmt.Sprintf("<r>\nPatch applied to file: %s\n</r>\n", params.FilePath)
  198. text += getDiagnostics(params.FilePath, p.lspClients)
  199. return WithResponseMetadata(
  200. NewTextResponse(text),
  201. PatchResponseMetadata{
  202. Diff: diffText,
  203. Additions: additions,
  204. Removals: removals,
  205. }), nil
  206. }
  207. // applyPatch applies a parsed diff to a string and returns the resulting content
  208. func applyPatch(content string, diffResult diff.DiffResult) (string, error) {
  209. lines := strings.Split(content, "\n")
  210. // Process each hunk in the diff
  211. for _, hunk := range diffResult.Hunks {
  212. // Parse the hunk header to get line numbers
  213. var oldStart, oldCount, newStart, newCount int
  214. _, err := fmt.Sscanf(hunk.Header, "@@ -%d,%d +%d,%d @@", &oldStart, &oldCount, &newStart, &newCount)
  215. if err != nil {
  216. // Try alternative format with single line counts
  217. _, err = fmt.Sscanf(hunk.Header, "@@ -%d +%d @@", &oldStart, &newStart)
  218. if err != nil {
  219. return "", fmt.Errorf("invalid hunk header format: %s", hunk.Header)
  220. }
  221. oldCount = 1
  222. newCount = 1
  223. }
  224. // Adjust for 0-based array indexing
  225. oldStart--
  226. newStart--
  227. // Apply the changes
  228. newLines := make([]string, 0)
  229. newLines = append(newLines, lines[:oldStart]...)
  230. // Process the hunk lines in order
  231. currentOldLine := oldStart
  232. for _, line := range hunk.Lines {
  233. switch line.Kind {
  234. case diff.LineContext:
  235. newLines = append(newLines, line.Content)
  236. currentOldLine++
  237. case diff.LineRemoved:
  238. // Skip this line in the output (it's being removed)
  239. currentOldLine++
  240. case diff.LineAdded:
  241. // Add the new line
  242. newLines = append(newLines, line.Content)
  243. }
  244. }
  245. // Append the rest of the file
  246. newLines = append(newLines, lines[currentOldLine:]...)
  247. lines = newLines
  248. }
  249. return strings.Join(lines, "\n"), nil
  250. }