session.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package session
  2. import (
  3. "context"
  4. "database/sql"
  5. "github.com/charmbracelet/crush/internal/db"
  6. "github.com/charmbracelet/crush/internal/pubsub"
  7. "github.com/google/uuid"
  8. )
  9. type Session struct {
  10. ID string
  11. ParentSessionID string
  12. Title string
  13. MessageCount int64
  14. PromptTokens int64
  15. CompletionTokens int64
  16. SummaryMessageID string
  17. Cost float64
  18. CreatedAt int64
  19. UpdatedAt int64
  20. }
  21. type Service interface {
  22. pubsub.Suscriber[Session]
  23. Create(ctx context.Context, title string) (Session, error)
  24. CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
  25. CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
  26. Get(ctx context.Context, id string) (Session, error)
  27. List(ctx context.Context) ([]Session, error)
  28. Save(ctx context.Context, session Session) (Session, error)
  29. Delete(ctx context.Context, id string) error
  30. }
  31. type service struct {
  32. *pubsub.Broker[Session]
  33. q db.Querier
  34. }
  35. func (s *service) Create(ctx context.Context, title string) (Session, error) {
  36. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  37. ID: uuid.New().String(),
  38. Title: title,
  39. })
  40. if err != nil {
  41. return Session{}, err
  42. }
  43. session := s.fromDBItem(dbSession)
  44. s.Publish(pubsub.CreatedEvent, session)
  45. return session, nil
  46. }
  47. func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  48. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  49. ID: toolCallID,
  50. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  51. Title: title,
  52. })
  53. if err != nil {
  54. return Session{}, err
  55. }
  56. session := s.fromDBItem(dbSession)
  57. s.Publish(pubsub.CreatedEvent, session)
  58. return session, nil
  59. }
  60. func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
  61. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  62. ID: "title-" + parentSessionID,
  63. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  64. Title: "Generate a title",
  65. })
  66. if err != nil {
  67. return Session{}, err
  68. }
  69. session := s.fromDBItem(dbSession)
  70. s.Publish(pubsub.CreatedEvent, session)
  71. return session, nil
  72. }
  73. func (s *service) Delete(ctx context.Context, id string) error {
  74. session, err := s.Get(ctx, id)
  75. if err != nil {
  76. return err
  77. }
  78. err = s.q.DeleteSession(ctx, session.ID)
  79. if err != nil {
  80. return err
  81. }
  82. s.Publish(pubsub.DeletedEvent, session)
  83. return nil
  84. }
  85. func (s *service) Get(ctx context.Context, id string) (Session, error) {
  86. dbSession, err := s.q.GetSessionByID(ctx, id)
  87. if err != nil {
  88. return Session{}, err
  89. }
  90. return s.fromDBItem(dbSession), nil
  91. }
  92. func (s *service) Save(ctx context.Context, session Session) (Session, error) {
  93. dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
  94. ID: session.ID,
  95. Title: session.Title,
  96. PromptTokens: session.PromptTokens,
  97. CompletionTokens: session.CompletionTokens,
  98. SummaryMessageID: sql.NullString{
  99. String: session.SummaryMessageID,
  100. Valid: session.SummaryMessageID != "",
  101. },
  102. Cost: session.Cost,
  103. })
  104. if err != nil {
  105. return Session{}, err
  106. }
  107. session = s.fromDBItem(dbSession)
  108. s.Publish(pubsub.UpdatedEvent, session)
  109. return session, nil
  110. }
  111. func (s *service) List(ctx context.Context) ([]Session, error) {
  112. dbSessions, err := s.q.ListSessions(ctx)
  113. if err != nil {
  114. return nil, err
  115. }
  116. sessions := make([]Session, len(dbSessions))
  117. for i, dbSession := range dbSessions {
  118. sessions[i] = s.fromDBItem(dbSession)
  119. }
  120. return sessions, nil
  121. }
  122. func (s service) fromDBItem(item db.Session) Session {
  123. return Session{
  124. ID: item.ID,
  125. ParentSessionID: item.ParentSessionID.String,
  126. Title: item.Title,
  127. MessageCount: item.MessageCount,
  128. PromptTokens: item.PromptTokens,
  129. CompletionTokens: item.CompletionTokens,
  130. SummaryMessageID: item.SummaryMessageID.String,
  131. Cost: item.Cost,
  132. CreatedAt: item.CreatedAt,
  133. UpdatedAt: item.UpdatedAt,
  134. }
  135. }
  136. func NewService(q db.Querier) Service {
  137. broker := pubsub.NewBroker[Session]()
  138. return &service{
  139. broker,
  140. q,
  141. }
  142. }