session.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. package session
  2. import (
  3. "context"
  4. "database/sql"
  5. "github.com/charmbracelet/crush/internal/db"
  6. "github.com/charmbracelet/crush/internal/event"
  7. "github.com/charmbracelet/crush/internal/pubsub"
  8. "github.com/google/uuid"
  9. )
  10. type Session struct {
  11. ID string
  12. ParentSessionID string
  13. Title string
  14. MessageCount int64
  15. PromptTokens int64
  16. CompletionTokens int64
  17. SummaryMessageID string
  18. Cost float64
  19. CreatedAt int64
  20. UpdatedAt int64
  21. }
  22. type Service interface {
  23. pubsub.Suscriber[Session]
  24. Create(ctx context.Context, title string) (Session, error)
  25. CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
  26. CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
  27. Get(ctx context.Context, id string) (Session, error)
  28. List(ctx context.Context) ([]Session, error)
  29. Save(ctx context.Context, session Session) (Session, error)
  30. Delete(ctx context.Context, id string) error
  31. }
  32. type service struct {
  33. *pubsub.Broker[Session]
  34. q db.Querier
  35. }
  36. func (s *service) Create(ctx context.Context, title string) (Session, error) {
  37. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  38. ID: uuid.New().String(),
  39. Title: title,
  40. })
  41. if err != nil {
  42. return Session{}, err
  43. }
  44. session := s.fromDBItem(dbSession)
  45. s.Publish(pubsub.CreatedEvent, session)
  46. event.SessionCreated()
  47. return session, nil
  48. }
  49. func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  50. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  51. ID: toolCallID,
  52. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  53. Title: title,
  54. })
  55. if err != nil {
  56. return Session{}, err
  57. }
  58. session := s.fromDBItem(dbSession)
  59. s.Publish(pubsub.CreatedEvent, session)
  60. return session, nil
  61. }
  62. func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
  63. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  64. ID: "title-" + parentSessionID,
  65. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  66. Title: "Generate a title",
  67. })
  68. if err != nil {
  69. return Session{}, err
  70. }
  71. session := s.fromDBItem(dbSession)
  72. s.Publish(pubsub.CreatedEvent, session)
  73. return session, nil
  74. }
  75. func (s *service) Delete(ctx context.Context, id string) error {
  76. session, err := s.Get(ctx, id)
  77. if err != nil {
  78. return err
  79. }
  80. err = s.q.DeleteSession(ctx, session.ID)
  81. if err != nil {
  82. return err
  83. }
  84. s.Publish(pubsub.DeletedEvent, session)
  85. event.SessionDeleted()
  86. return nil
  87. }
  88. func (s *service) Get(ctx context.Context, id string) (Session, error) {
  89. dbSession, err := s.q.GetSessionByID(ctx, id)
  90. if err != nil {
  91. return Session{}, err
  92. }
  93. return s.fromDBItem(dbSession), nil
  94. }
  95. func (s *service) Save(ctx context.Context, session Session) (Session, error) {
  96. dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
  97. ID: session.ID,
  98. Title: session.Title,
  99. PromptTokens: session.PromptTokens,
  100. CompletionTokens: session.CompletionTokens,
  101. SummaryMessageID: sql.NullString{
  102. String: session.SummaryMessageID,
  103. Valid: session.SummaryMessageID != "",
  104. },
  105. Cost: session.Cost,
  106. })
  107. if err != nil {
  108. return Session{}, err
  109. }
  110. session = s.fromDBItem(dbSession)
  111. s.Publish(pubsub.UpdatedEvent, session)
  112. return session, nil
  113. }
  114. func (s *service) List(ctx context.Context) ([]Session, error) {
  115. dbSessions, err := s.q.ListSessions(ctx)
  116. if err != nil {
  117. return nil, err
  118. }
  119. sessions := make([]Session, len(dbSessions))
  120. for i, dbSession := range dbSessions {
  121. sessions[i] = s.fromDBItem(dbSession)
  122. }
  123. return sessions, nil
  124. }
  125. func (s service) fromDBItem(item db.Session) Session {
  126. return Session{
  127. ID: item.ID,
  128. ParentSessionID: item.ParentSessionID.String,
  129. Title: item.Title,
  130. MessageCount: item.MessageCount,
  131. PromptTokens: item.PromptTokens,
  132. CompletionTokens: item.CompletionTokens,
  133. SummaryMessageID: item.SummaryMessageID.String,
  134. Cost: item.Cost,
  135. CreatedAt: item.CreatedAt,
  136. UpdatedAt: item.UpdatedAt,
  137. }
  138. }
  139. func NewService(q db.Querier) Service {
  140. broker := pubsub.NewBroker[Session]()
  141. return &service{
  142. broker,
  143. q,
  144. }
  145. }