message.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. package message
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/google/uuid"
  8. "github.com/kujtimiihoxha/termai/internal/db"
  9. "github.com/kujtimiihoxha/termai/internal/llm/models"
  10. "github.com/kujtimiihoxha/termai/internal/pubsub"
  11. )
  12. type CreateMessageParams struct {
  13. Role MessageRole
  14. Parts []ContentPart
  15. Model models.ModelID
  16. }
  17. type Service interface {
  18. pubsub.Suscriber[Message]
  19. Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
  20. Update(ctx context.Context, message Message) error
  21. Get(ctx context.Context, id string) (Message, error)
  22. List(ctx context.Context, sessionID string) ([]Message, error)
  23. Delete(ctx context.Context, id string) error
  24. DeleteSessionMessages(ctx context.Context, sessionID string) error
  25. }
  26. type service struct {
  27. *pubsub.Broker[Message]
  28. q db.Querier
  29. }
  30. func NewService(q db.Querier) Service {
  31. return &service{
  32. Broker: pubsub.NewBroker[Message](),
  33. q: q,
  34. }
  35. }
  36. func (s *service) Delete(ctx context.Context, id string) error {
  37. message, err := s.Get(ctx, id)
  38. if err != nil {
  39. return err
  40. }
  41. err = s.q.DeleteMessage(ctx, message.ID)
  42. if err != nil {
  43. return err
  44. }
  45. s.Publish(pubsub.DeletedEvent, message)
  46. return nil
  47. }
  48. func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  49. if params.Role != Assistant {
  50. params.Parts = append(params.Parts, Finish{
  51. Reason: "stop",
  52. })
  53. }
  54. partsJSON, err := marshallParts(params.Parts)
  55. if err != nil {
  56. return Message{}, err
  57. }
  58. dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
  59. ID: uuid.New().String(),
  60. SessionID: sessionID,
  61. Role: string(params.Role),
  62. Parts: string(partsJSON),
  63. Model: sql.NullString{String: string(params.Model), Valid: true},
  64. })
  65. if err != nil {
  66. return Message{}, err
  67. }
  68. message, err := s.fromDBItem(dbMessage)
  69. if err != nil {
  70. return Message{}, err
  71. }
  72. s.Publish(pubsub.CreatedEvent, message)
  73. return message, nil
  74. }
  75. func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
  76. messages, err := s.List(ctx, sessionID)
  77. if err != nil {
  78. return err
  79. }
  80. for _, message := range messages {
  81. if message.SessionID == sessionID {
  82. err = s.Delete(ctx, message.ID)
  83. if err != nil {
  84. return err
  85. }
  86. }
  87. }
  88. return nil
  89. }
  90. func (s *service) Update(ctx context.Context, message Message) error {
  91. parts, err := marshallParts(message.Parts)
  92. if err != nil {
  93. return err
  94. }
  95. finishedAt := sql.NullInt64{}
  96. if f := message.FinishPart(); f != nil {
  97. finishedAt.Int64 = f.Time
  98. finishedAt.Valid = true
  99. }
  100. err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
  101. ID: message.ID,
  102. Parts: string(parts),
  103. FinishedAt: finishedAt,
  104. })
  105. if err != nil {
  106. return err
  107. }
  108. s.Publish(pubsub.UpdatedEvent, message)
  109. return nil
  110. }
  111. func (s *service) Get(ctx context.Context, id string) (Message, error) {
  112. dbMessage, err := s.q.GetMessage(ctx, id)
  113. if err != nil {
  114. return Message{}, err
  115. }
  116. return s.fromDBItem(dbMessage)
  117. }
  118. func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
  119. dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
  120. if err != nil {
  121. return nil, err
  122. }
  123. messages := make([]Message, len(dbMessages))
  124. for i, dbMessage := range dbMessages {
  125. messages[i], err = s.fromDBItem(dbMessage)
  126. if err != nil {
  127. return nil, err
  128. }
  129. }
  130. return messages, nil
  131. }
  132. func (s *service) fromDBItem(item db.Message) (Message, error) {
  133. parts, err := unmarshallParts([]byte(item.Parts))
  134. if err != nil {
  135. return Message{}, err
  136. }
  137. return Message{
  138. ID: item.ID,
  139. SessionID: item.SessionID,
  140. Role: MessageRole(item.Role),
  141. Parts: parts,
  142. Model: models.ModelID(item.Model.String),
  143. CreatedAt: item.CreatedAt,
  144. UpdatedAt: item.UpdatedAt,
  145. }, nil
  146. }
  147. type partType string
  148. const (
  149. reasoningType partType = "reasoning"
  150. textType partType = "text"
  151. imageURLType partType = "image_url"
  152. binaryType partType = "binary"
  153. toolCallType partType = "tool_call"
  154. toolResultType partType = "tool_result"
  155. finishType partType = "finish"
  156. )
  157. type partWrapper struct {
  158. Type partType `json:"type"`
  159. Data ContentPart `json:"data"`
  160. }
  161. func marshallParts(parts []ContentPart) ([]byte, error) {
  162. wrappedParts := make([]partWrapper, len(parts))
  163. for i, part := range parts {
  164. var typ partType
  165. switch part.(type) {
  166. case ReasoningContent:
  167. typ = reasoningType
  168. case TextContent:
  169. typ = textType
  170. case ImageURLContent:
  171. typ = imageURLType
  172. case BinaryContent:
  173. typ = binaryType
  174. case ToolCall:
  175. typ = toolCallType
  176. case ToolResult:
  177. typ = toolResultType
  178. case Finish:
  179. typ = finishType
  180. default:
  181. return nil, fmt.Errorf("unknown part type: %T", part)
  182. }
  183. wrappedParts[i] = partWrapper{
  184. Type: typ,
  185. Data: part,
  186. }
  187. }
  188. return json.Marshal(wrappedParts)
  189. }
  190. func unmarshallParts(data []byte) ([]ContentPart, error) {
  191. temp := []json.RawMessage{}
  192. if err := json.Unmarshal(data, &temp); err != nil {
  193. return nil, err
  194. }
  195. parts := make([]ContentPart, 0)
  196. for _, rawPart := range temp {
  197. var wrapper struct {
  198. Type partType `json:"type"`
  199. Data json.RawMessage `json:"data"`
  200. }
  201. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  202. return nil, err
  203. }
  204. switch wrapper.Type {
  205. case reasoningType:
  206. part := ReasoningContent{}
  207. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  208. return nil, err
  209. }
  210. parts = append(parts, part)
  211. case textType:
  212. part := TextContent{}
  213. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  214. return nil, err
  215. }
  216. parts = append(parts, part)
  217. case imageURLType:
  218. part := ImageURLContent{}
  219. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  220. return nil, err
  221. }
  222. case binaryType:
  223. part := BinaryContent{}
  224. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  225. return nil, err
  226. }
  227. parts = append(parts, part)
  228. case toolCallType:
  229. part := ToolCall{}
  230. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  231. return nil, err
  232. }
  233. parts = append(parts, part)
  234. case toolResultType:
  235. part := ToolResult{}
  236. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  237. return nil, err
  238. }
  239. parts = append(parts, part)
  240. case finishType:
  241. part := Finish{}
  242. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  243. return nil, err
  244. }
  245. parts = append(parts, part)
  246. default:
  247. return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
  248. }
  249. }
  250. return parts, nil
  251. }