session.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. package session
  2. import (
  3. "context"
  4. "database/sql"
  5. "github.com/google/uuid"
  6. "github.com/opencode-ai/opencode/internal/db"
  7. "github.com/opencode-ai/opencode/internal/pubsub"
  8. )
  9. type Session struct {
  10. ID string
  11. ParentSessionID string
  12. Title string
  13. MessageCount int64
  14. PromptTokens int64
  15. CompletionTokens int64
  16. Cost float64
  17. Summary string
  18. SummarizedAt int64
  19. CreatedAt int64
  20. UpdatedAt int64
  21. }
  22. type Service interface {
  23. pubsub.Suscriber[Session]
  24. Create(ctx context.Context, title string) (Session, error)
  25. CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
  26. CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
  27. Get(ctx context.Context, id string) (Session, error)
  28. List(ctx context.Context) ([]Session, error)
  29. Save(ctx context.Context, session Session) (Session, error)
  30. Delete(ctx context.Context, id string) error
  31. }
  32. type service struct {
  33. *pubsub.Broker[Session]
  34. q db.Querier
  35. }
  36. func (s *service) Create(ctx context.Context, title string) (Session, error) {
  37. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  38. ID: uuid.New().String(),
  39. Title: title,
  40. })
  41. if err != nil {
  42. return Session{}, err
  43. }
  44. session := s.fromDBItem(dbSession)
  45. s.Publish(pubsub.CreatedEvent, session)
  46. return session, nil
  47. }
  48. func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
  49. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  50. ID: toolCallID,
  51. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  52. Title: title,
  53. })
  54. if err != nil {
  55. return Session{}, err
  56. }
  57. session := s.fromDBItem(dbSession)
  58. s.Publish(pubsub.CreatedEvent, session)
  59. return session, nil
  60. }
  61. func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
  62. dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
  63. ID: "title-" + parentSessionID,
  64. ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
  65. Title: "Generate a title",
  66. })
  67. if err != nil {
  68. return Session{}, err
  69. }
  70. session := s.fromDBItem(dbSession)
  71. s.Publish(pubsub.CreatedEvent, session)
  72. return session, nil
  73. }
  74. func (s *service) Delete(ctx context.Context, id string) error {
  75. session, err := s.Get(ctx, id)
  76. if err != nil {
  77. return err
  78. }
  79. err = s.q.DeleteSession(ctx, session.ID)
  80. if err != nil {
  81. return err
  82. }
  83. s.Publish(pubsub.DeletedEvent, session)
  84. return nil
  85. }
  86. func (s *service) Get(ctx context.Context, id string) (Session, error) {
  87. dbSession, err := s.q.GetSessionByID(ctx, id)
  88. if err != nil {
  89. return Session{}, err
  90. }
  91. return s.fromDBItem(dbSession), nil
  92. }
  93. func (s *service) Save(ctx context.Context, session Session) (Session, error) {
  94. summary := sql.NullString{}
  95. if session.Summary != "" {
  96. summary.String = session.Summary
  97. summary.Valid = true
  98. }
  99. summarizedAt := sql.NullInt64{}
  100. if session.SummarizedAt != 0 {
  101. summarizedAt.Int64 = session.SummarizedAt
  102. summarizedAt.Valid = true
  103. }
  104. dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
  105. ID: session.ID,
  106. Title: session.Title,
  107. PromptTokens: session.PromptTokens,
  108. CompletionTokens: session.CompletionTokens,
  109. Cost: session.Cost,
  110. Summary: summary,
  111. SummarizedAt: summarizedAt,
  112. })
  113. if err != nil {
  114. return Session{}, err
  115. }
  116. session = s.fromDBItem(dbSession)
  117. s.Publish(pubsub.UpdatedEvent, session)
  118. return session, nil
  119. }
  120. func (s *service) List(ctx context.Context) ([]Session, error) {
  121. dbSessions, err := s.q.ListSessions(ctx)
  122. if err != nil {
  123. return nil, err
  124. }
  125. sessions := make([]Session, len(dbSessions))
  126. for i, dbSession := range dbSessions {
  127. sessions[i] = s.fromDBItem(dbSession)
  128. }
  129. return sessions, nil
  130. }
  131. func (s service) fromDBItem(item db.Session) Session {
  132. return Session{
  133. ID: item.ID,
  134. ParentSessionID: item.ParentSessionID.String,
  135. Title: item.Title,
  136. MessageCount: item.MessageCount,
  137. PromptTokens: item.PromptTokens,
  138. CompletionTokens: item.CompletionTokens,
  139. Cost: item.Cost,
  140. Summary: item.Summary.String,
  141. SummarizedAt: item.SummarizedAt.Int64,
  142. CreatedAt: item.CreatedAt,
  143. UpdatedAt: item.UpdatedAt,
  144. }
  145. }
  146. func NewService(q db.Querier) Service {
  147. broker := pubsub.NewBroker[Session]()
  148. return &service{
  149. broker,
  150. q,
  151. }
  152. }