session.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. package session
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "strings"
  9. "github.com/charmbracelet/crush/internal/db"
  10. "github.com/charmbracelet/crush/internal/event"
  11. "github.com/charmbracelet/crush/internal/pubsub"
  12. "github.com/google/uuid"
  13. )
  14. type TodoStatus string
  15. const (
  16. TodoStatusPending TodoStatus = "pending"
  17. TodoStatusInProgress TodoStatus = "in_progress"
  18. TodoStatusCompleted TodoStatus = "completed"
  19. )
  20. type Todo struct {
  21. Content string `json:"content"`
  22. Status TodoStatus `json:"status"`
  23. ActiveForm string `json:"active_form"`
  24. }
  25. // HasIncompleteTodos returns true if there are any non-completed todos.
  26. func HasIncompleteTodos(todos []Todo) bool {
  27. for _, todo := range todos {
  28. if todo.Status != TodoStatusCompleted {
  29. return true
  30. }
  31. }
  32. return false
  33. }
  34. type Session struct {
  35. ID string
  36. ParentSessionID string
  37. Title string
  38. MessageCount int64
  39. PromptTokens int64
  40. CompletionTokens int64
  41. SummaryMessageID string
  42. Cost float64
  43. Todos []Todo
  44. CreatedAt int64
  45. UpdatedAt int64
  46. }
  47. type Service interface {
  48. pubsub.Subscriber[Session]
  49. Create(ctx context.Context, title string) (Session, error)
  50. CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
  51. CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
  52. Get(ctx context.Context, id string) (Session, error)
  53. List(ctx context.Context) ([]Session, error)
  54. Save(ctx context.Context, session Session) (Session, error)
  55. UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
  56. Delete(ctx context.Context, id string) error
  57. // Agent tool session management
  58. CreateAgentToolSessionID(messageID, toolCallID string) string
  59. ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
  60. IsAgentToolSession(sessionID string) bool
  61. }
  62. type service struct {
  63. *pubsub.Broker[Session]
  64. db *sql.DB
  65. q *db.Queries
  66. }
  67. func (s *service) Create(ctx context.Context, title string) (Session, error) {
  68. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  69. ID: uuid.New().String(),
  70. Title: title,
  71. })
  72. if err != nil {
  73. return Session{}, err
  74. }
  75. session := s.fromDBItem(dbSession)
  76. s.Publish(pubsub.CreatedEvent, session)
  77. event.SessionCreated()
  78. return session, nil
  79. }
  80. func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  81. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  82. ID: toolCallID,
  83. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  84. Title: title,
  85. })
  86. if err != nil {
  87. return Session{}, err
  88. }
  89. session := s.fromDBItem(dbSession)
  90. s.Publish(pubsub.CreatedEvent, session)
  91. return session, nil
  92. }
  93. func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
  94. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  95. ID: "title-" + parentSessionID,
  96. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  97. Title: "Generate a title",
  98. })
  99. if err != nil {
  100. return Session{}, err
  101. }
  102. session := s.fromDBItem(dbSession)
  103. s.Publish(pubsub.CreatedEvent, session)
  104. return session, nil
  105. }
  106. func (s *service) Delete(ctx context.Context, id string) error {
  107. tx, err := s.db.BeginTx(ctx, nil)
  108. if err != nil {
  109. return fmt.Errorf("beginning transaction: %w", err)
  110. }
  111. defer tx.Rollback() //nolint:errcheck
  112. qtx := s.q.WithTx(tx)
  113. dbSession, err := qtx.GetSessionByID(ctx, id)
  114. if err != nil {
  115. return err
  116. }
  117. if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
  118. return fmt.Errorf("deleting session messages: %w", err)
  119. }
  120. if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
  121. return fmt.Errorf("deleting session files: %w", err)
  122. }
  123. if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
  124. return fmt.Errorf("deleting session: %w", err)
  125. }
  126. if err = tx.Commit(); err != nil {
  127. return fmt.Errorf("committing transaction: %w", err)
  128. }
  129. session := s.fromDBItem(dbSession)
  130. s.Publish(pubsub.DeletedEvent, session)
  131. event.SessionDeleted()
  132. return nil
  133. }
  134. func (s *service) Get(ctx context.Context, id string) (Session, error) {
  135. dbSession, err := s.q.GetSessionByID(ctx, id)
  136. if err != nil {
  137. return Session{}, err
  138. }
  139. return s.fromDBItem(dbSession), nil
  140. }
  141. func (s *service) Save(ctx context.Context, session Session) (Session, error) {
  142. todosJSON, err := marshalTodos(session.Todos)
  143. if err != nil {
  144. return Session{}, err
  145. }
  146. dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
  147. ID: session.ID,
  148. Title: session.Title,
  149. PromptTokens: session.PromptTokens,
  150. CompletionTokens: session.CompletionTokens,
  151. SummaryMessageID: sql.NullString{
  152. String: session.SummaryMessageID,
  153. Valid: session.SummaryMessageID != "",
  154. },
  155. Cost: session.Cost,
  156. Todos: sql.NullString{
  157. String: todosJSON,
  158. Valid: todosJSON != "",
  159. },
  160. })
  161. if err != nil {
  162. return Session{}, err
  163. }
  164. session = s.fromDBItem(dbSession)
  165. s.Publish(pubsub.UpdatedEvent, session)
  166. return session, nil
  167. }
  168. // UpdateTitleAndUsage updates only the title and usage fields atomically.
  169. // This is safer than fetching, modifying, and saving the entire session.
  170. func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
  171. return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
  172. ID: sessionID,
  173. Title: title,
  174. PromptTokens: promptTokens,
  175. CompletionTokens: completionTokens,
  176. Cost: cost,
  177. })
  178. }
  179. func (s *service) List(ctx context.Context) ([]Session, error) {
  180. dbSessions, err := s.q.ListSessions(ctx)
  181. if err != nil {
  182. return nil, err
  183. }
  184. sessions := make([]Session, len(dbSessions))
  185. for i, dbSession := range dbSessions {
  186. sessions[i] = s.fromDBItem(dbSession)
  187. }
  188. return sessions, nil
  189. }
  190. func (s service) fromDBItem(item db.Session) Session {
  191. todos, err := unmarshalTodos(item.Todos.String)
  192. if err != nil {
  193. slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
  194. }
  195. return Session{
  196. ID: item.ID,
  197. ParentSessionID: item.ParentSessionID.String,
  198. Title: item.Title,
  199. MessageCount: item.MessageCount,
  200. PromptTokens: item.PromptTokens,
  201. CompletionTokens: item.CompletionTokens,
  202. SummaryMessageID: item.SummaryMessageID.String,
  203. Cost: item.Cost,
  204. Todos: todos,
  205. CreatedAt: item.CreatedAt,
  206. UpdatedAt: item.UpdatedAt,
  207. }
  208. }
  209. func marshalTodos(todos []Todo) (string, error) {
  210. if len(todos) == 0 {
  211. return "", nil
  212. }
  213. data, err := json.Marshal(todos)
  214. if err != nil {
  215. return "", err
  216. }
  217. return string(data), nil
  218. }
  219. func unmarshalTodos(data string) ([]Todo, error) {
  220. if data == "" {
  221. return []Todo{}, nil
  222. }
  223. var todos []Todo
  224. if err := json.Unmarshal([]byte(data), &todos); err != nil {
  225. return []Todo{}, err
  226. }
  227. return todos, nil
  228. }
  229. func NewService(q *db.Queries, conn *sql.DB) Service {
  230. broker := pubsub.NewBroker[Session]()
  231. return &service{
  232. Broker: broker,
  233. db: conn,
  234. q: q,
  235. }
  236. }
  237. // CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
  238. func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
  239. return fmt.Sprintf("%s$$%s", messageID, toolCallID)
  240. }
  241. // ParseAgentToolSessionID parses an agent tool session ID into its components
  242. func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
  243. parts := strings.Split(sessionID, "$$")
  244. if len(parts) != 2 {
  245. return "", "", false
  246. }
  247. return parts[0], parts[1], true
  248. }
  249. // IsAgentToolSession checks if a session ID follows the agent tool session format
  250. func (s *service) IsAgentToolSession(sessionID string) bool {
  251. _, _, ok := s.ParseAgentToolSessionID(sessionID)
  252. return ok
  253. }