session.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. package session
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "strings"
  9. "github.com/charmbracelet/crush/internal/config"
  10. "github.com/charmbracelet/crush/internal/db"
  11. "github.com/charmbracelet/crush/internal/event"
  12. "github.com/charmbracelet/crush/internal/pubsub"
  13. "github.com/google/uuid"
  14. "github.com/zeebo/xxh3"
  15. )
  16. type TodoStatus string
  17. const (
  18. TodoStatusPending TodoStatus = "pending"
  19. TodoStatusInProgress TodoStatus = "in_progress"
  20. TodoStatusCompleted TodoStatus = "completed"
  21. )
  22. // HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
  23. func HashID(id string) string {
  24. h := xxh3.New()
  25. h.WriteString(id)
  26. return fmt.Sprintf("%x", h.Sum(nil))
  27. }
  28. type Todo struct {
  29. Content string `json:"content"`
  30. Status TodoStatus `json:"status"`
  31. ActiveForm string `json:"active_form"`
  32. }
  33. // HasIncompleteTodos returns true if there are any non-completed todos.
  34. func HasIncompleteTodos(todos []Todo) bool {
  35. for _, todo := range todos {
  36. if todo.Status != TodoStatusCompleted {
  37. return true
  38. }
  39. }
  40. return false
  41. }
  42. type Session struct {
  43. ID string
  44. ParentSessionID string
  45. Title string
  46. MessageCount int64
  47. PromptTokens int64
  48. CompletionTokens int64
  49. SummaryMessageID string
  50. Cost float64
  51. Todos []Todo
  52. Models map[config.SelectedModelType]config.SelectedModel
  53. CreatedAt int64
  54. UpdatedAt int64
  55. }
  56. type Service interface {
  57. pubsub.Subscriber[Session]
  58. Create(ctx context.Context, title string) (Session, error)
  59. CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
  60. CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
  61. Get(ctx context.Context, id string) (Session, error)
  62. GetLast(ctx context.Context) (Session, error)
  63. List(ctx context.Context) ([]Session, error)
  64. Save(ctx context.Context, session Session) (Session, error)
  65. SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error)
  66. UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error
  67. UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
  68. Rename(ctx context.Context, id string, title string) error
  69. Delete(ctx context.Context, id string) error
  70. // Agent tool session management
  71. CreateAgentToolSessionID(messageID, toolCallID string) string
  72. ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
  73. IsAgentToolSession(sessionID string) bool
  74. }
  75. type service struct {
  76. *pubsub.Broker[Session]
  77. db *sql.DB
  78. q *db.Queries
  79. }
  80. func (s *service) Create(ctx context.Context, title string) (Session, error) {
  81. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  82. ID: uuid.New().String(),
  83. Title: title,
  84. })
  85. if err != nil {
  86. return Session{}, err
  87. }
  88. session := s.fromDBItem(dbSession)
  89. s.Publish(pubsub.CreatedEvent, session)
  90. event.SessionCreated()
  91. return session, nil
  92. }
  93. func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  94. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  95. ID: toolCallID,
  96. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  97. Title: title,
  98. })
  99. if err != nil {
  100. return Session{}, err
  101. }
  102. session := s.fromDBItem(dbSession)
  103. s.Publish(pubsub.CreatedEvent, session)
  104. return session, nil
  105. }
  106. func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
  107. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  108. ID: "title-" + parentSessionID,
  109. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  110. Title: "Generate a title",
  111. })
  112. if err != nil {
  113. return Session{}, err
  114. }
  115. session := s.fromDBItem(dbSession)
  116. s.Publish(pubsub.CreatedEvent, session)
  117. return session, nil
  118. }
  119. func (s *service) Delete(ctx context.Context, id string) error {
  120. tx, err := s.db.BeginTx(ctx, nil)
  121. if err != nil {
  122. return fmt.Errorf("beginning transaction: %w", err)
  123. }
  124. defer tx.Rollback() //nolint:errcheck
  125. qtx := s.q.WithTx(tx)
  126. dbSession, err := qtx.GetSessionByID(ctx, id)
  127. if err != nil {
  128. return err
  129. }
  130. if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
  131. return fmt.Errorf("deleting session messages: %w", err)
  132. }
  133. if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
  134. return fmt.Errorf("deleting session files: %w", err)
  135. }
  136. if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
  137. return fmt.Errorf("deleting session: %w", err)
  138. }
  139. if err = tx.Commit(); err != nil {
  140. return fmt.Errorf("committing transaction: %w", err)
  141. }
  142. session := s.fromDBItem(dbSession)
  143. s.Publish(pubsub.DeletedEvent, session)
  144. event.SessionDeleted()
  145. return nil
  146. }
  147. func (s *service) Get(ctx context.Context, id string) (Session, error) {
  148. dbSession, err := s.q.GetSessionByID(ctx, id)
  149. if err != nil {
  150. return Session{}, err
  151. }
  152. return s.fromDBItem(dbSession), nil
  153. }
  154. func (s *service) GetLast(ctx context.Context) (Session, error) {
  155. dbSession, err := s.q.GetLastSession(ctx)
  156. if err != nil {
  157. return Session{}, err
  158. }
  159. return s.fromDBItem(dbSession), nil
  160. }
  161. func (s *service) Save(ctx context.Context, session Session) (Session, error) {
  162. todosJSON, err := marshalTodos(session.Todos)
  163. if err != nil {
  164. return Session{}, err
  165. }
  166. dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
  167. ID: session.ID,
  168. Title: session.Title,
  169. PromptTokens: session.PromptTokens,
  170. CompletionTokens: session.CompletionTokens,
  171. SummaryMessageID: sql.NullString{
  172. String: session.SummaryMessageID,
  173. Valid: session.SummaryMessageID != "",
  174. },
  175. Cost: session.Cost,
  176. Todos: sql.NullString{
  177. String: todosJSON,
  178. Valid: todosJSON != "",
  179. },
  180. })
  181. if err != nil {
  182. return Session{}, err
  183. }
  184. session = s.fromDBItem(dbSession)
  185. s.Publish(pubsub.UpdatedEvent, session)
  186. return session, nil
  187. }
  188. // UpdateTitleAndUsage updates only the title and usage fields atomically.
  189. // This is safer than fetching, modifying, and saving the entire session.
  190. func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
  191. return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
  192. ID: sessionID,
  193. Title: title,
  194. PromptTokens: promptTokens,
  195. CompletionTokens: completionTokens,
  196. Cost: cost,
  197. })
  198. }
  199. // Rename updates only the title of a session without touching updated_at or
  200. // usage fields.
  201. func (s *service) Rename(ctx context.Context, id string, title string) error {
  202. return s.q.RenameSession(ctx, db.RenameSessionParams{
  203. ID: id,
  204. Title: title,
  205. })
  206. }
  207. func (s *service) List(ctx context.Context) ([]Session, error) {
  208. dbSessions, err := s.q.ListSessions(ctx)
  209. if err != nil {
  210. return nil, err
  211. }
  212. sessions := make([]Session, len(dbSessions))
  213. for i, dbSession := range dbSessions {
  214. sessions[i] = s.fromDBItem(dbSession)
  215. }
  216. return sessions, nil
  217. }
  218. func (s service) fromDBItem(item db.Session) Session {
  219. todos, err := unmarshalTodos(item.Todos.String)
  220. if err != nil {
  221. slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
  222. }
  223. models, err := unmarshalModels(item.Models.String)
  224. if err != nil {
  225. slog.Error("Failed to unmarshal models", "session_id", item.ID, "error", err)
  226. }
  227. return Session{
  228. ID: item.ID,
  229. ParentSessionID: item.ParentSessionID.String,
  230. Title: item.Title,
  231. MessageCount: item.MessageCount,
  232. PromptTokens: item.PromptTokens,
  233. CompletionTokens: item.CompletionTokens,
  234. SummaryMessageID: item.SummaryMessageID.String,
  235. Cost: item.Cost,
  236. Todos: todos,
  237. Models: models,
  238. CreatedAt: item.CreatedAt,
  239. UpdatedAt: item.UpdatedAt,
  240. }
  241. }
  242. func marshalTodos(todos []Todo) (string, error) {
  243. if len(todos) == 0 {
  244. return "", nil
  245. }
  246. data, err := json.Marshal(todos)
  247. if err != nil {
  248. return "", err
  249. }
  250. return string(data), nil
  251. }
  252. func unmarshalTodos(data string) ([]Todo, error) {
  253. if data == "" {
  254. return []Todo{}, nil
  255. }
  256. var todos []Todo
  257. if err := json.Unmarshal([]byte(data), &todos); err != nil {
  258. return []Todo{}, err
  259. }
  260. return todos, nil
  261. }
  262. func marshalModels(models map[config.SelectedModelType]config.SelectedModel) (string, error) {
  263. if len(models) == 0 {
  264. return "", nil
  265. }
  266. data, err := json.Marshal(models)
  267. if err != nil {
  268. return "", err
  269. }
  270. return string(data), nil
  271. }
  272. func unmarshalModels(data string) (map[config.SelectedModelType]config.SelectedModel, error) {
  273. if data == "" {
  274. return nil, nil
  275. }
  276. var models map[config.SelectedModelType]config.SelectedModel
  277. if err := json.Unmarshal([]byte(data), &models); err != nil {
  278. return nil, err
  279. }
  280. return models, nil
  281. }
  282. func (s *service) UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error {
  283. modelsJSON, err := marshalModels(models)
  284. if err != nil {
  285. return fmt.Errorf("failed to marshal models: %w", err)
  286. }
  287. _, err = s.q.UpdateSessionModels(ctx, db.UpdateSessionModelsParams{
  288. Models: sql.NullString{String: modelsJSON, Valid: modelsJSON != ""},
  289. ID: id,
  290. })
  291. return err
  292. }
  293. // SaveWithModels saves the session and then persists the models column as a
  294. // second operation. This is intentionally non-atomic: if the models update
  295. // fails, the session fields are still saved (which is equivalent to the
  296. // pre-feature behavior where models were never persisted). The next agent turn
  297. // will retry the models write, so transient failures are self-healing.
  298. func (s *service) SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error) {
  299. saved, err := s.Save(ctx, session)
  300. if err != nil {
  301. return Session{}, err
  302. }
  303. if err := s.UpdateSessionModels(ctx, session.ID, models); err != nil {
  304. return Session{}, fmt.Errorf("failed to persist models: %w", err)
  305. }
  306. return saved, nil
  307. }
  308. func NewService(q *db.Queries, conn *sql.DB) Service {
  309. broker := pubsub.NewBroker[Session]()
  310. return &service{
  311. Broker: broker,
  312. db: conn,
  313. q: q,
  314. }
  315. }
  316. // CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
  317. func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
  318. return fmt.Sprintf("%s$$%s", messageID, toolCallID)
  319. }
  320. // ParseAgentToolSessionID parses an agent tool session ID into its components
  321. func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
  322. parts := strings.Split(sessionID, "$$")
  323. if len(parts) != 2 {
  324. return "", "", false
  325. }
  326. return parts[0], parts[1], true
  327. }
  328. // IsAgentToolSession checks if a session ID follows the agent tool session format
  329. func (s *service) IsAgentToolSession(sessionID string) bool {
  330. _, _, ok := s.ParseAgentToolSessionID(sessionID)
  331. return ok
  332. }