service.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. // Package filetracker provides functionality to track file reads in sessions.
  2. package filetracker
  3. import (
  4. "context"
  5. "fmt"
  6. "log/slog"
  7. "os"
  8. "path/filepath"
  9. "time"
  10. "github.com/charmbracelet/crush/internal/db"
  11. )
  12. // Service defines the interface for tracking file reads in sessions.
  13. type Service interface {
  14. // RecordRead records when a file was read.
  15. RecordRead(ctx context.Context, sessionID, path string)
  16. // LastReadTime returns when a file was last read.
  17. // Returns zero time if never read.
  18. LastReadTime(ctx context.Context, sessionID, path string) time.Time
  19. // ListReadFiles returns the paths of all files read in a session.
  20. ListReadFiles(ctx context.Context, sessionID string) ([]string, error)
  21. }
  22. type service struct {
  23. q *db.Queries
  24. }
  25. // NewService creates a new file tracker service.
  26. func NewService(q *db.Queries) Service {
  27. return &service{q: q}
  28. }
  29. // RecordRead records when a file was read.
  30. func (s *service) RecordRead(ctx context.Context, sessionID, path string) {
  31. if err := s.q.RecordFileRead(ctx, db.RecordFileReadParams{
  32. SessionID: sessionID,
  33. Path: relpath(path),
  34. }); err != nil {
  35. slog.Error("Error recording file read", "error", err, "file", path)
  36. }
  37. }
  38. // LastReadTime returns when a file was last read.
  39. // Returns zero time if never read.
  40. func (s *service) LastReadTime(ctx context.Context, sessionID, path string) time.Time {
  41. readFile, err := s.q.GetFileRead(ctx, db.GetFileReadParams{
  42. SessionID: sessionID,
  43. Path: relpath(path),
  44. })
  45. if err != nil {
  46. return time.Time{}
  47. }
  48. return time.Unix(readFile.ReadAt, 0)
  49. }
  50. func relpath(path string) string {
  51. path = filepath.Clean(path)
  52. basepath, err := os.Getwd()
  53. if err != nil {
  54. slog.Warn("Error getting basepath", "error", err)
  55. return path
  56. }
  57. relpath, err := filepath.Rel(basepath, path)
  58. if err != nil {
  59. slog.Warn("Error getting relpath", "error", err)
  60. return path
  61. }
  62. return relpath
  63. }
  64. // ListReadFiles returns the paths of all files read in a session.
  65. func (s *service) ListReadFiles(ctx context.Context, sessionID string) ([]string, error) {
  66. readFiles, err := s.q.ListSessionReadFiles(ctx, sessionID)
  67. if err != nil {
  68. return nil, fmt.Errorf("listing read files: %w", err)
  69. }
  70. basepath, err := os.Getwd()
  71. if err != nil {
  72. return nil, fmt.Errorf("getting working directory: %w", err)
  73. }
  74. paths := make([]string, 0, len(readFiles))
  75. for _, rf := range readFiles {
  76. paths = append(paths, filepath.Join(basepath, rf.Path))
  77. }
  78. return paths, nil
  79. }