message.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. package message
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/google/uuid"
  12. "github.com/sst/opencode/internal/db"
  13. "github.com/sst/opencode/internal/llm/models"
  14. "github.com/sst/opencode/internal/pubsub"
  15. )
  16. type Message struct {
  17. ID string
  18. Role MessageRole
  19. SessionID string
  20. Parts []ContentPart
  21. Model models.ModelID
  22. CreatedAt time.Time
  23. UpdatedAt time.Time
  24. }
  25. const (
  26. EventMessageCreated pubsub.EventType = "message_created"
  27. EventMessageUpdated pubsub.EventType = "message_updated"
  28. EventMessageDeleted pubsub.EventType = "message_deleted"
  29. )
  30. type CreateMessageParams struct {
  31. Role MessageRole
  32. Parts []ContentPart
  33. Model models.ModelID
  34. }
  35. type Service interface {
  36. pubsub.Subscriber[Message]
  37. Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
  38. Update(ctx context.Context, message Message) (Message, error)
  39. Get(ctx context.Context, id string) (Message, error)
  40. List(ctx context.Context, sessionID string) ([]Message, error)
  41. ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error)
  42. Delete(ctx context.Context, id string) error
  43. DeleteSessionMessages(ctx context.Context, sessionID string) error
  44. }
  45. type service struct {
  46. db *db.Queries
  47. broker *pubsub.Broker[Message]
  48. mu sync.RWMutex
  49. }
  50. var globalMessageService *service
  51. func InitService(dbConn *sql.DB) error {
  52. if globalMessageService != nil {
  53. return fmt.Errorf("message service already initialized")
  54. }
  55. queries := db.New(dbConn)
  56. broker := pubsub.NewBroker[Message]()
  57. globalMessageService = &service{
  58. db: queries,
  59. broker: broker,
  60. }
  61. return nil
  62. }
  63. func GetService() Service {
  64. if globalMessageService == nil {
  65. panic("message service not initialized. Call message.InitService() first.")
  66. }
  67. return globalMessageService
  68. }
  69. func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  70. s.mu.Lock()
  71. defer s.mu.Unlock()
  72. isFinished := false
  73. for _, p := range params.Parts {
  74. if _, ok := p.(Finish); ok {
  75. isFinished = true
  76. break
  77. }
  78. }
  79. if params.Role == User && !isFinished {
  80. params.Parts = append(params.Parts, Finish{Reason: FinishReasonEndTurn, Time: time.Now()})
  81. }
  82. partsJSON, err := marshallParts(params.Parts)
  83. if err != nil {
  84. return Message{}, fmt.Errorf("failed to marshal message parts: %w", err)
  85. }
  86. dbMsgParams := db.CreateMessageParams{
  87. ID: uuid.New().String(),
  88. SessionID: sessionID,
  89. Role: string(params.Role),
  90. Parts: string(partsJSON),
  91. Model: sql.NullString{String: string(params.Model), Valid: params.Model != ""},
  92. }
  93. dbMessage, err := s.db.CreateMessage(ctx, dbMsgParams)
  94. if err != nil {
  95. return Message{}, fmt.Errorf("db.CreateMessage: %w", err)
  96. }
  97. message, err := s.fromDBItem(dbMessage)
  98. if err != nil {
  99. return Message{}, fmt.Errorf("failed to convert DB message: %w", err)
  100. }
  101. s.broker.Publish(EventMessageCreated, message)
  102. return message, nil
  103. }
  104. func (s *service) Update(ctx context.Context, message Message) (Message, error) {
  105. s.mu.Lock()
  106. defer s.mu.Unlock()
  107. if message.ID == "" {
  108. return Message{}, fmt.Errorf("cannot update message with empty ID")
  109. }
  110. partsJSON, err := marshallParts(message.Parts)
  111. if err != nil {
  112. return Message{}, fmt.Errorf("failed to marshal message parts for update: %w", err)
  113. }
  114. var dbFinishedAt sql.NullString
  115. finishPart := message.FinishPart()
  116. if finishPart != nil && !finishPart.Time.IsZero() {
  117. dbFinishedAt = sql.NullString{
  118. String: finishPart.Time.UTC().Format(time.RFC3339Nano),
  119. Valid: true,
  120. }
  121. }
  122. // UpdatedAt is handled by the DB trigger (strftime('%s', 'now'))
  123. err = s.db.UpdateMessage(ctx, db.UpdateMessageParams{
  124. ID: message.ID,
  125. Parts: string(partsJSON),
  126. FinishedAt: dbFinishedAt,
  127. })
  128. if err != nil {
  129. return Message{}, fmt.Errorf("db.UpdateMessage: %w", err)
  130. }
  131. dbUpdatedMessage, err := s.db.GetMessage(ctx, message.ID)
  132. if err != nil {
  133. return Message{}, fmt.Errorf("failed to fetch message after update: %w", err)
  134. }
  135. updatedMessage, err := s.fromDBItem(dbUpdatedMessage)
  136. if err != nil {
  137. return Message{}, fmt.Errorf("failed to convert updated DB message: %w", err)
  138. }
  139. s.broker.Publish(EventMessageUpdated, updatedMessage)
  140. return updatedMessage, nil
  141. }
  142. func (s *service) Get(ctx context.Context, id string) (Message, error) {
  143. s.mu.RLock()
  144. defer s.mu.RUnlock()
  145. dbMessage, err := s.db.GetMessage(ctx, id)
  146. if err != nil {
  147. if err == sql.ErrNoRows {
  148. return Message{}, fmt.Errorf("message with ID '%s' not found", id)
  149. }
  150. return Message{}, fmt.Errorf("db.GetMessage: %w", err)
  151. }
  152. return s.fromDBItem(dbMessage)
  153. }
  154. func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
  155. s.mu.RLock()
  156. defer s.mu.RUnlock()
  157. dbMessages, err := s.db.ListMessagesBySession(ctx, sessionID)
  158. if err != nil {
  159. return nil, fmt.Errorf("db.ListMessagesBySession: %w", err)
  160. }
  161. messages := make([]Message, len(dbMessages))
  162. for i, dbMsg := range dbMessages {
  163. msg, convErr := s.fromDBItem(dbMsg)
  164. if convErr != nil {
  165. return nil, fmt.Errorf("failed to convert DB message at index %d: %w", i, convErr)
  166. }
  167. messages[i] = msg
  168. }
  169. return messages, nil
  170. }
  171. func (s *service) ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) {
  172. s.mu.RLock()
  173. defer s.mu.RUnlock()
  174. dbMessages, err := s.db.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{
  175. SessionID: sessionID,
  176. CreatedAt: timestamp.Format(time.RFC3339Nano),
  177. })
  178. if err != nil {
  179. return nil, fmt.Errorf("db.ListMessagesBySessionAfter: %w", err)
  180. }
  181. messages := make([]Message, len(dbMessages))
  182. for i, dbMsg := range dbMessages {
  183. msg, convErr := s.fromDBItem(dbMsg)
  184. if convErr != nil {
  185. return nil, fmt.Errorf("failed to convert DB message at index %d (ListAfter): %w", i, convErr)
  186. }
  187. messages[i] = msg
  188. }
  189. return messages, nil
  190. }
  191. func (s *service) Delete(ctx context.Context, id string) error {
  192. s.mu.Lock()
  193. messageToPublish, err := s.getServiceForPublish(ctx, id)
  194. s.mu.Unlock()
  195. if err != nil {
  196. // If error was due to not found, it's not a critical failure for deletion intent
  197. if strings.Contains(err.Error(), "not found") {
  198. return nil // Or return the error if strictness is required
  199. }
  200. return err
  201. }
  202. s.mu.Lock()
  203. defer s.mu.Unlock()
  204. err = s.db.DeleteMessage(ctx, id)
  205. if err != nil {
  206. return fmt.Errorf("db.DeleteMessage: %w", err)
  207. }
  208. if messageToPublish != nil {
  209. s.broker.Publish(EventMessageDeleted, *messageToPublish)
  210. }
  211. return nil
  212. }
  213. func (s *service) getServiceForPublish(ctx context.Context, id string) (*Message, error) {
  214. dbMsg, err := s.db.GetMessage(ctx, id)
  215. if err != nil {
  216. return nil, err
  217. }
  218. msg, convErr := s.fromDBItem(dbMsg)
  219. if convErr != nil {
  220. return nil, fmt.Errorf("failed to convert DB message for publishing: %w", convErr)
  221. }
  222. return &msg, nil
  223. }
  224. func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
  225. s.mu.Lock()
  226. defer s.mu.Unlock()
  227. messagesToDelete, err := s.db.ListMessagesBySession(ctx, sessionID)
  228. if err != nil {
  229. return fmt.Errorf("failed to list messages for deletion: %w", err)
  230. }
  231. err = s.db.DeleteSessionMessages(ctx, sessionID)
  232. if err != nil {
  233. return fmt.Errorf("db.DeleteSessionMessages: %w", err)
  234. }
  235. for _, dbMsg := range messagesToDelete {
  236. msg, convErr := s.fromDBItem(dbMsg)
  237. if convErr == nil {
  238. s.broker.Publish(EventMessageDeleted, msg)
  239. } else {
  240. slog.Error("Failed to convert DB message for delete event publishing", "id", dbMsg.ID, "error", convErr)
  241. }
  242. }
  243. return nil
  244. }
  245. func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
  246. return s.broker.Subscribe(ctx)
  247. }
  248. func (s *service) fromDBItem(item db.Message) (Message, error) {
  249. parts, err := unmarshallParts([]byte(item.Parts))
  250. if err != nil {
  251. return Message{}, fmt.Errorf("unmarshallParts for message ID %s: %w. Raw parts: %s", item.ID, err, item.Parts)
  252. }
  253. // Parse timestamps from ISO strings
  254. createdAt, err := time.Parse(time.RFC3339Nano, item.CreatedAt)
  255. if err != nil {
  256. slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
  257. createdAt = time.Now() // Fallback
  258. }
  259. updatedAt, err := time.Parse(time.RFC3339Nano, item.UpdatedAt)
  260. if err != nil {
  261. slog.Error("Failed to parse created_at", "value", item.CreatedAt, "error", err)
  262. updatedAt = time.Now() // Fallback
  263. }
  264. msg := Message{
  265. ID: item.ID,
  266. SessionID: item.SessionID,
  267. Role: MessageRole(item.Role),
  268. Parts: parts,
  269. Model: models.ModelID(item.Model.String),
  270. CreatedAt: createdAt,
  271. UpdatedAt: updatedAt,
  272. }
  273. return msg, nil
  274. }
  275. func Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  276. return GetService().Create(ctx, sessionID, params)
  277. }
  278. func Update(ctx context.Context, message Message) (Message, error) {
  279. return GetService().Update(ctx, message)
  280. }
  281. func Get(ctx context.Context, id string) (Message, error) {
  282. return GetService().Get(ctx, id)
  283. }
  284. func List(ctx context.Context, sessionID string) ([]Message, error) {
  285. return GetService().List(ctx, sessionID)
  286. }
  287. func ListAfter(ctx context.Context, sessionID string, timestamp time.Time) ([]Message, error) {
  288. return GetService().ListAfter(ctx, sessionID, timestamp)
  289. }
  290. func Delete(ctx context.Context, id string) error {
  291. return GetService().Delete(ctx, id)
  292. }
  293. func DeleteSessionMessages(ctx context.Context, sessionID string) error {
  294. return GetService().DeleteSessionMessages(ctx, sessionID)
  295. }
  296. func Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
  297. return GetService().Subscribe(ctx)
  298. }
  299. type partType string
  300. const (
  301. reasoningType partType = "reasoning"
  302. textType partType = "text"
  303. imageURLType partType = "image_url"
  304. binaryType partType = "binary"
  305. toolCallType partType = "tool_call"
  306. toolResultType partType = "tool_result"
  307. finishType partType = "finish"
  308. )
  309. type partWrapper struct {
  310. Type partType `json:"type"`
  311. Data json.RawMessage `json:"data"`
  312. }
  313. func marshallParts(parts []ContentPart) ([]byte, error) {
  314. wrappedParts := make([]json.RawMessage, len(parts))
  315. for i, part := range parts {
  316. var typ partType
  317. var dataBytes []byte
  318. var err error
  319. switch p := part.(type) {
  320. case ReasoningContent:
  321. typ = reasoningType
  322. dataBytes, err = json.Marshal(p)
  323. case TextContent:
  324. typ = textType
  325. dataBytes, err = json.Marshal(p)
  326. case *TextContent:
  327. typ = textType
  328. dataBytes, err = json.Marshal(p)
  329. case ImageURLContent:
  330. typ = imageURLType
  331. dataBytes, err = json.Marshal(p)
  332. case BinaryContent:
  333. typ = binaryType
  334. dataBytes, err = json.Marshal(p)
  335. case ToolCall:
  336. typ = toolCallType
  337. dataBytes, err = json.Marshal(p)
  338. case ToolResult:
  339. typ = toolResultType
  340. dataBytes, err = json.Marshal(p)
  341. case Finish:
  342. typ = finishType
  343. var dbFinish DBFinish
  344. dbFinish.Reason = p.Reason
  345. dbFinish.Time = p.Time.UnixMilli()
  346. dataBytes, err = json.Marshal(dbFinish)
  347. default:
  348. return nil, fmt.Errorf("unknown part type for marshalling: %T", part)
  349. }
  350. if err != nil {
  351. return nil, fmt.Errorf("failed to marshal part data for type %s: %w", typ, err)
  352. }
  353. wrapper := struct {
  354. Type partType `json:"type"`
  355. Data json.RawMessage `json:"data"`
  356. }{Type: typ, Data: dataBytes}
  357. wrappedBytes, err := json.Marshal(wrapper)
  358. if err != nil {
  359. return nil, fmt.Errorf("failed to marshal part wrapper for type %s: %w", typ, err)
  360. }
  361. wrappedParts[i] = wrappedBytes
  362. }
  363. return json.Marshal(wrappedParts)
  364. }
  365. func unmarshallParts(data []byte) ([]ContentPart, error) {
  366. var rawMessages []json.RawMessage
  367. if err := json.Unmarshal(data, &rawMessages); err != nil {
  368. return nil, fmt.Errorf("failed to unmarshal parts data as array: %w. Data: %s", err, string(data))
  369. }
  370. parts := make([]ContentPart, 0, len(rawMessages))
  371. for _, rawPart := range rawMessages {
  372. var wrapper partWrapper
  373. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  374. // Fallback for old format where parts might be just TextContent string
  375. var text string
  376. if errText := json.Unmarshal(rawPart, &text); errText == nil {
  377. parts = append(parts, TextContent{Text: text})
  378. continue
  379. }
  380. return nil, fmt.Errorf("failed to unmarshal part wrapper: %w. Raw part: %s", err, string(rawPart))
  381. }
  382. switch wrapper.Type {
  383. case reasoningType:
  384. var p ReasoningContent
  385. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  386. return nil, fmt.Errorf("unmarshal ReasoningContent: %w. Data: %s", err, string(wrapper.Data))
  387. }
  388. parts = append(parts, p)
  389. case textType:
  390. var p TextContent
  391. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  392. return nil, fmt.Errorf("unmarshal TextContent: %w. Data: %s", err, string(wrapper.Data))
  393. }
  394. parts = append(parts, p)
  395. case imageURLType:
  396. var p ImageURLContent
  397. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  398. return nil, fmt.Errorf("unmarshal ImageURLContent: %w. Data: %s", err, string(wrapper.Data))
  399. }
  400. parts = append(parts, p)
  401. case binaryType:
  402. var p BinaryContent
  403. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  404. return nil, fmt.Errorf("unmarshal BinaryContent: %w. Data: %s", err, string(wrapper.Data))
  405. }
  406. parts = append(parts, p)
  407. case toolCallType:
  408. var p ToolCall
  409. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  410. return nil, fmt.Errorf("unmarshal ToolCall: %w. Data: %s", err, string(wrapper.Data))
  411. }
  412. parts = append(parts, p)
  413. case toolResultType:
  414. var p ToolResult
  415. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  416. return nil, fmt.Errorf("unmarshal ToolResult: %w. Data: %s", err, string(wrapper.Data))
  417. }
  418. parts = append(parts, p)
  419. case finishType:
  420. var p DBFinish
  421. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  422. return nil, fmt.Errorf("unmarshal Finish: %w. Data: %s", err, string(wrapper.Data))
  423. }
  424. parts = append(parts, Finish{Reason: FinishReason(p.Reason), Time: time.UnixMilli(p.Time)})
  425. default:
  426. slog.Warn("Unknown part type during unmarshalling, attempting to parse as TextContent", "type", wrapper.Type, "data", string(wrapper.Data))
  427. // Fallback: if type is unknown or empty, try to parse data as TextContent directly
  428. var p TextContent
  429. if err := json.Unmarshal(wrapper.Data, &p); err == nil {
  430. parts = append(parts, p)
  431. } else {
  432. // If that also fails, log it but continue if possible, or return error
  433. slog.Error("Failed to unmarshal unknown part type and fallback to TextContent failed", "type", wrapper.Type, "data", string(wrapper.Data), "error", err)
  434. // Depending on strictness, you might return an error here:
  435. // return nil, fmt.Errorf("unknown part type '%s' and failed fallback: %w", wrapper.Type, err)
  436. }
  437. }
  438. }
  439. return parts, nil
  440. }