session.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. package session
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "strings"
  7. "github.com/charmbracelet/crush/internal/db"
  8. "github.com/charmbracelet/crush/internal/event"
  9. "github.com/charmbracelet/crush/internal/pubsub"
  10. "github.com/google/uuid"
  11. )
  12. type Session struct {
  13. ID string
  14. ParentSessionID string
  15. Title string
  16. MessageCount int64
  17. PromptTokens int64
  18. CompletionTokens int64
  19. SummaryMessageID string
  20. Cost float64
  21. CreatedAt int64
  22. UpdatedAt int64
  23. }
  24. type Service interface {
  25. pubsub.Suscriber[Session]
  26. Create(ctx context.Context, title string) (Session, error)
  27. CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
  28. CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
  29. Get(ctx context.Context, id string) (Session, error)
  30. List(ctx context.Context) ([]Session, error)
  31. Save(ctx context.Context, session Session) (Session, error)
  32. Delete(ctx context.Context, id string) error
  33. // Agent tool session management
  34. CreateAgentToolSessionID(messageID, toolCallID string) string
  35. ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
  36. IsAgentToolSession(sessionID string) bool
  37. }
  38. type service struct {
  39. *pubsub.Broker[Session]
  40. q db.Querier
  41. }
  42. func (s *service) Create(ctx context.Context, title string) (Session, error) {
  43. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  44. ID: uuid.New().String(),
  45. Title: title,
  46. })
  47. if err != nil {
  48. return Session{}, err
  49. }
  50. session := s.fromDBItem(dbSession)
  51. s.Publish(pubsub.CreatedEvent, session)
  52. event.SessionCreated()
  53. return session, nil
  54. }
  55. func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  56. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  57. ID: toolCallID,
  58. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  59. Title: title,
  60. })
  61. if err != nil {
  62. return Session{}, err
  63. }
  64. session := s.fromDBItem(dbSession)
  65. s.Publish(pubsub.CreatedEvent, session)
  66. return session, nil
  67. }
  68. func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
  69. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  70. ID: "title-" + parentSessionID,
  71. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  72. Title: "Generate a title",
  73. })
  74. if err != nil {
  75. return Session{}, err
  76. }
  77. session := s.fromDBItem(dbSession)
  78. s.Publish(pubsub.CreatedEvent, session)
  79. return session, nil
  80. }
  81. func (s *service) Delete(ctx context.Context, id string) error {
  82. session, err := s.Get(ctx, id)
  83. if err != nil {
  84. return err
  85. }
  86. err = s.q.DeleteSession(ctx, session.ID)
  87. if err != nil {
  88. return err
  89. }
  90. s.Publish(pubsub.DeletedEvent, session)
  91. event.SessionDeleted()
  92. return nil
  93. }
  94. func (s *service) Get(ctx context.Context, id string) (Session, error) {
  95. dbSession, err := s.q.GetSessionByID(ctx, id)
  96. if err != nil {
  97. return Session{}, err
  98. }
  99. return s.fromDBItem(dbSession), nil
  100. }
  101. func (s *service) Save(ctx context.Context, session Session) (Session, error) {
  102. dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
  103. ID: session.ID,
  104. Title: session.Title,
  105. PromptTokens: session.PromptTokens,
  106. CompletionTokens: session.CompletionTokens,
  107. SummaryMessageID: sql.NullString{
  108. String: session.SummaryMessageID,
  109. Valid: session.SummaryMessageID != "",
  110. },
  111. Cost: session.Cost,
  112. })
  113. if err != nil {
  114. return Session{}, err
  115. }
  116. session = s.fromDBItem(dbSession)
  117. s.Publish(pubsub.UpdatedEvent, session)
  118. return session, nil
  119. }
  120. func (s *service) List(ctx context.Context) ([]Session, error) {
  121. dbSessions, err := s.q.ListSessions(ctx)
  122. if err != nil {
  123. return nil, err
  124. }
  125. sessions := make([]Session, len(dbSessions))
  126. for i, dbSession := range dbSessions {
  127. sessions[i] = s.fromDBItem(dbSession)
  128. }
  129. return sessions, nil
  130. }
  131. func (s service) fromDBItem(item db.Session) Session {
  132. return Session{
  133. ID: item.ID,
  134. ParentSessionID: item.ParentSessionID.String,
  135. Title: item.Title,
  136. MessageCount: item.MessageCount,
  137. PromptTokens: item.PromptTokens,
  138. CompletionTokens: item.CompletionTokens,
  139. SummaryMessageID: item.SummaryMessageID.String,
  140. Cost: item.Cost,
  141. CreatedAt: item.CreatedAt,
  142. UpdatedAt: item.UpdatedAt,
  143. }
  144. }
  145. func NewService(q db.Querier) Service {
  146. broker := pubsub.NewBroker[Session]()
  147. return &service{
  148. broker,
  149. q,
  150. }
  151. }
  152. // CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
  153. func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
  154. return fmt.Sprintf("%s$$%s", messageID, toolCallID)
  155. }
  156. // ParseAgentToolSessionID parses an agent tool session ID into its components
  157. func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
  158. parts := strings.Split(sessionID, "$$")
  159. if len(parts) != 2 {
  160. return "", "", false
  161. }
  162. return parts[0], parts[1], true
  163. }
  164. // IsAgentToolSession checks if a session ID follows the agent tool session format
  165. func (s *service) IsAgentToolSession(sessionID string) bool {
  166. _, _, ok := s.ParseAgentToolSessionID(sessionID)
  167. return ok
  168. }