message.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. package message
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "time"
  8. "github.com/charmbracelet/crush/internal/db"
  9. "github.com/charmbracelet/crush/internal/pubsub"
  10. "github.com/google/uuid"
  11. )
  12. type CreateMessageParams struct {
  13. Role MessageRole
  14. Parts []ContentPart
  15. Model string
  16. Provider string
  17. IsSummaryMessage bool
  18. }
  19. type Service interface {
  20. pubsub.Subscriber[Message]
  21. Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
  22. Update(ctx context.Context, message Message) error
  23. Get(ctx context.Context, id string) (Message, error)
  24. List(ctx context.Context, sessionID string) ([]Message, error)
  25. ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
  26. ListAllUserMessages(ctx context.Context) ([]Message, error)
  27. Copy(ctx context.Context, sessionID string, message Message) (Message, error)
  28. Delete(ctx context.Context, id string) error
  29. DeleteSessionMessages(ctx context.Context, sessionID string) error
  30. DeleteMessagesFrom(ctx context.Context, sessionID, messageID string) error
  31. }
  32. type service struct {
  33. *pubsub.Broker[Message]
  34. q db.Querier
  35. }
  36. func NewService(q db.Querier) Service {
  37. return &service{
  38. Broker: pubsub.NewBroker[Message](),
  39. q: q,
  40. }
  41. }
  42. func (s *service) Delete(ctx context.Context, id string) error {
  43. message, err := s.Get(ctx, id)
  44. if err != nil {
  45. return err
  46. }
  47. err = s.q.DeleteMessage(ctx, message.ID)
  48. if err != nil {
  49. return err
  50. }
  51. // Clone the message before publishing to avoid race conditions with
  52. // concurrent modifications to the Parts slice.
  53. s.Publish(pubsub.DeletedEvent, message.Clone())
  54. return nil
  55. }
  56. func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  57. if params.Role != Assistant {
  58. params.Parts = append(params.Parts, Finish{
  59. Reason: "stop",
  60. })
  61. }
  62. partsJSON, err := marshalParts(params.Parts)
  63. if err != nil {
  64. return Message{}, err
  65. }
  66. isSummary := int64(0)
  67. if params.IsSummaryMessage {
  68. isSummary = 1
  69. }
  70. dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
  71. ID: uuid.New().String(),
  72. SessionID: sessionID,
  73. Role: string(params.Role),
  74. Parts: string(partsJSON),
  75. Model: sql.NullString{String: string(params.Model), Valid: true},
  76. Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
  77. IsSummaryMessage: isSummary,
  78. })
  79. if err != nil {
  80. return Message{}, err
  81. }
  82. message, err := s.fromDBItem(dbMessage)
  83. if err != nil {
  84. return Message{}, err
  85. }
  86. // Clone the message before publishing to avoid race conditions with
  87. // concurrent modifications to the Parts slice.
  88. s.Publish(pubsub.CreatedEvent, message.Clone())
  89. return message, nil
  90. }
  91. func (s *service) Copy(ctx context.Context, sessionID string, message Message) (Message, error) {
  92. params := CreateMessageParams{
  93. Role: message.Role,
  94. Parts: message.Parts,
  95. Model: message.Model,
  96. Provider: message.Provider,
  97. IsSummaryMessage: message.IsSummaryMessage,
  98. }
  99. return s.Create(ctx, sessionID, params)
  100. }
  101. func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
  102. messages, err := s.List(ctx, sessionID)
  103. if err != nil {
  104. return err
  105. }
  106. for _, message := range messages {
  107. if message.SessionID == sessionID {
  108. err = s.Delete(ctx, message.ID)
  109. if err != nil {
  110. return err
  111. }
  112. }
  113. }
  114. return nil
  115. }
  116. func (s *service) DeleteMessagesFrom(ctx context.Context, sessionID, messageID string) error {
  117. allMessages, err := s.List(ctx, sessionID)
  118. if err != nil {
  119. return err
  120. }
  121. targetIndex := -1
  122. for i, msg := range allMessages {
  123. if msg.ID == messageID {
  124. targetIndex = i
  125. break
  126. }
  127. }
  128. if targetIndex == -1 {
  129. return fmt.Errorf("message not found: %s", messageID)
  130. }
  131. for i := len(allMessages) - 1; i >= targetIndex; i-- {
  132. if err := s.Delete(ctx, allMessages[i].ID); err != nil {
  133. return err
  134. }
  135. }
  136. return nil
  137. }
  138. func (s *service) Update(ctx context.Context, message Message) error {
  139. parts, err := marshalParts(message.Parts)
  140. if err != nil {
  141. return err
  142. }
  143. finishedAt := sql.NullInt64{}
  144. if f := message.FinishPart(); f != nil {
  145. finishedAt.Int64 = f.Time
  146. finishedAt.Valid = true
  147. }
  148. err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
  149. ID: message.ID,
  150. Parts: string(parts),
  151. FinishedAt: finishedAt,
  152. })
  153. if err != nil {
  154. return err
  155. }
  156. message.UpdatedAt = time.Now().Unix()
  157. // Clone the message before publishing to avoid race conditions with
  158. // concurrent modifications to the Parts slice.
  159. s.Publish(pubsub.UpdatedEvent, message.Clone())
  160. return nil
  161. }
  162. func (s *service) Get(ctx context.Context, id string) (Message, error) {
  163. dbMessage, err := s.q.GetMessage(ctx, id)
  164. if err != nil {
  165. return Message{}, err
  166. }
  167. return s.fromDBItem(dbMessage)
  168. }
  169. func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
  170. dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
  171. if err != nil {
  172. return nil, err
  173. }
  174. messages := make([]Message, len(dbMessages))
  175. for i, dbMessage := range dbMessages {
  176. messages[i], err = s.fromDBItem(dbMessage)
  177. if err != nil {
  178. return nil, err
  179. }
  180. }
  181. return messages, nil
  182. }
  183. func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
  184. dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
  185. if err != nil {
  186. return nil, err
  187. }
  188. messages := make([]Message, len(dbMessages))
  189. for i, dbMessage := range dbMessages {
  190. messages[i], err = s.fromDBItem(dbMessage)
  191. if err != nil {
  192. return nil, err
  193. }
  194. }
  195. return messages, nil
  196. }
  197. func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
  198. dbMessages, err := s.q.ListAllUserMessages(ctx)
  199. if err != nil {
  200. return nil, err
  201. }
  202. messages := make([]Message, len(dbMessages))
  203. for i, dbMessage := range dbMessages {
  204. messages[i], err = s.fromDBItem(dbMessage)
  205. if err != nil {
  206. return nil, err
  207. }
  208. }
  209. return messages, nil
  210. }
  211. func (s *service) fromDBItem(item db.Message) (Message, error) {
  212. parts, err := unmarshalParts([]byte(item.Parts))
  213. if err != nil {
  214. return Message{}, err
  215. }
  216. return Message{
  217. ID: item.ID,
  218. SessionID: item.SessionID,
  219. Role: MessageRole(item.Role),
  220. Parts: parts,
  221. Model: item.Model.String,
  222. Provider: item.Provider.String,
  223. CreatedAt: item.CreatedAt,
  224. UpdatedAt: item.UpdatedAt,
  225. IsSummaryMessage: item.IsSummaryMessage != 0,
  226. }, nil
  227. }
  228. type partType string
  229. const (
  230. reasoningType partType = "reasoning"
  231. textType partType = "text"
  232. imageURLType partType = "image_url"
  233. binaryType partType = "binary"
  234. toolCallType partType = "tool_call"
  235. toolResultType partType = "tool_result"
  236. finishType partType = "finish"
  237. )
  238. type partWrapper struct {
  239. Type partType `json:"type"`
  240. Data ContentPart `json:"data"`
  241. }
  242. func marshalParts(parts []ContentPart) ([]byte, error) {
  243. wrappedParts := make([]partWrapper, len(parts))
  244. for i, part := range parts {
  245. var typ partType
  246. switch part.(type) {
  247. case ReasoningContent:
  248. typ = reasoningType
  249. case TextContent:
  250. typ = textType
  251. case ImageURLContent:
  252. typ = imageURLType
  253. case BinaryContent:
  254. typ = binaryType
  255. case ToolCall:
  256. typ = toolCallType
  257. case ToolResult:
  258. typ = toolResultType
  259. case Finish:
  260. typ = finishType
  261. default:
  262. return nil, fmt.Errorf("unknown part type: %T", part)
  263. }
  264. wrappedParts[i] = partWrapper{
  265. Type: typ,
  266. Data: part,
  267. }
  268. }
  269. return json.Marshal(wrappedParts)
  270. }
  271. func unmarshalParts(data []byte) ([]ContentPart, error) {
  272. temp := []json.RawMessage{}
  273. if err := json.Unmarshal(data, &temp); err != nil {
  274. return nil, err
  275. }
  276. parts := make([]ContentPart, 0)
  277. for _, rawPart := range temp {
  278. var wrapper struct {
  279. Type partType `json:"type"`
  280. Data json.RawMessage `json:"data"`
  281. }
  282. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  283. return nil, err
  284. }
  285. switch wrapper.Type {
  286. case reasoningType:
  287. part := ReasoningContent{}
  288. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  289. return nil, err
  290. }
  291. parts = append(parts, part)
  292. case textType:
  293. part := TextContent{}
  294. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  295. return nil, err
  296. }
  297. parts = append(parts, part)
  298. case imageURLType:
  299. part := ImageURLContent{}
  300. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  301. return nil, err
  302. }
  303. parts = append(parts, part)
  304. case binaryType:
  305. part := BinaryContent{}
  306. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  307. return nil, err
  308. }
  309. parts = append(parts, part)
  310. case toolCallType:
  311. part := ToolCall{}
  312. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  313. return nil, err
  314. }
  315. parts = append(parts, part)
  316. case toolResultType:
  317. part := ToolResult{}
  318. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  319. return nil, err
  320. }
  321. parts = append(parts, part)
  322. case finishType:
  323. part := Finish{}
  324. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  325. return nil, err
  326. }
  327. parts = append(parts, part)
  328. default:
  329. return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
  330. }
  331. }
  332. return parts, nil
  333. }