session.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. package session
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "strings"
  9. "github.com/charmbracelet/crush/internal/db"
  10. "github.com/charmbracelet/crush/internal/event"
  11. "github.com/charmbracelet/crush/internal/pubsub"
  12. "github.com/google/uuid"
  13. )
  14. type TodoStatus string
  15. const (
  16. TodoStatusPending TodoStatus = "pending"
  17. TodoStatusInProgress TodoStatus = "in_progress"
  18. TodoStatusCompleted TodoStatus = "completed"
  19. )
  20. type Todo struct {
  21. Content string `json:"content"`
  22. Status TodoStatus `json:"status"`
  23. ActiveForm string `json:"active_form"`
  24. }
  25. type Session struct {
  26. ID string
  27. ParentSessionID string
  28. Title string
  29. MessageCount int64
  30. PromptTokens int64
  31. CompletionTokens int64
  32. SummaryMessageID string
  33. Cost float64
  34. Todos []Todo
  35. CreatedAt int64
  36. UpdatedAt int64
  37. }
  38. type Service interface {
  39. pubsub.Subscriber[Session]
  40. Create(ctx context.Context, title string) (Session, error)
  41. CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
  42. CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
  43. Get(ctx context.Context, id string) (Session, error)
  44. List(ctx context.Context) ([]Session, error)
  45. Save(ctx context.Context, session Session) (Session, error)
  46. UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
  47. Delete(ctx context.Context, id string) error
  48. // Agent tool session management
  49. CreateAgentToolSessionID(messageID, toolCallID string) string
  50. ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
  51. IsAgentToolSession(sessionID string) bool
  52. }
  53. type service struct {
  54. *pubsub.Broker[Session]
  55. q db.Querier
  56. }
  57. func (s *service) Create(ctx context.Context, title string) (Session, error) {
  58. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  59. ID: uuid.New().String(),
  60. Title: title,
  61. })
  62. if err != nil {
  63. return Session{}, err
  64. }
  65. session := s.fromDBItem(dbSession)
  66. s.Publish(pubsub.CreatedEvent, session)
  67. event.SessionCreated()
  68. return session, nil
  69. }
  70. func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  71. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  72. ID: toolCallID,
  73. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  74. Title: title,
  75. })
  76. if err != nil {
  77. return Session{}, err
  78. }
  79. session := s.fromDBItem(dbSession)
  80. s.Publish(pubsub.CreatedEvent, session)
  81. return session, nil
  82. }
  83. func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
  84. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  85. ID: "title-" + parentSessionID,
  86. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  87. Title: "Generate a title",
  88. })
  89. if err != nil {
  90. return Session{}, err
  91. }
  92. session := s.fromDBItem(dbSession)
  93. s.Publish(pubsub.CreatedEvent, session)
  94. return session, nil
  95. }
  96. func (s *service) Delete(ctx context.Context, id string) error {
  97. session, err := s.Get(ctx, id)
  98. if err != nil {
  99. return err
  100. }
  101. err = s.q.DeleteSession(ctx, session.ID)
  102. if err != nil {
  103. return err
  104. }
  105. s.Publish(pubsub.DeletedEvent, session)
  106. event.SessionDeleted()
  107. return nil
  108. }
  109. func (s *service) Get(ctx context.Context, id string) (Session, error) {
  110. dbSession, err := s.q.GetSessionByID(ctx, id)
  111. if err != nil {
  112. return Session{}, err
  113. }
  114. return s.fromDBItem(dbSession), nil
  115. }
  116. func (s *service) Save(ctx context.Context, session Session) (Session, error) {
  117. todosJSON, err := marshalTodos(session.Todos)
  118. if err != nil {
  119. return Session{}, err
  120. }
  121. dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
  122. ID: session.ID,
  123. Title: session.Title,
  124. PromptTokens: session.PromptTokens,
  125. CompletionTokens: session.CompletionTokens,
  126. SummaryMessageID: sql.NullString{
  127. String: session.SummaryMessageID,
  128. Valid: session.SummaryMessageID != "",
  129. },
  130. Cost: session.Cost,
  131. Todos: sql.NullString{
  132. String: todosJSON,
  133. Valid: todosJSON != "",
  134. },
  135. })
  136. if err != nil {
  137. return Session{}, err
  138. }
  139. session = s.fromDBItem(dbSession)
  140. s.Publish(pubsub.UpdatedEvent, session)
  141. return session, nil
  142. }
  143. // UpdateTitleAndUsage updates only the title and usage fields atomically.
  144. // This is safer than fetching, modifying, and saving the entire session.
  145. func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
  146. return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
  147. ID: sessionID,
  148. Title: title,
  149. PromptTokens: promptTokens,
  150. CompletionTokens: completionTokens,
  151. Cost: cost,
  152. })
  153. }
  154. func (s *service) List(ctx context.Context) ([]Session, error) {
  155. dbSessions, err := s.q.ListSessions(ctx)
  156. if err != nil {
  157. return nil, err
  158. }
  159. sessions := make([]Session, len(dbSessions))
  160. for i, dbSession := range dbSessions {
  161. sessions[i] = s.fromDBItem(dbSession)
  162. }
  163. return sessions, nil
  164. }
  165. func (s service) fromDBItem(item db.Session) Session {
  166. todos, err := unmarshalTodos(item.Todos.String)
  167. if err != nil {
  168. slog.Error("failed to unmarshal todos", "session_id", item.ID, "error", err)
  169. }
  170. return Session{
  171. ID: item.ID,
  172. ParentSessionID: item.ParentSessionID.String,
  173. Title: item.Title,
  174. MessageCount: item.MessageCount,
  175. PromptTokens: item.PromptTokens,
  176. CompletionTokens: item.CompletionTokens,
  177. SummaryMessageID: item.SummaryMessageID.String,
  178. Cost: item.Cost,
  179. Todos: todos,
  180. CreatedAt: item.CreatedAt,
  181. UpdatedAt: item.UpdatedAt,
  182. }
  183. }
  184. func marshalTodos(todos []Todo) (string, error) {
  185. if len(todos) == 0 {
  186. return "", nil
  187. }
  188. data, err := json.Marshal(todos)
  189. if err != nil {
  190. return "", err
  191. }
  192. return string(data), nil
  193. }
  194. func unmarshalTodos(data string) ([]Todo, error) {
  195. if data == "" {
  196. return []Todo{}, nil
  197. }
  198. var todos []Todo
  199. if err := json.Unmarshal([]byte(data), &todos); err != nil {
  200. return []Todo{}, err
  201. }
  202. return todos, nil
  203. }
  204. func NewService(q db.Querier) Service {
  205. broker := pubsub.NewBroker[Session]()
  206. return &service{
  207. broker,
  208. q,
  209. }
  210. }
  211. // CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
  212. func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
  213. return fmt.Sprintf("%s$$%s", messageID, toolCallID)
  214. }
  215. // ParseAgentToolSessionID parses an agent tool session ID into its components
  216. func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
  217. parts := strings.Split(sessionID, "$$")
  218. if len(parts) != 2 {
  219. return "", "", false
  220. }
  221. return parts[0], parts[1], true
  222. }
  223. // IsAgentToolSession checks if a session ID follows the agent tool session format
  224. func (s *service) IsAgentToolSession(sessionID string) bool {
  225. _, _, ok := s.ParseAgentToolSessionID(sessionID)
  226. return ok
  227. }