message.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. package message
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "time"
  8. "github.com/charmbracelet/crush/internal/db"
  9. "github.com/charmbracelet/crush/internal/pubsub"
  10. "github.com/google/uuid"
  11. )
  12. type CreateMessageParams struct {
  13. Role MessageRole
  14. Parts []ContentPart
  15. Model string
  16. Provider string
  17. IsSummaryMessage bool
  18. }
  19. type Service interface {
  20. pubsub.Subscriber[Message]
  21. Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
  22. Update(ctx context.Context, message Message) error
  23. Get(ctx context.Context, id string) (Message, error)
  24. List(ctx context.Context, sessionID string) ([]Message, error)
  25. ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
  26. ListAllUserMessages(ctx context.Context) ([]Message, error)
  27. Delete(ctx context.Context, id string) error
  28. DeleteSessionMessages(ctx context.Context, sessionID string) error
  29. }
  30. type service struct {
  31. *pubsub.Broker[Message]
  32. q db.Querier
  33. }
  34. func NewService(q db.Querier) Service {
  35. return &service{
  36. Broker: pubsub.NewBroker[Message](),
  37. q: q,
  38. }
  39. }
  40. func (s *service) Delete(ctx context.Context, id string) error {
  41. message, err := s.Get(ctx, id)
  42. if err != nil {
  43. return err
  44. }
  45. err = s.q.DeleteMessage(ctx, message.ID)
  46. if err != nil {
  47. return err
  48. }
  49. // Clone the message before publishing to avoid race conditions with
  50. // concurrent modifications to the Parts slice.
  51. s.Publish(pubsub.DeletedEvent, message.Clone())
  52. return nil
  53. }
  54. func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  55. if params.Role != Assistant {
  56. params.Parts = append(params.Parts, Finish{
  57. Reason: "stop",
  58. })
  59. }
  60. partsJSON, err := marshalParts(params.Parts)
  61. if err != nil {
  62. return Message{}, err
  63. }
  64. isSummary := int64(0)
  65. if params.IsSummaryMessage {
  66. isSummary = 1
  67. }
  68. dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
  69. ID: uuid.New().String(),
  70. SessionID: sessionID,
  71. Role: string(params.Role),
  72. Parts: string(partsJSON),
  73. Model: sql.NullString{String: string(params.Model), Valid: true},
  74. Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
  75. IsSummaryMessage: isSummary,
  76. })
  77. if err != nil {
  78. return Message{}, err
  79. }
  80. message, err := s.fromDBItem(dbMessage)
  81. if err != nil {
  82. return Message{}, err
  83. }
  84. // Clone the message before publishing to avoid race conditions with
  85. // concurrent modifications to the Parts slice.
  86. s.Publish(pubsub.CreatedEvent, message.Clone())
  87. return message, nil
  88. }
  89. func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
  90. messages, err := s.List(ctx, sessionID)
  91. if err != nil {
  92. return err
  93. }
  94. for _, message := range messages {
  95. if message.SessionID == sessionID {
  96. err = s.Delete(ctx, message.ID)
  97. if err != nil {
  98. return err
  99. }
  100. }
  101. }
  102. return nil
  103. }
  104. func (s *service) Update(ctx context.Context, message Message) error {
  105. parts, err := marshalParts(message.Parts)
  106. if err != nil {
  107. return err
  108. }
  109. finishedAt := sql.NullInt64{}
  110. if f := message.FinishPart(); f != nil {
  111. finishedAt.Int64 = f.Time
  112. finishedAt.Valid = true
  113. }
  114. err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
  115. ID: message.ID,
  116. Parts: string(parts),
  117. FinishedAt: finishedAt,
  118. })
  119. if err != nil {
  120. return err
  121. }
  122. message.UpdatedAt = time.Now().Unix()
  123. // Clone the message before publishing to avoid race conditions with
  124. // concurrent modifications to the Parts slice.
  125. s.Publish(pubsub.UpdatedEvent, message.Clone())
  126. return nil
  127. }
  128. func (s *service) Get(ctx context.Context, id string) (Message, error) {
  129. dbMessage, err := s.q.GetMessage(ctx, id)
  130. if err != nil {
  131. return Message{}, err
  132. }
  133. return s.fromDBItem(dbMessage)
  134. }
  135. func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
  136. dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
  137. if err != nil {
  138. return nil, err
  139. }
  140. messages := make([]Message, len(dbMessages))
  141. for i, dbMessage := range dbMessages {
  142. messages[i], err = s.fromDBItem(dbMessage)
  143. if err != nil {
  144. return nil, err
  145. }
  146. }
  147. return messages, nil
  148. }
  149. func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
  150. dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
  151. if err != nil {
  152. return nil, err
  153. }
  154. messages := make([]Message, len(dbMessages))
  155. for i, dbMessage := range dbMessages {
  156. messages[i], err = s.fromDBItem(dbMessage)
  157. if err != nil {
  158. return nil, err
  159. }
  160. }
  161. return messages, nil
  162. }
  163. func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
  164. dbMessages, err := s.q.ListAllUserMessages(ctx)
  165. if err != nil {
  166. return nil, err
  167. }
  168. messages := make([]Message, len(dbMessages))
  169. for i, dbMessage := range dbMessages {
  170. messages[i], err = s.fromDBItem(dbMessage)
  171. if err != nil {
  172. return nil, err
  173. }
  174. }
  175. return messages, nil
  176. }
  177. func (s *service) fromDBItem(item db.Message) (Message, error) {
  178. parts, err := unmarshalParts([]byte(item.Parts))
  179. if err != nil {
  180. return Message{}, err
  181. }
  182. return Message{
  183. ID: item.ID,
  184. SessionID: item.SessionID,
  185. Role: MessageRole(item.Role),
  186. Parts: parts,
  187. Model: item.Model.String,
  188. Provider: item.Provider.String,
  189. CreatedAt: item.CreatedAt,
  190. UpdatedAt: item.UpdatedAt,
  191. IsSummaryMessage: item.IsSummaryMessage != 0,
  192. }, nil
  193. }
  194. type partType string
  195. const (
  196. reasoningType partType = "reasoning"
  197. textType partType = "text"
  198. imageURLType partType = "image_url"
  199. binaryType partType = "binary"
  200. toolCallType partType = "tool_call"
  201. toolResultType partType = "tool_result"
  202. finishType partType = "finish"
  203. )
  204. type partWrapper struct {
  205. Type partType `json:"type"`
  206. Data ContentPart `json:"data"`
  207. }
  208. func marshalParts(parts []ContentPart) ([]byte, error) {
  209. wrappedParts := make([]partWrapper, len(parts))
  210. for i, part := range parts {
  211. var typ partType
  212. switch part.(type) {
  213. case ReasoningContent:
  214. typ = reasoningType
  215. case TextContent:
  216. typ = textType
  217. case ImageURLContent:
  218. typ = imageURLType
  219. case BinaryContent:
  220. typ = binaryType
  221. case ToolCall:
  222. typ = toolCallType
  223. case ToolResult:
  224. typ = toolResultType
  225. case Finish:
  226. typ = finishType
  227. default:
  228. return nil, fmt.Errorf("unknown part type: %T", part)
  229. }
  230. wrappedParts[i] = partWrapper{
  231. Type: typ,
  232. Data: part,
  233. }
  234. }
  235. return json.Marshal(wrappedParts)
  236. }
  237. func unmarshalParts(data []byte) ([]ContentPart, error) {
  238. temp := []json.RawMessage{}
  239. if err := json.Unmarshal(data, &temp); err != nil {
  240. return nil, err
  241. }
  242. parts := make([]ContentPart, 0)
  243. for _, rawPart := range temp {
  244. var wrapper struct {
  245. Type partType `json:"type"`
  246. Data json.RawMessage `json:"data"`
  247. }
  248. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  249. return nil, err
  250. }
  251. switch wrapper.Type {
  252. case reasoningType:
  253. part := ReasoningContent{}
  254. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  255. return nil, err
  256. }
  257. parts = append(parts, part)
  258. case textType:
  259. part := TextContent{}
  260. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  261. return nil, err
  262. }
  263. parts = append(parts, part)
  264. case imageURLType:
  265. part := ImageURLContent{}
  266. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  267. return nil, err
  268. }
  269. parts = append(parts, part)
  270. case binaryType:
  271. part := BinaryContent{}
  272. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  273. return nil, err
  274. }
  275. parts = append(parts, part)
  276. case toolCallType:
  277. part := ToolCall{}
  278. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  279. return nil, err
  280. }
  281. parts = append(parts, part)
  282. case toolResultType:
  283. part := ToolResult{}
  284. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  285. return nil, err
  286. }
  287. parts = append(parts, part)
  288. case finishType:
  289. part := Finish{}
  290. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  291. return nil, err
  292. }
  293. parts = append(parts, part)
  294. default:
  295. return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
  296. }
  297. }
  298. return parts, nil
  299. }