message.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. package message
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "strings"
  9. "sync"
  10. "time"
  11. "github.com/google/uuid"
  12. "github.com/opencode-ai/opencode/internal/db"
  13. "github.com/opencode-ai/opencode/internal/llm/models"
  14. "github.com/opencode-ai/opencode/internal/pubsub"
  15. )
  16. const (
  17. EventMessageCreated pubsub.EventType = "message_created"
  18. EventMessageUpdated pubsub.EventType = "message_updated"
  19. EventMessageDeleted pubsub.EventType = "message_deleted"
  20. )
  21. type CreateMessageParams struct {
  22. Role MessageRole
  23. Parts []ContentPart
  24. Model models.ModelID
  25. }
  26. type Service interface {
  27. pubsub.Subscriber[Message]
  28. Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
  29. Update(ctx context.Context, message Message) (Message, error)
  30. Get(ctx context.Context, id string) (Message, error)
  31. List(ctx context.Context, sessionID string) ([]Message, error)
  32. ListAfter(ctx context.Context, sessionID string, timestampMillis int64) ([]Message, error)
  33. Delete(ctx context.Context, id string) error
  34. DeleteSessionMessages(ctx context.Context, sessionID string) error
  35. }
  36. type service struct {
  37. db *db.Queries
  38. broker *pubsub.Broker[Message]
  39. mu sync.RWMutex
  40. }
  41. var globalMessageService *service
  42. func InitService(dbConn *sql.DB) error {
  43. if globalMessageService != nil {
  44. return fmt.Errorf("message service already initialized")
  45. }
  46. queries := db.New(dbConn)
  47. broker := pubsub.NewBroker[Message]()
  48. globalMessageService = &service{
  49. db: queries,
  50. broker: broker,
  51. }
  52. return nil
  53. }
  54. func GetService() Service {
  55. if globalMessageService == nil {
  56. panic("message service not initialized. Call message.InitService() first.")
  57. }
  58. return globalMessageService
  59. }
  60. func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  61. s.mu.Lock()
  62. defer s.mu.Unlock()
  63. isFinished := false
  64. for _, p := range params.Parts {
  65. if _, ok := p.(Finish); ok {
  66. isFinished = true
  67. break
  68. }
  69. }
  70. if params.Role == User && !isFinished {
  71. params.Parts = append(params.Parts, Finish{Reason: FinishReasonEndTurn, Time: time.Now().UnixMilli()})
  72. }
  73. partsJSON, err := marshallParts(params.Parts)
  74. if err != nil {
  75. return Message{}, fmt.Errorf("failed to marshal message parts: %w", err)
  76. }
  77. dbMsgParams := db.CreateMessageParams{
  78. ID: uuid.New().String(),
  79. SessionID: sessionID,
  80. Role: string(params.Role),
  81. Parts: string(partsJSON),
  82. Model: sql.NullString{String: string(params.Model), Valid: params.Model != ""},
  83. }
  84. dbMessage, err := s.db.CreateMessage(ctx, dbMsgParams)
  85. if err != nil {
  86. return Message{}, fmt.Errorf("db.CreateMessage: %w", err)
  87. }
  88. message, err := s.fromDBItem(dbMessage)
  89. if err != nil {
  90. return Message{}, fmt.Errorf("failed to convert DB message: %w", err)
  91. }
  92. s.broker.Publish(EventMessageCreated, message)
  93. return message, nil
  94. }
  95. func (s *service) Update(ctx context.Context, message Message) (Message, error) {
  96. s.mu.Lock()
  97. defer s.mu.Unlock()
  98. if message.ID == "" {
  99. return Message{}, fmt.Errorf("cannot update message with empty ID")
  100. }
  101. partsJSON, err := marshallParts(message.Parts)
  102. if err != nil {
  103. return Message{}, fmt.Errorf("failed to marshal message parts for update: %w", err)
  104. }
  105. var dbFinishedAt sql.NullInt64
  106. finishPart := message.FinishPart()
  107. if finishPart != nil && finishPart.Time > 0 {
  108. dbFinishedAt = sql.NullInt64{
  109. Int64: finishPart.Time / 1000, // Convert Milliseconds from Go struct to Seconds for DB
  110. Valid: true,
  111. }
  112. }
  113. // UpdatedAt is handled by the DB trigger (strftime('%s', 'now'))
  114. err = s.db.UpdateMessage(ctx, db.UpdateMessageParams{
  115. ID: message.ID,
  116. Parts: string(partsJSON),
  117. FinishedAt: dbFinishedAt,
  118. })
  119. if err != nil {
  120. return Message{}, fmt.Errorf("db.UpdateMessage: %w", err)
  121. }
  122. dbUpdatedMessage, err := s.db.GetMessage(ctx, message.ID)
  123. if err != nil {
  124. return Message{}, fmt.Errorf("failed to fetch message after update: %w", err)
  125. }
  126. updatedMessage, err := s.fromDBItem(dbUpdatedMessage)
  127. if err != nil {
  128. return Message{}, fmt.Errorf("failed to convert updated DB message: %w", err)
  129. }
  130. s.broker.Publish(EventMessageUpdated, updatedMessage)
  131. return updatedMessage, nil
  132. }
  133. func (s *service) Get(ctx context.Context, id string) (Message, error) {
  134. s.mu.RLock()
  135. defer s.mu.RUnlock()
  136. dbMessage, err := s.db.GetMessage(ctx, id)
  137. if err != nil {
  138. if err == sql.ErrNoRows {
  139. return Message{}, fmt.Errorf("message with ID '%s' not found", id)
  140. }
  141. return Message{}, fmt.Errorf("db.GetMessage: %w", err)
  142. }
  143. return s.fromDBItem(dbMessage)
  144. }
  145. func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
  146. s.mu.RLock()
  147. defer s.mu.RUnlock()
  148. dbMessages, err := s.db.ListMessagesBySession(ctx, sessionID)
  149. if err != nil {
  150. return nil, fmt.Errorf("db.ListMessagesBySession: %w", err)
  151. }
  152. messages := make([]Message, len(dbMessages))
  153. for i, dbMsg := range dbMessages {
  154. msg, convErr := s.fromDBItem(dbMsg)
  155. if convErr != nil {
  156. return nil, fmt.Errorf("failed to convert DB message at index %d: %w", i, convErr)
  157. }
  158. messages[i] = msg
  159. }
  160. return messages, nil
  161. }
  162. func (s *service) ListAfter(ctx context.Context, sessionID string, timestampMillis int64) ([]Message, error) {
  163. s.mu.RLock()
  164. defer s.mu.RUnlock()
  165. timestampSeconds := timestampMillis / 1000 // Convert to seconds for DB query
  166. dbMessages, err := s.db.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{
  167. SessionID: sessionID,
  168. CreatedAt: timestampSeconds,
  169. })
  170. if err != nil {
  171. return nil, fmt.Errorf("db.ListMessagesBySessionAfter: %w", err)
  172. }
  173. messages := make([]Message, len(dbMessages))
  174. for i, dbMsg := range dbMessages {
  175. msg, convErr := s.fromDBItem(dbMsg)
  176. if convErr != nil {
  177. return nil, fmt.Errorf("failed to convert DB message at index %d (ListAfter): %w", i, convErr)
  178. }
  179. messages[i] = msg
  180. }
  181. return messages, nil
  182. }
  183. func (s *service) Delete(ctx context.Context, id string) error {
  184. s.mu.Lock()
  185. messageToPublish, err := s.getServiceForPublish(ctx, id)
  186. s.mu.Unlock()
  187. if err != nil {
  188. // If error was due to not found, it's not a critical failure for deletion intent
  189. if strings.Contains(err.Error(), "not found") {
  190. return nil // Or return the error if strictness is required
  191. }
  192. return err
  193. }
  194. s.mu.Lock()
  195. defer s.mu.Unlock()
  196. err = s.db.DeleteMessage(ctx, id)
  197. if err != nil {
  198. return fmt.Errorf("db.DeleteMessage: %w", err)
  199. }
  200. if messageToPublish != nil {
  201. s.broker.Publish(EventMessageDeleted, *messageToPublish)
  202. }
  203. return nil
  204. }
  205. func (s *service) getServiceForPublish(ctx context.Context, id string) (*Message, error) {
  206. dbMsg, err := s.db.GetMessage(ctx, id)
  207. if err != nil {
  208. return nil, err
  209. }
  210. msg, convErr := s.fromDBItem(dbMsg)
  211. if convErr != nil {
  212. return nil, fmt.Errorf("failed to convert DB message for publishing: %w", convErr)
  213. }
  214. return &msg, nil
  215. }
  216. func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
  217. s.mu.Lock()
  218. defer s.mu.Unlock()
  219. messagesToDelete, err := s.db.ListMessagesBySession(ctx, sessionID)
  220. if err != nil {
  221. return fmt.Errorf("failed to list messages for deletion: %w", err)
  222. }
  223. err = s.db.DeleteSessionMessages(ctx, sessionID)
  224. if err != nil {
  225. return fmt.Errorf("db.DeleteSessionMessages: %w", err)
  226. }
  227. for _, dbMsg := range messagesToDelete {
  228. msg, convErr := s.fromDBItem(dbMsg)
  229. if convErr == nil {
  230. s.broker.Publish(EventMessageDeleted, msg)
  231. } else {
  232. slog.Error("Failed to convert DB message for delete event publishing", "id", dbMsg.ID, "error", convErr)
  233. }
  234. }
  235. return nil
  236. }
  237. func (s *service) Subscribe(ctx context.Context) <-chan pubsub.Event[Message] {
  238. return s.broker.Subscribe(ctx)
  239. }
  240. func (s *service) fromDBItem(item db.Message) (Message, error) {
  241. parts, err := unmarshallParts([]byte(item.Parts))
  242. if err != nil {
  243. return Message{}, fmt.Errorf("unmarshallParts for message ID %s: %w. Raw parts: %s", item.ID, err, item.Parts)
  244. }
  245. // DB stores created_at, updated_at, finished_at as Unix seconds.
  246. // Go struct Message stores them as Unix milliseconds.
  247. createdAtMillis := item.CreatedAt * 1000
  248. updatedAtMillis := item.UpdatedAt * 1000
  249. msg := Message{
  250. ID: item.ID,
  251. SessionID: item.SessionID,
  252. Role: MessageRole(item.Role),
  253. Parts: parts,
  254. Model: models.ModelID(item.Model.String),
  255. CreatedAt: createdAtMillis,
  256. UpdatedAt: updatedAtMillis,
  257. }
  258. // Ensure Finish part in msg.Parts reflects the item.FinishedAt state
  259. // if item.FinishedAt is the source of truth for the "overall message finished time".
  260. // The `unmarshallParts` should already create a Finish part if it's in the JSON.
  261. // This logic reconciles the DB column with the JSON parts.
  262. var existingFinishPart *Finish
  263. var finishPartIndex = -1
  264. for i, p := range msg.Parts {
  265. if fp, ok := p.(Finish); ok {
  266. existingFinishPart = &fp
  267. finishPartIndex = i
  268. break
  269. }
  270. }
  271. if item.FinishedAt.Valid && item.FinishedAt.Int64 > 0 {
  272. dbFinishTimeMillis := item.FinishedAt.Int64 * 1000
  273. if existingFinishPart != nil {
  274. // If a Finish part exists from JSON, update its time if DB's time is different.
  275. // This assumes DB `finished_at` is the ultimate source of truth for when the message truly finished.
  276. if existingFinishPart.Time != dbFinishTimeMillis {
  277. slog.Debug("Aligning Finish part time with DB finished_at", "message_id", msg.ID, "json_finish_time", existingFinishPart.Time, "db_finish_time", dbFinishTimeMillis)
  278. existingFinishPart.Time = dbFinishTimeMillis
  279. msg.Parts[finishPartIndex] = *existingFinishPart
  280. }
  281. } else {
  282. // If no Finish part in JSON but DB says it's finished, add one.
  283. // We might not know the original FinishReason here, so use a sensible default or leave it to be set by Update.
  284. // This scenario should be less common if `Update` always ensures a Finish part for finished messages.
  285. slog.Debug("Synthesizing Finish part from DB finished_at", "message_id", msg.ID)
  286. msg.Parts = append(msg.Parts, Finish{Reason: FinishReasonEndTurn, Time: dbFinishTimeMillis})
  287. }
  288. }
  289. return msg, nil
  290. }
  291. func Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  292. return GetService().Create(ctx, sessionID, params)
  293. }
  294. func Update(ctx context.Context, message Message) (Message, error) {
  295. return GetService().Update(ctx, message)
  296. }
  297. func Get(ctx context.Context, id string) (Message, error) {
  298. return GetService().Get(ctx, id)
  299. }
  300. func List(ctx context.Context, sessionID string) ([]Message, error) {
  301. return GetService().List(ctx, sessionID)
  302. }
  303. func ListAfter(ctx context.Context, sessionID string, timestampMillis int64) ([]Message, error) {
  304. return GetService().ListAfter(ctx, sessionID, timestampMillis)
  305. }
  306. func Delete(ctx context.Context, id string) error {
  307. return GetService().Delete(ctx, id)
  308. }
  309. func DeleteSessionMessages(ctx context.Context, sessionID string) error {
  310. return GetService().DeleteSessionMessages(ctx, sessionID)
  311. }
  312. func SubscribeToEvents(ctx context.Context) <-chan pubsub.Event[Message] {
  313. return GetService().Subscribe(ctx)
  314. }
  315. type partType string
  316. const (
  317. reasoningType partType = "reasoning"
  318. textType partType = "text"
  319. imageURLType partType = "image_url"
  320. binaryType partType = "binary"
  321. toolCallType partType = "tool_call"
  322. toolResultType partType = "tool_result"
  323. finishType partType = "finish"
  324. )
  325. type partWrapper struct {
  326. Type partType `json:"type"`
  327. Data json.RawMessage `json:"data"`
  328. }
  329. func marshallParts(parts []ContentPart) ([]byte, error) {
  330. wrappedParts := make([]json.RawMessage, len(parts))
  331. for i, part := range parts {
  332. var typ partType
  333. var dataBytes []byte
  334. var err error
  335. switch p := part.(type) {
  336. case ReasoningContent:
  337. typ = reasoningType
  338. dataBytes, err = json.Marshal(p)
  339. case TextContent:
  340. typ = textType
  341. dataBytes, err = json.Marshal(p)
  342. case *TextContent:
  343. typ = textType
  344. dataBytes, err = json.Marshal(p)
  345. case ImageURLContent:
  346. typ = imageURLType
  347. dataBytes, err = json.Marshal(p)
  348. case BinaryContent:
  349. typ = binaryType
  350. dataBytes, err = json.Marshal(p)
  351. case ToolCall:
  352. typ = toolCallType
  353. dataBytes, err = json.Marshal(p)
  354. case ToolResult:
  355. typ = toolResultType
  356. dataBytes, err = json.Marshal(p)
  357. case Finish:
  358. typ = finishType
  359. dataBytes, err = json.Marshal(p)
  360. default:
  361. return nil, fmt.Errorf("unknown part type for marshalling: %T", part)
  362. }
  363. if err != nil {
  364. return nil, fmt.Errorf("failed to marshal part data for type %s: %w", typ, err)
  365. }
  366. wrapper := struct {
  367. Type partType `json:"type"`
  368. Data json.RawMessage `json:"data"`
  369. }{Type: typ, Data: dataBytes}
  370. wrappedBytes, err := json.Marshal(wrapper)
  371. if err != nil {
  372. return nil, fmt.Errorf("failed to marshal part wrapper for type %s: %w", typ, err)
  373. }
  374. wrappedParts[i] = wrappedBytes
  375. }
  376. return json.Marshal(wrappedParts)
  377. }
  378. func unmarshallParts(data []byte) ([]ContentPart, error) {
  379. var rawMessages []json.RawMessage
  380. if err := json.Unmarshal(data, &rawMessages); err != nil {
  381. // Handle case where 'parts' might be a single object if not an array initially
  382. // This was a fallback, if your DB always stores an array, this might not be needed.
  383. var singleRawMessage json.RawMessage
  384. if errSingle := json.Unmarshal(data, &singleRawMessage); errSingle == nil {
  385. rawMessages = []json.RawMessage{singleRawMessage}
  386. } else {
  387. return nil, fmt.Errorf("failed to unmarshal parts data as array: %w. Data: %s", err, string(data))
  388. }
  389. }
  390. parts := make([]ContentPart, 0, len(rawMessages))
  391. for _, rawPart := range rawMessages {
  392. var wrapper partWrapper
  393. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  394. // Fallback for old format where parts might be just TextContent string
  395. var text string
  396. if errText := json.Unmarshal(rawPart, &text); errText == nil {
  397. parts = append(parts, TextContent{Text: text})
  398. continue
  399. }
  400. return nil, fmt.Errorf("failed to unmarshal part wrapper: %w. Raw part: %s", err, string(rawPart))
  401. }
  402. switch wrapper.Type {
  403. case reasoningType:
  404. var p ReasoningContent
  405. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  406. return nil, fmt.Errorf("unmarshal ReasoningContent: %w. Data: %s", err, string(wrapper.Data))
  407. }
  408. parts = append(parts, p)
  409. case textType:
  410. var p TextContent
  411. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  412. return nil, fmt.Errorf("unmarshal TextContent: %w. Data: %s", err, string(wrapper.Data))
  413. }
  414. parts = append(parts, p)
  415. case imageURLType:
  416. var p ImageURLContent
  417. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  418. return nil, fmt.Errorf("unmarshal ImageURLContent: %w. Data: %s", err, string(wrapper.Data))
  419. }
  420. parts = append(parts, p)
  421. case binaryType:
  422. var p BinaryContent
  423. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  424. return nil, fmt.Errorf("unmarshal BinaryContent: %w. Data: %s", err, string(wrapper.Data))
  425. }
  426. parts = append(parts, p)
  427. case toolCallType:
  428. var p ToolCall
  429. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  430. return nil, fmt.Errorf("unmarshal ToolCall: %w. Data: %s", err, string(wrapper.Data))
  431. }
  432. parts = append(parts, p)
  433. case toolResultType:
  434. var p ToolResult
  435. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  436. return nil, fmt.Errorf("unmarshal ToolResult: %w. Data: %s", err, string(wrapper.Data))
  437. }
  438. parts = append(parts, p)
  439. case finishType:
  440. var p Finish
  441. if err := json.Unmarshal(wrapper.Data, &p); err != nil {
  442. return nil, fmt.Errorf("unmarshal Finish: %w. Data: %s", err, string(wrapper.Data))
  443. }
  444. parts = append(parts, p)
  445. default:
  446. slog.Warn("Unknown part type during unmarshalling, attempting to parse as TextContent", "type", wrapper.Type, "data", string(wrapper.Data))
  447. // Fallback: if type is unknown or empty, try to parse data as TextContent directly
  448. var p TextContent
  449. if err := json.Unmarshal(wrapper.Data, &p); err == nil {
  450. parts = append(parts, p)
  451. } else {
  452. // If that also fails, log it but continue if possible, or return error
  453. slog.Error("Failed to unmarshal unknown part type and fallback to TextContent failed", "type", wrapper.Type, "data", string(wrapper.Data), "error", err)
  454. // Depending on strictness, you might return an error here:
  455. // return nil, fmt.Errorf("unknown part type '%s' and failed fallback: %w", wrapper.Type, err)
  456. }
  457. }
  458. }
  459. return parts, nil
  460. }