session.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. package session
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "strings"
  9. "github.com/charmbracelet/crush/internal/db"
  10. "github.com/charmbracelet/crush/internal/event"
  11. "github.com/charmbracelet/crush/internal/pubsub"
  12. "github.com/google/uuid"
  13. "github.com/zeebo/xxh3"
  14. )
  15. type TodoStatus string
  16. const (
  17. TodoStatusPending TodoStatus = "pending"
  18. TodoStatusInProgress TodoStatus = "in_progress"
  19. TodoStatusCompleted TodoStatus = "completed"
  20. )
  21. // HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
  22. func HashID(id string) string {
  23. h := xxh3.New()
  24. h.WriteString(id)
  25. return fmt.Sprintf("%x", h.Sum(nil))
  26. }
  27. type Todo struct {
  28. Content string `json:"content"`
  29. Status TodoStatus `json:"status"`
  30. ActiveForm string `json:"active_form"`
  31. }
  32. // HasIncompleteTodos returns true if there are any non-completed todos.
  33. func HasIncompleteTodos(todos []Todo) bool {
  34. for _, todo := range todos {
  35. if todo.Status != TodoStatusCompleted {
  36. return true
  37. }
  38. }
  39. return false
  40. }
  41. type Session struct {
  42. ID string
  43. ParentSessionID string
  44. Title string
  45. MessageCount int64
  46. PromptTokens int64
  47. CompletionTokens int64
  48. SummaryMessageID string
  49. Cost float64
  50. Todos []Todo
  51. CreatedAt int64
  52. UpdatedAt int64
  53. }
  54. type Service interface {
  55. pubsub.Subscriber[Session]
  56. Create(ctx context.Context, title string) (Session, error)
  57. CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
  58. CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
  59. Get(ctx context.Context, id string) (Session, error)
  60. GetLast(ctx context.Context) (Session, error)
  61. List(ctx context.Context) ([]Session, error)
  62. Save(ctx context.Context, session Session) (Session, error)
  63. UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
  64. Rename(ctx context.Context, id string, title string) error
  65. Delete(ctx context.Context, id string) error
  66. // Agent tool session management
  67. CreateAgentToolSessionID(messageID, toolCallID string) string
  68. ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
  69. IsAgentToolSession(sessionID string) bool
  70. }
  71. type service struct {
  72. *pubsub.Broker[Session]
  73. db *sql.DB
  74. q *db.Queries
  75. }
  76. func (s *service) Create(ctx context.Context, title string) (Session, error) {
  77. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  78. ID: uuid.New().String(),
  79. Title: title,
  80. })
  81. if err != nil {
  82. return Session{}, err
  83. }
  84. session := s.fromDBItem(dbSession)
  85. s.Publish(pubsub.CreatedEvent, session)
  86. event.SessionCreated()
  87. return session, nil
  88. }
  89. func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  90. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  91. ID: toolCallID,
  92. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  93. Title: title,
  94. })
  95. if err != nil {
  96. return Session{}, err
  97. }
  98. session := s.fromDBItem(dbSession)
  99. s.Publish(pubsub.CreatedEvent, session)
  100. return session, nil
  101. }
  102. func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
  103. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  104. ID: "title-" + parentSessionID,
  105. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  106. Title: "Generate a title",
  107. })
  108. if err != nil {
  109. return Session{}, err
  110. }
  111. session := s.fromDBItem(dbSession)
  112. s.Publish(pubsub.CreatedEvent, session)
  113. return session, nil
  114. }
  115. func (s *service) Delete(ctx context.Context, id string) error {
  116. tx, err := s.db.BeginTx(ctx, nil)
  117. if err != nil {
  118. return fmt.Errorf("beginning transaction: %w", err)
  119. }
  120. defer tx.Rollback() //nolint:errcheck
  121. qtx := s.q.WithTx(tx)
  122. dbSession, err := qtx.GetSessionByID(ctx, id)
  123. if err != nil {
  124. return err
  125. }
  126. if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
  127. return fmt.Errorf("deleting session messages: %w", err)
  128. }
  129. if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
  130. return fmt.Errorf("deleting session files: %w", err)
  131. }
  132. if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
  133. return fmt.Errorf("deleting session: %w", err)
  134. }
  135. if err = tx.Commit(); err != nil {
  136. return fmt.Errorf("committing transaction: %w", err)
  137. }
  138. session := s.fromDBItem(dbSession)
  139. s.Publish(pubsub.DeletedEvent, session)
  140. event.SessionDeleted()
  141. return nil
  142. }
  143. func (s *service) Get(ctx context.Context, id string) (Session, error) {
  144. dbSession, err := s.q.GetSessionByID(ctx, id)
  145. if err != nil {
  146. return Session{}, err
  147. }
  148. return s.fromDBItem(dbSession), nil
  149. }
  150. func (s *service) GetLast(ctx context.Context) (Session, error) {
  151. dbSession, err := s.q.GetLastSession(ctx)
  152. if err != nil {
  153. return Session{}, err
  154. }
  155. return s.fromDBItem(dbSession), nil
  156. }
  157. func (s *service) Save(ctx context.Context, session Session) (Session, error) {
  158. todosJSON, err := marshalTodos(session.Todos)
  159. if err != nil {
  160. return Session{}, err
  161. }
  162. dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
  163. ID: session.ID,
  164. Title: session.Title,
  165. PromptTokens: session.PromptTokens,
  166. CompletionTokens: session.CompletionTokens,
  167. SummaryMessageID: sql.NullString{
  168. String: session.SummaryMessageID,
  169. Valid: session.SummaryMessageID != "",
  170. },
  171. Cost: session.Cost,
  172. Todos: sql.NullString{
  173. String: todosJSON,
  174. Valid: todosJSON != "",
  175. },
  176. })
  177. if err != nil {
  178. return Session{}, err
  179. }
  180. session = s.fromDBItem(dbSession)
  181. s.Publish(pubsub.UpdatedEvent, session)
  182. return session, nil
  183. }
  184. // UpdateTitleAndUsage updates only the title and usage fields atomically.
  185. // This is safer than fetching, modifying, and saving the entire session.
  186. func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
  187. return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
  188. ID: sessionID,
  189. Title: title,
  190. PromptTokens: promptTokens,
  191. CompletionTokens: completionTokens,
  192. Cost: cost,
  193. })
  194. }
  195. // Rename updates only the title of a session without touching updated_at or
  196. // usage fields.
  197. func (s *service) Rename(ctx context.Context, id string, title string) error {
  198. return s.q.RenameSession(ctx, db.RenameSessionParams{
  199. ID: id,
  200. Title: title,
  201. })
  202. }
  203. func (s *service) List(ctx context.Context) ([]Session, error) {
  204. dbSessions, err := s.q.ListSessions(ctx)
  205. if err != nil {
  206. return nil, err
  207. }
  208. sessions := make([]Session, len(dbSessions))
  209. for i, dbSession := range dbSessions {
  210. sessions[i] = s.fromDBItem(dbSession)
  211. }
  212. return sessions, nil
  213. }
  214. func (s service) fromDBItem(item db.Session) Session {
  215. todos, err := unmarshalTodos(item.Todos.String)
  216. if err != nil {
  217. slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
  218. }
  219. return Session{
  220. ID: item.ID,
  221. ParentSessionID: item.ParentSessionID.String,
  222. Title: item.Title,
  223. MessageCount: item.MessageCount,
  224. PromptTokens: item.PromptTokens,
  225. CompletionTokens: item.CompletionTokens,
  226. SummaryMessageID: item.SummaryMessageID.String,
  227. Cost: item.Cost,
  228. Todos: todos,
  229. CreatedAt: item.CreatedAt,
  230. UpdatedAt: item.UpdatedAt,
  231. }
  232. }
  233. func marshalTodos(todos []Todo) (string, error) {
  234. if len(todos) == 0 {
  235. return "", nil
  236. }
  237. data, err := json.Marshal(todos)
  238. if err != nil {
  239. return "", err
  240. }
  241. return string(data), nil
  242. }
  243. func unmarshalTodos(data string) ([]Todo, error) {
  244. if data == "" {
  245. return []Todo{}, nil
  246. }
  247. var todos []Todo
  248. if err := json.Unmarshal([]byte(data), &todos); err != nil {
  249. return []Todo{}, err
  250. }
  251. return todos, nil
  252. }
  253. func NewService(q *db.Queries, conn *sql.DB) Service {
  254. broker := pubsub.NewBroker[Session]()
  255. return &service{
  256. Broker: broker,
  257. db: conn,
  258. q: q,
  259. }
  260. }
  261. // CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
  262. func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
  263. return fmt.Sprintf("%s$$%s", messageID, toolCallID)
  264. }
  265. // ParseAgentToolSessionID parses an agent tool session ID into its components
  266. func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
  267. parts := strings.Split(sessionID, "$$")
  268. if len(parts) != 2 {
  269. return "", "", false
  270. }
  271. return parts[0], parts[1], true
  272. }
  273. // IsAgentToolSession checks if a session ID follows the agent tool session format
  274. func (s *service) IsAgentToolSession(sessionID string) bool {
  275. _, _, ok := s.ParseAgentToolSessionID(sessionID)
  276. return ok
  277. }