message.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. package message
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "time"
  8. "github.com/charmbracelet/crush/internal/db"
  9. "github.com/charmbracelet/crush/internal/pubsub"
  10. "github.com/google/uuid"
  11. )
  12. type CreateMessageParams struct {
  13. Role MessageRole
  14. Parts []ContentPart
  15. Model string
  16. Provider string
  17. IsSummaryMessage bool
  18. }
  19. type Service interface {
  20. pubsub.Subscriber[Message]
  21. Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
  22. Update(ctx context.Context, message Message) error
  23. Get(ctx context.Context, id string) (Message, error)
  24. List(ctx context.Context, sessionID string) ([]Message, error)
  25. Delete(ctx context.Context, id string) error
  26. DeleteSessionMessages(ctx context.Context, sessionID string) error
  27. }
  28. type service struct {
  29. *pubsub.Broker[Message]
  30. q db.Querier
  31. }
  32. func NewService(q db.Querier) Service {
  33. return &service{
  34. Broker: pubsub.NewBroker[Message](),
  35. q: q,
  36. }
  37. }
  38. func (s *service) Delete(ctx context.Context, id string) error {
  39. message, err := s.Get(ctx, id)
  40. if err != nil {
  41. return err
  42. }
  43. err = s.q.DeleteMessage(ctx, message.ID)
  44. if err != nil {
  45. return err
  46. }
  47. // Clone the message before publishing to avoid race conditions with
  48. // concurrent modifications to the Parts slice.
  49. s.Publish(ctx, pubsub.DeletedEvent, message.Clone())
  50. return nil
  51. }
  52. func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  53. if params.Role != Assistant {
  54. params.Parts = append(params.Parts, Finish{
  55. Reason: "stop",
  56. })
  57. }
  58. partsJSON, err := marshallParts(params.Parts)
  59. if err != nil {
  60. return Message{}, err
  61. }
  62. isSummary := int64(0)
  63. if params.IsSummaryMessage {
  64. isSummary = 1
  65. }
  66. dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
  67. ID: uuid.New().String(),
  68. SessionID: sessionID,
  69. Role: string(params.Role),
  70. Parts: string(partsJSON),
  71. Model: sql.NullString{String: string(params.Model), Valid: true},
  72. Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
  73. IsSummaryMessage: isSummary,
  74. })
  75. if err != nil {
  76. return Message{}, err
  77. }
  78. message, err := s.fromDBItem(dbMessage)
  79. if err != nil {
  80. return Message{}, err
  81. }
  82. // Clone the message before publishing to avoid race conditions with
  83. // concurrent modifications to the Parts slice.
  84. s.Publish(ctx, pubsub.CreatedEvent, message.Clone())
  85. return message, nil
  86. }
  87. func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
  88. messages, err := s.List(ctx, sessionID)
  89. if err != nil {
  90. return err
  91. }
  92. for _, message := range messages {
  93. if message.SessionID == sessionID {
  94. err = s.Delete(ctx, message.ID)
  95. if err != nil {
  96. return err
  97. }
  98. }
  99. }
  100. return nil
  101. }
  102. func (s *service) Update(ctx context.Context, message Message) error {
  103. parts, err := marshallParts(message.Parts)
  104. if err != nil {
  105. return err
  106. }
  107. finishedAt := sql.NullInt64{}
  108. if f := message.FinishPart(); f != nil {
  109. finishedAt.Int64 = f.Time
  110. finishedAt.Valid = true
  111. }
  112. err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
  113. ID: message.ID,
  114. Parts: string(parts),
  115. FinishedAt: finishedAt,
  116. })
  117. if err != nil {
  118. return err
  119. }
  120. message.UpdatedAt = time.Now().Unix()
  121. // Clone the message before publishing to avoid race conditions with
  122. // concurrent modifications to the Parts slice.
  123. s.Publish(ctx, pubsub.UpdatedEvent, message.Clone())
  124. return nil
  125. }
  126. func (s *service) Get(ctx context.Context, id string) (Message, error) {
  127. dbMessage, err := s.q.GetMessage(ctx, id)
  128. if err != nil {
  129. return Message{}, err
  130. }
  131. return s.fromDBItem(dbMessage)
  132. }
  133. func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
  134. dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
  135. if err != nil {
  136. return nil, err
  137. }
  138. messages := make([]Message, len(dbMessages))
  139. for i, dbMessage := range dbMessages {
  140. messages[i], err = s.fromDBItem(dbMessage)
  141. if err != nil {
  142. return nil, err
  143. }
  144. }
  145. return messages, nil
  146. }
  147. func (s *service) fromDBItem(item db.Message) (Message, error) {
  148. parts, err := unmarshallParts([]byte(item.Parts))
  149. if err != nil {
  150. return Message{}, err
  151. }
  152. return Message{
  153. ID: item.ID,
  154. SessionID: item.SessionID,
  155. Role: MessageRole(item.Role),
  156. Parts: parts,
  157. Model: item.Model.String,
  158. Provider: item.Provider.String,
  159. CreatedAt: item.CreatedAt,
  160. UpdatedAt: item.UpdatedAt,
  161. IsSummaryMessage: item.IsSummaryMessage != 0,
  162. }, nil
  163. }
  164. type partType string
  165. const (
  166. reasoningType partType = "reasoning"
  167. textType partType = "text"
  168. imageURLType partType = "image_url"
  169. binaryType partType = "binary"
  170. toolCallType partType = "tool_call"
  171. toolResultType partType = "tool_result"
  172. finishType partType = "finish"
  173. )
  174. type partWrapper struct {
  175. Type partType `json:"type"`
  176. Data ContentPart `json:"data"`
  177. }
  178. func marshallParts(parts []ContentPart) ([]byte, error) {
  179. wrappedParts := make([]partWrapper, len(parts))
  180. for i, part := range parts {
  181. var typ partType
  182. switch part.(type) {
  183. case ReasoningContent:
  184. typ = reasoningType
  185. case TextContent:
  186. typ = textType
  187. case ImageURLContent:
  188. typ = imageURLType
  189. case BinaryContent:
  190. typ = binaryType
  191. case ToolCall:
  192. typ = toolCallType
  193. case ToolResult:
  194. typ = toolResultType
  195. case Finish:
  196. typ = finishType
  197. default:
  198. return nil, fmt.Errorf("unknown part type: %T", part)
  199. }
  200. wrappedParts[i] = partWrapper{
  201. Type: typ,
  202. Data: part,
  203. }
  204. }
  205. return json.Marshal(wrappedParts)
  206. }
  207. func unmarshallParts(data []byte) ([]ContentPart, error) {
  208. temp := []json.RawMessage{}
  209. if err := json.Unmarshal(data, &temp); err != nil {
  210. return nil, err
  211. }
  212. parts := make([]ContentPart, 0)
  213. for _, rawPart := range temp {
  214. var wrapper struct {
  215. Type partType `json:"type"`
  216. Data json.RawMessage `json:"data"`
  217. }
  218. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  219. return nil, err
  220. }
  221. switch wrapper.Type {
  222. case reasoningType:
  223. part := ReasoningContent{}
  224. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  225. return nil, err
  226. }
  227. parts = append(parts, part)
  228. case textType:
  229. part := TextContent{}
  230. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  231. return nil, err
  232. }
  233. parts = append(parts, part)
  234. case imageURLType:
  235. part := ImageURLContent{}
  236. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  237. return nil, err
  238. }
  239. parts = append(parts, part)
  240. case binaryType:
  241. part := BinaryContent{}
  242. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  243. return nil, err
  244. }
  245. parts = append(parts, part)
  246. case toolCallType:
  247. part := ToolCall{}
  248. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  249. return nil, err
  250. }
  251. parts = append(parts, part)
  252. case toolResultType:
  253. part := ToolResult{}
  254. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  255. return nil, err
  256. }
  257. parts = append(parts, part)
  258. case finishType:
  259. part := Finish{}
  260. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  261. return nil, err
  262. }
  263. parts = append(parts, part)
  264. default:
  265. return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
  266. }
  267. }
  268. return parts, nil
  269. }