session.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. package session
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "sync"
  7. "time"
  8. "github.com/google/uuid"
  9. "github.com/sst/opencode/internal/db"
  10. "github.com/sst/opencode/internal/pubsub"
  11. )
  12. type Session struct {
  13. ID string
  14. ParentSessionID string
  15. Title string
  16. MessageCount int64
  17. PromptTokens int64
  18. CompletionTokens int64
  19. Cost float64
  20. Summary string
  21. SummarizedAt time.Time
  22. CreatedAt time.Time
  23. UpdatedAt time.Time
  24. }
  25. const (
  26. EventSessionCreated pubsub.EventType = "session_created"
  27. EventSessionUpdated pubsub.EventType = "session_updated"
  28. EventSessionDeleted pubsub.EventType = "session_deleted"
  29. )
  30. type Service interface {
  31. pubsub.Subscriber[Session]
  32. Create(ctx context.Context, title string) (Session, error)
  33. CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
  34. Get(ctx context.Context, id string) (Session, error)
  35. List(ctx context.Context) ([]Session, error)
  36. Update(ctx context.Context, session Session) (Session, error)
  37. Delete(ctx context.Context, id string) error
  38. }
  39. type service struct {
  40. db *db.Queries
  41. broker *pubsub.Broker[Session]
  42. mu sync.RWMutex
  43. }
  44. var globalSessionService *service
  45. func InitService(dbConn *sql.DB) error {
  46. if globalSessionService != nil {
  47. return fmt.Errorf("session service already initialized")
  48. }
  49. queries := db.New(dbConn)
  50. broker := pubsub.NewBroker[Session]()
  51. globalSessionService = &service{
  52. db: queries,
  53. broker: broker,
  54. }
  55. return nil
  56. }
  57. func GetService() Service {
  58. if globalSessionService == nil {
  59. panic("session service not initialized. Call session.InitService() first.")
  60. }
  61. return globalSessionService
  62. }
  63. func (s *service) Create(ctx context.Context, title string) (Session, error) {
  64. s.mu.Lock()
  65. defer s.mu.Unlock()
  66. if title == "" {
  67. title = "New Session - " + time.Now().Format("2006-01-02 15:04:05")
  68. }
  69. dbSessParams := db.CreateSessionParams{
  70. ID: uuid.New().String(),
  71. Title: title,
  72. }
  73. dbSession, err := s.db.CreateSession(ctx, dbSessParams)
  74. if err != nil {
  75. return Session{}, fmt.Errorf("db.CreateSession: %w", err)
  76. }
  77. session := s.fromDBItem(dbSession)
  78. s.broker.Publish(EventSessionCreated, session)
  79. return session, nil
  80. }
  81. func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  82. s.mu.Lock()
  83. defer s.mu.Unlock()
  84. if title == "" {
  85. title = "Task Session - " + time.Now().Format("2006-01-02 15:04:05")
  86. }
  87. if toolCallID == "" {
  88. toolCallID = uuid.New().String()
  89. }
  90. dbSessParams := db.CreateSessionParams{
  91. ID: toolCallID,
  92. ParentSessionID: sql.NullString{String: parentSessionID, Valid: parentSessionID != ""},
  93. Title: title,
  94. }
  95. dbSession, err := s.db.CreateSession(ctx, dbSessParams)
  96. if err != nil {
  97. return Session{}, fmt.Errorf("db.CreateTaskSession: %w", err)
  98. }
  99. session := s.fromDBItem(dbSession)
  100. s.broker.Publish(EventSessionCreated, session)
  101. return session, nil
  102. }
  103. func (s *service) Get(ctx context.Context, id string) (Session, error) {
  104. s.mu.RLock()
  105. defer s.mu.RUnlock()
  106. dbSession, err := s.db.GetSessionByID(ctx, id)
  107. if err != nil {
  108. if err == sql.ErrNoRows {
  109. return Session{}, fmt.Errorf("session ID '%s' not found", id)
  110. }
  111. return Session{}, fmt.Errorf("db.GetSessionByID: %w", err)
  112. }
  113. return s.fromDBItem(dbSession), nil
  114. }
  115. func (s *service) List(ctx context.Context) ([]Session, error) {
  116. s.mu.RLock()
  117. defer s.mu.RUnlock()
  118. dbSessions, err := s.db.ListSessions(ctx)
  119. if err != nil {
  120. return nil, fmt.Errorf("db.ListSessions: %w", err)
  121. }
  122. sessions := make([]Session, len(dbSessions))
  123. for i, dbSess := range dbSessions {
  124. sessions[i] = s.fromDBItem(dbSess)
  125. }
  126. return sessions, nil
  127. }
  128. func (s *service) Update(ctx context.Context, session Session) (Session, error) {
  129. s.mu.Lock()
  130. defer s.mu.Unlock()
  131. if session.ID == "" {
  132. return Session{}, fmt.Errorf("cannot update session with empty ID")
  133. }
  134. params := db.UpdateSessionParams{
  135. ID: session.ID,
  136. Title: session.Title,
  137. PromptTokens: session.PromptTokens,
  138. CompletionTokens: session.CompletionTokens,
  139. Cost: session.Cost,
  140. Summary: sql.NullString{String: session.Summary, Valid: session.Summary != ""},
  141. SummarizedAt: sql.NullString{String: session.SummarizedAt.UTC().Format(time.RFC3339Nano), Valid: !session.SummarizedAt.IsZero()},
  142. }
  143. dbSession, err := s.db.UpdateSession(ctx, params)
  144. if err != nil {
  145. return Session{}, fmt.Errorf("db.UpdateSession: %w", err)
  146. }
  147. updatedSession := s.fromDBItem(dbSession)
  148. s.broker.Publish(EventSessionUpdated, updatedSession)
  149. return updatedSession, nil
  150. }
  151. func (s *service) Delete(ctx context.Context, id string) error {
  152. s.mu.Lock()
  153. dbSess, err := s.db.GetSessionByID(ctx, id)
  154. if err != nil {
  155. s.mu.Unlock()
  156. if err == sql.ErrNoRows {
  157. return fmt.Errorf("session ID '%s' not found for deletion", id)
  158. }
  159. return fmt.Errorf("db.GetSessionByID before delete: %w", err)
  160. }
  161. sessionToPublish := s.fromDBItem(dbSess)
  162. s.mu.Unlock()
  163. s.mu.Lock()
  164. defer s.mu.Unlock()
  165. err = s.db.DeleteSession(ctx, id)
  166. if err != nil {
  167. return fmt.Errorf("db.DeleteSession: %w", err)
  168. }
  169. s.broker.Publish(EventSessionDeleted, sessionToPublish)
  170. return nil
  171. }
  172. func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
  173. return s.broker.Subscribe(ctx)
  174. }
  175. func (s *service) fromDBItem(item db.Session) Session {
  176. var summarizedAt time.Time
  177. if item.SummarizedAt.Valid {
  178. parsedTime, err := time.Parse(time.RFC3339Nano, item.SummarizedAt.String)
  179. if err == nil {
  180. summarizedAt = parsedTime
  181. }
  182. }
  183. createdAt, _ := time.Parse(time.RFC3339Nano, item.CreatedAt)
  184. updatedAt, _ := time.Parse(time.RFC3339Nano, item.UpdatedAt)
  185. return Session{
  186. ID: item.ID,
  187. ParentSessionID: item.ParentSessionID.String,
  188. Title: item.Title,
  189. MessageCount: item.MessageCount,
  190. PromptTokens: item.PromptTokens,
  191. CompletionTokens: item.CompletionTokens,
  192. Cost: item.Cost,
  193. Summary: item.Summary.String,
  194. SummarizedAt: summarizedAt,
  195. CreatedAt: createdAt,
  196. UpdatedAt: updatedAt,
  197. }
  198. }
  199. func Create(ctx context.Context, title string) (Session, error) {
  200. return GetService().Create(ctx, title)
  201. }
  202. func CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  203. return GetService().CreateTaskSession(ctx, toolCallID, parentSessionID, title)
  204. }
  205. func Get(ctx context.Context, id string) (Session, error) {
  206. return GetService().Get(ctx, id)
  207. }
  208. func List(ctx context.Context) ([]Session, error) {
  209. return GetService().List(ctx)
  210. }
  211. func Update(ctx context.Context, session Session) (Session, error) {
  212. return GetService().Update(ctx, session)
  213. }
  214. func Delete(ctx context.Context, id string) error {
  215. return GetService().Delete(ctx, id)
  216. }
  217. func Subscribe(ctx context.Context) <-chan pubsub.Event[Session] {
  218. return GetService().Subscribe(ctx)
  219. }