message.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. package message
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "time"
  8. "github.com/google/uuid"
  9. "github.com/opencode-ai/opencode/internal/db"
  10. "github.com/opencode-ai/opencode/internal/llm/models"
  11. "github.com/opencode-ai/opencode/internal/pubsub"
  12. )
  13. type CreateMessageParams struct {
  14. Role MessageRole
  15. Parts []ContentPart
  16. Model models.ModelID
  17. }
  18. type Service interface {
  19. pubsub.Suscriber[Message]
  20. Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
  21. Update(ctx context.Context, message Message) error
  22. Get(ctx context.Context, id string) (Message, error)
  23. List(ctx context.Context, sessionID string) ([]Message, error)
  24. ListAfter(ctx context.Context, sessionID string, timestamp int64) ([]Message, error)
  25. Delete(ctx context.Context, id string) error
  26. DeleteSessionMessages(ctx context.Context, sessionID string) error
  27. }
  28. type service struct {
  29. *pubsub.Broker[Message]
  30. q db.Querier
  31. }
  32. func NewService(q db.Querier) Service {
  33. return &service{
  34. Broker: pubsub.NewBroker[Message](),
  35. q: q,
  36. }
  37. }
  38. func (s *service) Delete(ctx context.Context, id string) error {
  39. message, err := s.Get(ctx, id)
  40. if err != nil {
  41. return err
  42. }
  43. err = s.q.DeleteMessage(ctx, message.ID)
  44. if err != nil {
  45. return err
  46. }
  47. s.Publish(pubsub.DeletedEvent, message)
  48. return nil
  49. }
  50. func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  51. if params.Role != Assistant {
  52. params.Parts = append(params.Parts, Finish{
  53. Reason: "stop",
  54. })
  55. }
  56. partsJSON, err := marshallParts(params.Parts)
  57. if err != nil {
  58. return Message{}, err
  59. }
  60. dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
  61. ID: uuid.New().String(),
  62. SessionID: sessionID,
  63. Role: string(params.Role),
  64. Parts: string(partsJSON),
  65. Model: sql.NullString{String: string(params.Model), Valid: true},
  66. })
  67. if err != nil {
  68. return Message{}, err
  69. }
  70. message, err := s.fromDBItem(dbMessage)
  71. if err != nil {
  72. return Message{}, err
  73. }
  74. s.Publish(pubsub.CreatedEvent, message)
  75. return message, nil
  76. }
  77. func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
  78. messages, err := s.List(ctx, sessionID)
  79. if err != nil {
  80. return err
  81. }
  82. for _, message := range messages {
  83. if message.SessionID == sessionID {
  84. err = s.Delete(ctx, message.ID)
  85. if err != nil {
  86. return err
  87. }
  88. }
  89. }
  90. return nil
  91. }
  92. func (s *service) Update(ctx context.Context, message Message) error {
  93. parts, err := marshallParts(message.Parts)
  94. if err != nil {
  95. return err
  96. }
  97. finishedAt := sql.NullInt64{}
  98. if f := message.FinishPart(); f != nil {
  99. finishedAt.Int64 = f.Time
  100. finishedAt.Valid = true
  101. }
  102. err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
  103. ID: message.ID,
  104. Parts: string(parts),
  105. FinishedAt: finishedAt,
  106. })
  107. if err != nil {
  108. return err
  109. }
  110. message.UpdatedAt = time.Now().Unix()
  111. s.Publish(pubsub.UpdatedEvent, message)
  112. return nil
  113. }
  114. func (s *service) Get(ctx context.Context, id string) (Message, error) {
  115. dbMessage, err := s.q.GetMessage(ctx, id)
  116. if err != nil {
  117. return Message{}, err
  118. }
  119. return s.fromDBItem(dbMessage)
  120. }
  121. func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
  122. dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
  123. if err != nil {
  124. return nil, err
  125. }
  126. messages := make([]Message, len(dbMessages))
  127. for i, dbMessage := range dbMessages {
  128. messages[i], err = s.fromDBItem(dbMessage)
  129. if err != nil {
  130. return nil, err
  131. }
  132. }
  133. return messages, nil
  134. }
  135. func (s *service) ListAfter(ctx context.Context, sessionID string, timestamp int64) ([]Message, error) {
  136. dbMessages, err := s.q.ListMessagesBySessionAfter(ctx, db.ListMessagesBySessionAfterParams{
  137. SessionID: sessionID,
  138. CreatedAt: timestamp,
  139. })
  140. if err != nil {
  141. return nil, err
  142. }
  143. messages := make([]Message, len(dbMessages))
  144. for i, dbMessage := range dbMessages {
  145. messages[i], err = s.fromDBItem(dbMessage)
  146. if err != nil {
  147. return nil, err
  148. }
  149. }
  150. return messages, nil
  151. }
  152. func (s *service) fromDBItem(item db.Message) (Message, error) {
  153. parts, err := unmarshallParts([]byte(item.Parts))
  154. if err != nil {
  155. return Message{}, err
  156. }
  157. return Message{
  158. ID: item.ID,
  159. SessionID: item.SessionID,
  160. Role: MessageRole(item.Role),
  161. Parts: parts,
  162. Model: models.ModelID(item.Model.String),
  163. CreatedAt: item.CreatedAt,
  164. UpdatedAt: item.UpdatedAt,
  165. }, nil
  166. }
  167. type partType string
  168. const (
  169. reasoningType partType = "reasoning"
  170. textType partType = "text"
  171. imageURLType partType = "image_url"
  172. binaryType partType = "binary"
  173. toolCallType partType = "tool_call"
  174. toolResultType partType = "tool_result"
  175. finishType partType = "finish"
  176. )
  177. type partWrapper struct {
  178. Type partType `json:"type"`
  179. Data ContentPart `json:"data"`
  180. }
  181. func marshallParts(parts []ContentPart) ([]byte, error) {
  182. wrappedParts := make([]partWrapper, len(parts))
  183. for i, part := range parts {
  184. var typ partType
  185. switch part.(type) {
  186. case ReasoningContent:
  187. typ = reasoningType
  188. case TextContent:
  189. typ = textType
  190. case ImageURLContent:
  191. typ = imageURLType
  192. case BinaryContent:
  193. typ = binaryType
  194. case ToolCall:
  195. typ = toolCallType
  196. case ToolResult:
  197. typ = toolResultType
  198. case Finish:
  199. typ = finishType
  200. default:
  201. return nil, fmt.Errorf("unknown part type: %T", part)
  202. }
  203. wrappedParts[i] = partWrapper{
  204. Type: typ,
  205. Data: part,
  206. }
  207. }
  208. return json.Marshal(wrappedParts)
  209. }
  210. func unmarshallParts(data []byte) ([]ContentPart, error) {
  211. temp := []json.RawMessage{}
  212. if err := json.Unmarshal(data, &temp); err != nil {
  213. return nil, err
  214. }
  215. parts := make([]ContentPart, 0)
  216. for _, rawPart := range temp {
  217. var wrapper struct {
  218. Type partType `json:"type"`
  219. Data json.RawMessage `json:"data"`
  220. }
  221. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  222. return nil, err
  223. }
  224. switch wrapper.Type {
  225. case reasoningType:
  226. part := ReasoningContent{}
  227. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  228. return nil, err
  229. }
  230. parts = append(parts, part)
  231. case textType:
  232. part := TextContent{}
  233. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  234. return nil, err
  235. }
  236. parts = append(parts, part)
  237. case imageURLType:
  238. part := ImageURLContent{}
  239. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  240. return nil, err
  241. }
  242. case binaryType:
  243. part := BinaryContent{}
  244. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  245. return nil, err
  246. }
  247. parts = append(parts, part)
  248. case toolCallType:
  249. part := ToolCall{}
  250. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  251. return nil, err
  252. }
  253. parts = append(parts, part)
  254. case toolResultType:
  255. part := ToolResult{}
  256. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  257. return nil, err
  258. }
  259. parts = append(parts, part)
  260. case finishType:
  261. part := Finish{}
  262. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  263. return nil, err
  264. }
  265. parts = append(parts, part)
  266. default:
  267. return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
  268. }
  269. }
  270. return parts, nil
  271. }