message.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. package message
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "github.com/google/uuid"
  7. "github.com/kujtimiihoxha/termai/internal/db"
  8. "github.com/kujtimiihoxha/termai/internal/pubsub"
  9. )
  10. type MessageRole string
  11. const (
  12. Assistant MessageRole = "assistant"
  13. User MessageRole = "user"
  14. System MessageRole = "system"
  15. Tool MessageRole = "tool"
  16. )
  17. type ToolResult struct {
  18. ToolCallID string
  19. Content string
  20. IsError bool
  21. // TODO: support for images
  22. }
  23. type ToolCall struct {
  24. ID string
  25. Name string
  26. Input string
  27. Type string
  28. }
  29. type Message struct {
  30. ID string
  31. SessionID string
  32. // NEW
  33. Role MessageRole
  34. Content string
  35. Thinking string
  36. Finished bool
  37. ToolResults []ToolResult
  38. ToolCalls []ToolCall
  39. CreatedAt int64
  40. UpdatedAt int64
  41. }
  42. type CreateMessageParams struct {
  43. Role MessageRole
  44. Content string
  45. ToolCalls []ToolCall
  46. ToolResults []ToolResult
  47. }
  48. type Service interface {
  49. pubsub.Suscriber[Message]
  50. Create(sessionID string, params CreateMessageParams) (Message, error)
  51. Update(message Message) error
  52. Get(id string) (Message, error)
  53. List(sessionID string) ([]Message, error)
  54. Delete(id string) error
  55. DeleteSessionMessages(sessionID string) error
  56. }
  57. type service struct {
  58. *pubsub.Broker[Message]
  59. q db.Querier
  60. ctx context.Context
  61. }
  62. func (s *service) Delete(id string) error {
  63. message, err := s.Get(id)
  64. if err != nil {
  65. return err
  66. }
  67. err = s.q.DeleteMessage(s.ctx, message.ID)
  68. if err != nil {
  69. return err
  70. }
  71. s.Publish(pubsub.DeletedEvent, message)
  72. return nil
  73. }
  74. func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
  75. toolCallsStr, err := json.Marshal(params.ToolCalls)
  76. if err != nil {
  77. return Message{}, err
  78. }
  79. toolResultsStr, err := json.Marshal(params.ToolResults)
  80. if err != nil {
  81. return Message{}, err
  82. }
  83. dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
  84. ID: uuid.New().String(),
  85. SessionID: sessionID,
  86. Role: string(params.Role),
  87. Finished: params.Role != Assistant,
  88. Content: params.Content,
  89. ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true},
  90. ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
  91. })
  92. if err != nil {
  93. return Message{}, err
  94. }
  95. message, err := s.fromDBItem(dbMessage)
  96. if err != nil {
  97. return Message{}, err
  98. }
  99. s.Publish(pubsub.CreatedEvent, message)
  100. return message, nil
  101. }
  102. func (s *service) DeleteSessionMessages(sessionID string) error {
  103. messages, err := s.List(sessionID)
  104. if err != nil {
  105. return err
  106. }
  107. for _, message := range messages {
  108. if message.SessionID == sessionID {
  109. err = s.Delete(message.ID)
  110. if err != nil {
  111. return err
  112. }
  113. }
  114. }
  115. return nil
  116. }
  117. func (s *service) Update(message Message) error {
  118. toolCallsStr, err := json.Marshal(message.ToolCalls)
  119. if err != nil {
  120. return err
  121. }
  122. toolResultsStr, err := json.Marshal(message.ToolResults)
  123. if err != nil {
  124. return err
  125. }
  126. err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
  127. ID: message.ID,
  128. Content: message.Content,
  129. Thinking: message.Thinking,
  130. Finished: message.Finished,
  131. ToolCalls: sql.NullString{String: string(toolCallsStr), Valid: true},
  132. ToolResults: sql.NullString{String: string(toolResultsStr), Valid: true},
  133. })
  134. if err != nil {
  135. return err
  136. }
  137. s.Publish(pubsub.UpdatedEvent, message)
  138. return nil
  139. }
  140. func (s *service) Get(id string) (Message, error) {
  141. dbMessage, err := s.q.GetMessage(s.ctx, id)
  142. if err != nil {
  143. return Message{}, err
  144. }
  145. return s.fromDBItem(dbMessage)
  146. }
  147. func (s *service) List(sessionID string) ([]Message, error) {
  148. dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID)
  149. if err != nil {
  150. return nil, err
  151. }
  152. messages := make([]Message, len(dbMessages))
  153. for i, dbMessage := range dbMessages {
  154. messages[i], err = s.fromDBItem(dbMessage)
  155. if err != nil {
  156. return nil, err
  157. }
  158. }
  159. return messages, nil
  160. }
  161. func (s *service) fromDBItem(item db.Message) (Message, error) {
  162. toolCalls := make([]ToolCall, 0)
  163. if item.ToolCalls.Valid {
  164. err := json.Unmarshal([]byte(item.ToolCalls.String), &toolCalls)
  165. if err != nil {
  166. return Message{}, err
  167. }
  168. }
  169. toolResults := make([]ToolResult, 0)
  170. if item.ToolResults.Valid {
  171. err := json.Unmarshal([]byte(item.ToolResults.String), &toolResults)
  172. if err != nil {
  173. return Message{}, err
  174. }
  175. }
  176. return Message{
  177. ID: item.ID,
  178. SessionID: item.SessionID,
  179. Role: MessageRole(item.Role),
  180. Content: item.Content,
  181. Thinking: item.Thinking,
  182. Finished: item.Finished,
  183. ToolCalls: toolCalls,
  184. ToolResults: toolResults,
  185. CreatedAt: item.CreatedAt,
  186. UpdatedAt: item.UpdatedAt,
  187. }, nil
  188. }
  189. func NewService(ctx context.Context, q db.Querier) Service {
  190. return &service{
  191. Broker: pubsub.NewBroker[Message](),
  192. q: q,
  193. ctx: ctx,
  194. }
  195. }