message.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. package message
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/google/uuid"
  7. "github.com/kujtimiihoxha/termai/internal/db"
  8. "github.com/kujtimiihoxha/termai/internal/pubsub"
  9. )
  10. type CreateMessageParams struct {
  11. Role MessageRole
  12. Parts []ContentPart
  13. }
  14. type Service interface {
  15. pubsub.Suscriber[Message]
  16. Create(sessionID string, params CreateMessageParams) (Message, error)
  17. Update(message Message) error
  18. Get(id string) (Message, error)
  19. List(sessionID string) ([]Message, error)
  20. Delete(id string) error
  21. DeleteSessionMessages(sessionID string) error
  22. }
  23. type service struct {
  24. *pubsub.Broker[Message]
  25. q db.Querier
  26. ctx context.Context
  27. }
  28. func NewService(ctx context.Context, q db.Querier) Service {
  29. return &service{
  30. Broker: pubsub.NewBroker[Message](),
  31. q: q,
  32. ctx: ctx,
  33. }
  34. }
  35. func (s *service) Delete(id string) error {
  36. message, err := s.Get(id)
  37. if err != nil {
  38. return err
  39. }
  40. err = s.q.DeleteMessage(s.ctx, message.ID)
  41. if err != nil {
  42. return err
  43. }
  44. s.Publish(pubsub.DeletedEvent, message)
  45. return nil
  46. }
  47. func (s *service) Create(sessionID string, params CreateMessageParams) (Message, error) {
  48. if params.Role != Assistant {
  49. params.Parts = append(params.Parts, Finish{
  50. Reason: "stop",
  51. })
  52. }
  53. partsJSON, err := marshallParts(params.Parts)
  54. if err != nil {
  55. return Message{}, err
  56. }
  57. dbMessage, err := s.q.CreateMessage(s.ctx, db.CreateMessageParams{
  58. ID: uuid.New().String(),
  59. SessionID: sessionID,
  60. Role: string(params.Role),
  61. Parts: string(partsJSON),
  62. })
  63. if err != nil {
  64. return Message{}, err
  65. }
  66. message, err := s.fromDBItem(dbMessage)
  67. if err != nil {
  68. return Message{}, err
  69. }
  70. s.Publish(pubsub.CreatedEvent, message)
  71. return message, nil
  72. }
  73. func (s *service) DeleteSessionMessages(sessionID string) error {
  74. messages, err := s.List(sessionID)
  75. if err != nil {
  76. return err
  77. }
  78. for _, message := range messages {
  79. if message.SessionID == sessionID {
  80. err = s.Delete(message.ID)
  81. if err != nil {
  82. return err
  83. }
  84. }
  85. }
  86. return nil
  87. }
  88. func (s *service) Update(message Message) error {
  89. parts, err := marshallParts(message.Parts)
  90. if err != nil {
  91. return err
  92. }
  93. err = s.q.UpdateMessage(s.ctx, db.UpdateMessageParams{
  94. ID: message.ID,
  95. Parts: string(parts),
  96. })
  97. if err != nil {
  98. return err
  99. }
  100. s.Publish(pubsub.UpdatedEvent, message)
  101. return nil
  102. }
  103. func (s *service) Get(id string) (Message, error) {
  104. dbMessage, err := s.q.GetMessage(s.ctx, id)
  105. if err != nil {
  106. return Message{}, err
  107. }
  108. return s.fromDBItem(dbMessage)
  109. }
  110. func (s *service) List(sessionID string) ([]Message, error) {
  111. dbMessages, err := s.q.ListMessagesBySession(s.ctx, sessionID)
  112. if err != nil {
  113. return nil, err
  114. }
  115. messages := make([]Message, len(dbMessages))
  116. for i, dbMessage := range dbMessages {
  117. messages[i], err = s.fromDBItem(dbMessage)
  118. if err != nil {
  119. return nil, err
  120. }
  121. }
  122. return messages, nil
  123. }
  124. func (s *service) fromDBItem(item db.Message) (Message, error) {
  125. parts, err := unmarshallParts([]byte(item.Parts))
  126. if err != nil {
  127. return Message{}, err
  128. }
  129. return Message{
  130. ID: item.ID,
  131. SessionID: item.SessionID,
  132. Role: MessageRole(item.Role),
  133. Parts: parts,
  134. CreatedAt: item.CreatedAt,
  135. UpdatedAt: item.UpdatedAt,
  136. }, nil
  137. }
  138. type partType string
  139. const (
  140. reasoningType partType = "reasoning"
  141. textType partType = "text"
  142. imageURLType partType = "image_url"
  143. binaryType partType = "binary"
  144. toolCallType partType = "tool_call"
  145. toolResultType partType = "tool_result"
  146. finishType partType = "finish"
  147. )
  148. type partWrapper struct {
  149. Type partType `json:"type"`
  150. Data ContentPart `json:"data"`
  151. }
  152. func marshallParts(parts []ContentPart) ([]byte, error) {
  153. wrappedParts := make([]partWrapper, len(parts))
  154. for i, part := range parts {
  155. var typ partType
  156. switch part.(type) {
  157. case ReasoningContent:
  158. typ = reasoningType
  159. case TextContent:
  160. typ = textType
  161. case ImageURLContent:
  162. typ = imageURLType
  163. case BinaryContent:
  164. typ = binaryType
  165. case ToolCall:
  166. typ = toolCallType
  167. case ToolResult:
  168. typ = toolResultType
  169. case Finish:
  170. typ = finishType
  171. default:
  172. return nil, fmt.Errorf("unknown part type: %T", part)
  173. }
  174. wrappedParts[i] = partWrapper{
  175. Type: typ,
  176. Data: part,
  177. }
  178. }
  179. return json.Marshal(wrappedParts)
  180. }
  181. func unmarshallParts(data []byte) ([]ContentPart, error) {
  182. temp := []json.RawMessage{}
  183. if err := json.Unmarshal(data, &temp); err != nil {
  184. return nil, err
  185. }
  186. parts := make([]ContentPart, 0)
  187. for _, rawPart := range temp {
  188. var wrapper struct {
  189. Type partType `json:"type"`
  190. Data json.RawMessage `json:"data"`
  191. }
  192. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  193. return nil, err
  194. }
  195. switch wrapper.Type {
  196. case reasoningType:
  197. part := ReasoningContent{}
  198. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  199. return nil, err
  200. }
  201. parts = append(parts, part)
  202. case textType:
  203. part := TextContent{}
  204. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  205. return nil, err
  206. }
  207. parts = append(parts, part)
  208. case imageURLType:
  209. part := ImageURLContent{}
  210. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  211. return nil, err
  212. }
  213. case binaryType:
  214. part := BinaryContent{}
  215. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  216. return nil, err
  217. }
  218. parts = append(parts, part)
  219. case toolCallType:
  220. part := ToolCall{}
  221. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  222. return nil, err
  223. }
  224. parts = append(parts, part)
  225. case toolResultType:
  226. part := ToolResult{}
  227. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  228. return nil, err
  229. }
  230. parts = append(parts, part)
  231. case finishType:
  232. part := Finish{}
  233. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  234. return nil, err
  235. }
  236. parts = append(parts, part)
  237. default:
  238. return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
  239. }
  240. }
  241. return parts, nil
  242. }