message.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. package message
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "time"
  8. "github.com/charmbracelet/crush/internal/db"
  9. "github.com/charmbracelet/crush/internal/pubsub"
  10. "github.com/google/uuid"
  11. )
  12. type CreateMessageParams struct {
  13. Role MessageRole
  14. Parts []ContentPart
  15. Model string
  16. Provider string
  17. IsSummaryMessage bool
  18. HookOutputs []HookOutput
  19. }
  20. type Service interface {
  21. pubsub.Suscriber[Message]
  22. Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
  23. Update(ctx context.Context, message Message) error
  24. Get(ctx context.Context, id string) (Message, error)
  25. List(ctx context.Context, sessionID string) ([]Message, error)
  26. Delete(ctx context.Context, id string) error
  27. DeleteSessionMessages(ctx context.Context, sessionID string) error
  28. }
  29. type service struct {
  30. *pubsub.Broker[Message]
  31. q db.Querier
  32. }
  33. func NewService(q db.Querier) Service {
  34. return &service{
  35. Broker: pubsub.NewBroker[Message](),
  36. q: q,
  37. }
  38. }
  39. func (s *service) Delete(ctx context.Context, id string) error {
  40. message, err := s.Get(ctx, id)
  41. if err != nil {
  42. return err
  43. }
  44. err = s.q.DeleteMessage(ctx, message.ID)
  45. if err != nil {
  46. return err
  47. }
  48. s.Publish(pubsub.DeletedEvent, message)
  49. return nil
  50. }
  51. func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  52. if params.Role != Assistant {
  53. params.Parts = append(params.Parts, Finish{
  54. Reason: "stop",
  55. })
  56. }
  57. partsJSON, err := marshallParts(params.Parts)
  58. if err != nil {
  59. return Message{}, err
  60. }
  61. isSummary := int64(0)
  62. if params.IsSummaryMessage {
  63. isSummary = 1
  64. }
  65. hookOutputsJSON, err := json.Marshal(params.HookOutputs)
  66. if err != nil {
  67. return Message{}, err
  68. }
  69. dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
  70. ID: uuid.New().String(),
  71. SessionID: sessionID,
  72. Role: string(params.Role),
  73. Parts: string(partsJSON),
  74. Model: sql.NullString{String: string(params.Model), Valid: true},
  75. Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
  76. IsSummaryMessage: isSummary,
  77. HookOutputs: string(hookOutputsJSON),
  78. })
  79. if err != nil {
  80. return Message{}, err
  81. }
  82. message, err := s.fromDBItem(dbMessage)
  83. if err != nil {
  84. return Message{}, err
  85. }
  86. s.Publish(pubsub.CreatedEvent, message)
  87. return message, nil
  88. }
  89. func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
  90. messages, err := s.List(ctx, sessionID)
  91. if err != nil {
  92. return err
  93. }
  94. for _, message := range messages {
  95. if message.SessionID == sessionID {
  96. err = s.Delete(ctx, message.ID)
  97. if err != nil {
  98. return err
  99. }
  100. }
  101. }
  102. return nil
  103. }
  104. func (s *service) Update(ctx context.Context, message Message) error {
  105. parts, err := marshallParts(message.Parts)
  106. if err != nil {
  107. return err
  108. }
  109. finishedAt := sql.NullInt64{}
  110. if f := message.FinishPart(); f != nil {
  111. finishedAt.Int64 = f.Time
  112. finishedAt.Valid = true
  113. }
  114. hookOutputsJSON, err := json.Marshal(message.HookOutputs)
  115. if err != nil {
  116. return err
  117. }
  118. err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
  119. ID: message.ID,
  120. Parts: string(parts),
  121. FinishedAt: finishedAt,
  122. HookOutputs: string(hookOutputsJSON),
  123. })
  124. if err != nil {
  125. return err
  126. }
  127. message.UpdatedAt = time.Now().Unix()
  128. s.Publish(pubsub.UpdatedEvent, message)
  129. return nil
  130. }
  131. func (s *service) Get(ctx context.Context, id string) (Message, error) {
  132. dbMessage, err := s.q.GetMessage(ctx, id)
  133. if err != nil {
  134. return Message{}, err
  135. }
  136. return s.fromDBItem(dbMessage)
  137. }
  138. func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
  139. dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
  140. if err != nil {
  141. return nil, err
  142. }
  143. messages := make([]Message, len(dbMessages))
  144. for i, dbMessage := range dbMessages {
  145. messages[i], err = s.fromDBItem(dbMessage)
  146. if err != nil {
  147. return nil, err
  148. }
  149. }
  150. return messages, nil
  151. }
  152. func (s *service) fromDBItem(item db.Message) (Message, error) {
  153. parts, err := unmarshallParts([]byte(item.Parts))
  154. if err != nil {
  155. return Message{}, err
  156. }
  157. var hookOutputs []HookOutput
  158. if item.HookOutputs != "" {
  159. if err := json.Unmarshal([]byte(item.HookOutputs), &hookOutputs); err != nil {
  160. return Message{}, err
  161. }
  162. }
  163. return Message{
  164. ID: item.ID,
  165. SessionID: item.SessionID,
  166. Role: MessageRole(item.Role),
  167. Parts: parts,
  168. Model: item.Model.String,
  169. Provider: item.Provider.String,
  170. CreatedAt: item.CreatedAt,
  171. UpdatedAt: item.UpdatedAt,
  172. IsSummaryMessage: item.IsSummaryMessage != 0,
  173. HookOutputs: hookOutputs,
  174. }, nil
  175. }
  176. type partType string
  177. const (
  178. reasoningType partType = "reasoning"
  179. textType partType = "text"
  180. imageURLType partType = "image_url"
  181. binaryType partType = "binary"
  182. toolCallType partType = "tool_call"
  183. toolResultType partType = "tool_result"
  184. finishType partType = "finish"
  185. )
  186. type partWrapper struct {
  187. Type partType `json:"type"`
  188. Data ContentPart `json:"data"`
  189. }
  190. func marshallParts(parts []ContentPart) ([]byte, error) {
  191. wrappedParts := make([]partWrapper, len(parts))
  192. for i, part := range parts {
  193. var typ partType
  194. switch part.(type) {
  195. case ReasoningContent:
  196. typ = reasoningType
  197. case TextContent:
  198. typ = textType
  199. case ImageURLContent:
  200. typ = imageURLType
  201. case BinaryContent:
  202. typ = binaryType
  203. case ToolCall:
  204. typ = toolCallType
  205. case ToolResult:
  206. typ = toolResultType
  207. case Finish:
  208. typ = finishType
  209. default:
  210. return nil, fmt.Errorf("unknown part type: %T", part)
  211. }
  212. wrappedParts[i] = partWrapper{
  213. Type: typ,
  214. Data: part,
  215. }
  216. }
  217. return json.Marshal(wrappedParts)
  218. }
  219. func unmarshallParts(data []byte) ([]ContentPart, error) {
  220. temp := []json.RawMessage{}
  221. if err := json.Unmarshal(data, &temp); err != nil {
  222. return nil, err
  223. }
  224. parts := make([]ContentPart, 0)
  225. for _, rawPart := range temp {
  226. var wrapper struct {
  227. Type partType `json:"type"`
  228. Data json.RawMessage `json:"data"`
  229. }
  230. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  231. return nil, err
  232. }
  233. switch wrapper.Type {
  234. case reasoningType:
  235. part := ReasoningContent{}
  236. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  237. return nil, err
  238. }
  239. parts = append(parts, part)
  240. case textType:
  241. part := TextContent{}
  242. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  243. return nil, err
  244. }
  245. parts = append(parts, part)
  246. case imageURLType:
  247. part := ImageURLContent{}
  248. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  249. return nil, err
  250. }
  251. parts = append(parts, part)
  252. case binaryType:
  253. part := BinaryContent{}
  254. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  255. return nil, err
  256. }
  257. parts = append(parts, part)
  258. case toolCallType:
  259. part := ToolCall{}
  260. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  261. return nil, err
  262. }
  263. parts = append(parts, part)
  264. case toolResultType:
  265. part := ToolResult{}
  266. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  267. return nil, err
  268. }
  269. parts = append(parts, part)
  270. case finishType:
  271. part := Finish{}
  272. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  273. return nil, err
  274. }
  275. parts = append(parts, part)
  276. default:
  277. return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
  278. }
  279. }
  280. return parts, nil
  281. }