session.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. package session
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "strings"
  9. "github.com/charmbracelet/crush/internal/db"
  10. "github.com/charmbracelet/crush/internal/event"
  11. "github.com/charmbracelet/crush/internal/pubsub"
  12. "github.com/google/uuid"
  13. )
  14. type TodoStatus string
  15. const (
  16. TodoStatusPending TodoStatus = "pending"
  17. TodoStatusInProgress TodoStatus = "in_progress"
  18. TodoStatusCompleted TodoStatus = "completed"
  19. )
  20. type Todo struct {
  21. Content string `json:"content"`
  22. Status TodoStatus `json:"status"`
  23. ActiveForm string `json:"active_form"`
  24. }
  25. type Session struct {
  26. ID string
  27. ParentSessionID string
  28. Title string
  29. MessageCount int64
  30. PromptTokens int64
  31. CompletionTokens int64
  32. SummaryMessageID string
  33. Cost float64
  34. Todos []Todo
  35. CreatedAt int64
  36. UpdatedAt int64
  37. }
  38. type Service interface {
  39. pubsub.Subscriber[Session]
  40. Create(ctx context.Context, title string) (Session, error)
  41. CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
  42. CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
  43. Get(ctx context.Context, id string) (Session, error)
  44. List(ctx context.Context) ([]Session, error)
  45. Save(ctx context.Context, session Session) (Session, error)
  46. Delete(ctx context.Context, id string) error
  47. // Agent tool session management
  48. CreateAgentToolSessionID(messageID, toolCallID string) string
  49. ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
  50. IsAgentToolSession(sessionID string) bool
  51. }
  52. type service struct {
  53. *pubsub.Broker[Session]
  54. q db.Querier
  55. }
  56. func (s *service) Create(ctx context.Context, title string) (Session, error) {
  57. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  58. ID: uuid.New().String(),
  59. Title: title,
  60. })
  61. if err != nil {
  62. return Session{}, err
  63. }
  64. session := s.fromDBItem(dbSession)
  65. s.Publish(pubsub.CreatedEvent, session)
  66. event.SessionCreated()
  67. return session, nil
  68. }
  69. func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  70. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  71. ID: toolCallID,
  72. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  73. Title: title,
  74. })
  75. if err != nil {
  76. return Session{}, err
  77. }
  78. session := s.fromDBItem(dbSession)
  79. s.Publish(pubsub.CreatedEvent, session)
  80. return session, nil
  81. }
  82. func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
  83. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  84. ID: "title-" + parentSessionID,
  85. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  86. Title: "Generate a title",
  87. })
  88. if err != nil {
  89. return Session{}, err
  90. }
  91. session := s.fromDBItem(dbSession)
  92. s.Publish(pubsub.CreatedEvent, session)
  93. return session, nil
  94. }
  95. func (s *service) Delete(ctx context.Context, id string) error {
  96. session, err := s.Get(ctx, id)
  97. if err != nil {
  98. return err
  99. }
  100. err = s.q.DeleteSession(ctx, session.ID)
  101. if err != nil {
  102. return err
  103. }
  104. s.Publish(pubsub.DeletedEvent, session)
  105. event.SessionDeleted()
  106. return nil
  107. }
  108. func (s *service) Get(ctx context.Context, id string) (Session, error) {
  109. dbSession, err := s.q.GetSessionByID(ctx, id)
  110. if err != nil {
  111. return Session{}, err
  112. }
  113. return s.fromDBItem(dbSession), nil
  114. }
  115. func (s *service) Save(ctx context.Context, session Session) (Session, error) {
  116. todosJSON, err := marshalTodos(session.Todos)
  117. if err != nil {
  118. return Session{}, err
  119. }
  120. dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
  121. ID: session.ID,
  122. Title: session.Title,
  123. PromptTokens: session.PromptTokens,
  124. CompletionTokens: session.CompletionTokens,
  125. SummaryMessageID: sql.NullString{
  126. String: session.SummaryMessageID,
  127. Valid: session.SummaryMessageID != "",
  128. },
  129. Cost: session.Cost,
  130. Todos: sql.NullString{
  131. String: todosJSON,
  132. Valid: todosJSON != "",
  133. },
  134. })
  135. if err != nil {
  136. return Session{}, err
  137. }
  138. session = s.fromDBItem(dbSession)
  139. s.Publish(pubsub.UpdatedEvent, session)
  140. return session, nil
  141. }
  142. func (s *service) List(ctx context.Context) ([]Session, error) {
  143. dbSessions, err := s.q.ListSessions(ctx)
  144. if err != nil {
  145. return nil, err
  146. }
  147. sessions := make([]Session, len(dbSessions))
  148. for i, dbSession := range dbSessions {
  149. sessions[i] = s.fromDBItem(dbSession)
  150. }
  151. return sessions, nil
  152. }
  153. func (s service) fromDBItem(item db.Session) Session {
  154. todos, err := unmarshalTodos(item.Todos.String)
  155. if err != nil {
  156. slog.Error("failed to unmarshal todos", "session_id", item.ID, "error", err)
  157. }
  158. return Session{
  159. ID: item.ID,
  160. ParentSessionID: item.ParentSessionID.String,
  161. Title: item.Title,
  162. MessageCount: item.MessageCount,
  163. PromptTokens: item.PromptTokens,
  164. CompletionTokens: item.CompletionTokens,
  165. SummaryMessageID: item.SummaryMessageID.String,
  166. Cost: item.Cost,
  167. Todos: todos,
  168. CreatedAt: item.CreatedAt,
  169. UpdatedAt: item.UpdatedAt,
  170. }
  171. }
  172. func marshalTodos(todos []Todo) (string, error) {
  173. if len(todos) == 0 {
  174. return "", nil
  175. }
  176. data, err := json.Marshal(todos)
  177. if err != nil {
  178. return "", err
  179. }
  180. return string(data), nil
  181. }
  182. func unmarshalTodos(data string) ([]Todo, error) {
  183. if data == "" {
  184. return []Todo{}, nil
  185. }
  186. var todos []Todo
  187. if err := json.Unmarshal([]byte(data), &todos); err != nil {
  188. return []Todo{}, err
  189. }
  190. return todos, nil
  191. }
  192. func NewService(q db.Querier) Service {
  193. broker := pubsub.NewBroker[Session]()
  194. return &service{
  195. broker,
  196. q,
  197. }
  198. }
  199. // CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
  200. func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
  201. return fmt.Sprintf("%s$$%s", messageID, toolCallID)
  202. }
  203. // ParseAgentToolSessionID parses an agent tool session ID into its components
  204. func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
  205. parts := strings.Split(sessionID, "$$")
  206. if len(parts) != 2 {
  207. return "", "", false
  208. }
  209. return parts[0], parts[1], true
  210. }
  211. // IsAgentToolSession checks if a session ID follows the agent tool session format
  212. func (s *service) IsAgentToolSession(sessionID string) bool {
  213. _, _, ok := s.ParseAgentToolSessionID(sessionID)
  214. return ok
  215. }