message.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. package message
  2. import (
  3. "context"
  4. "database/sql"
  5. "encoding/json"
  6. "fmt"
  7. "time"
  8. "github.com/google/uuid"
  9. "github.com/opencode-ai/opencode/internal/db"
  10. "github.com/opencode-ai/opencode/internal/llm/models"
  11. "github.com/opencode-ai/opencode/internal/pubsub"
  12. )
  13. type CreateMessageParams struct {
  14. Role MessageRole
  15. Parts []ContentPart
  16. Model models.ModelID
  17. }
  18. type Service interface {
  19. pubsub.Suscriber[Message]
  20. Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
  21. Update(ctx context.Context, message Message) error
  22. Get(ctx context.Context, id string) (Message, error)
  23. List(ctx context.Context, sessionID string) ([]Message, error)
  24. Delete(ctx context.Context, id string) error
  25. DeleteSessionMessages(ctx context.Context, sessionID string) error
  26. }
  27. type service struct {
  28. *pubsub.Broker[Message]
  29. q db.Querier
  30. }
  31. func NewService(q db.Querier) Service {
  32. return &service{
  33. Broker: pubsub.NewBroker[Message](),
  34. q: q,
  35. }
  36. }
  37. func (s *service) Delete(ctx context.Context, id string) error {
  38. message, err := s.Get(ctx, id)
  39. if err != nil {
  40. return err
  41. }
  42. err = s.q.DeleteMessage(ctx, message.ID)
  43. if err != nil {
  44. return err
  45. }
  46. s.Publish(pubsub.DeletedEvent, message)
  47. return nil
  48. }
  49. func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
  50. if params.Role != Assistant {
  51. params.Parts = append(params.Parts, Finish{
  52. Reason: "stop",
  53. })
  54. }
  55. partsJSON, err := marshallParts(params.Parts)
  56. if err != nil {
  57. return Message{}, err
  58. }
  59. dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
  60. ID: uuid.New().String(),
  61. SessionID: sessionID,
  62. Role: string(params.Role),
  63. Parts: string(partsJSON),
  64. Model: sql.NullString{String: string(params.Model), Valid: true},
  65. })
  66. if err != nil {
  67. return Message{}, err
  68. }
  69. message, err := s.fromDBItem(dbMessage)
  70. if err != nil {
  71. return Message{}, err
  72. }
  73. s.Publish(pubsub.CreatedEvent, message)
  74. return message, nil
  75. }
  76. func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
  77. messages, err := s.List(ctx, sessionID)
  78. if err != nil {
  79. return err
  80. }
  81. for _, message := range messages {
  82. if message.SessionID == sessionID {
  83. err = s.Delete(ctx, message.ID)
  84. if err != nil {
  85. return err
  86. }
  87. }
  88. }
  89. return nil
  90. }
  91. func (s *service) Update(ctx context.Context, message Message) error {
  92. parts, err := marshallParts(message.Parts)
  93. if err != nil {
  94. return err
  95. }
  96. finishedAt := sql.NullInt64{}
  97. if f := message.FinishPart(); f != nil {
  98. finishedAt.Int64 = f.Time
  99. finishedAt.Valid = true
  100. }
  101. err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
  102. ID: message.ID,
  103. Parts: string(parts),
  104. FinishedAt: finishedAt,
  105. })
  106. if err != nil {
  107. return err
  108. }
  109. message.UpdatedAt = time.Now().Unix()
  110. s.Publish(pubsub.UpdatedEvent, message)
  111. return nil
  112. }
  113. func (s *service) Get(ctx context.Context, id string) (Message, error) {
  114. dbMessage, err := s.q.GetMessage(ctx, id)
  115. if err != nil {
  116. return Message{}, err
  117. }
  118. return s.fromDBItem(dbMessage)
  119. }
  120. func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
  121. dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
  122. if err != nil {
  123. return nil, err
  124. }
  125. messages := make([]Message, len(dbMessages))
  126. for i, dbMessage := range dbMessages {
  127. messages[i], err = s.fromDBItem(dbMessage)
  128. if err != nil {
  129. return nil, err
  130. }
  131. }
  132. return messages, nil
  133. }
  134. func (s *service) fromDBItem(item db.Message) (Message, error) {
  135. parts, err := unmarshallParts([]byte(item.Parts))
  136. if err != nil {
  137. return Message{}, err
  138. }
  139. return Message{
  140. ID: item.ID,
  141. SessionID: item.SessionID,
  142. Role: MessageRole(item.Role),
  143. Parts: parts,
  144. Model: models.ModelID(item.Model.String),
  145. CreatedAt: item.CreatedAt,
  146. UpdatedAt: item.UpdatedAt,
  147. }, nil
  148. }
  149. type partType string
  150. const (
  151. reasoningType partType = "reasoning"
  152. textType partType = "text"
  153. imageURLType partType = "image_url"
  154. binaryType partType = "binary"
  155. toolCallType partType = "tool_call"
  156. toolResultType partType = "tool_result"
  157. finishType partType = "finish"
  158. )
  159. type partWrapper struct {
  160. Type partType `json:"type"`
  161. Data ContentPart `json:"data"`
  162. }
  163. func marshallParts(parts []ContentPart) ([]byte, error) {
  164. wrappedParts := make([]partWrapper, len(parts))
  165. for i, part := range parts {
  166. var typ partType
  167. switch part.(type) {
  168. case ReasoningContent:
  169. typ = reasoningType
  170. case TextContent:
  171. typ = textType
  172. case ImageURLContent:
  173. typ = imageURLType
  174. case BinaryContent:
  175. typ = binaryType
  176. case ToolCall:
  177. typ = toolCallType
  178. case ToolResult:
  179. typ = toolResultType
  180. case Finish:
  181. typ = finishType
  182. default:
  183. return nil, fmt.Errorf("unknown part type: %T", part)
  184. }
  185. wrappedParts[i] = partWrapper{
  186. Type: typ,
  187. Data: part,
  188. }
  189. }
  190. return json.Marshal(wrappedParts)
  191. }
  192. func unmarshallParts(data []byte) ([]ContentPart, error) {
  193. temp := []json.RawMessage{}
  194. if err := json.Unmarshal(data, &temp); err != nil {
  195. return nil, err
  196. }
  197. parts := make([]ContentPart, 0)
  198. for _, rawPart := range temp {
  199. var wrapper struct {
  200. Type partType `json:"type"`
  201. Data json.RawMessage `json:"data"`
  202. }
  203. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  204. return nil, err
  205. }
  206. switch wrapper.Type {
  207. case reasoningType:
  208. part := ReasoningContent{}
  209. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  210. return nil, err
  211. }
  212. parts = append(parts, part)
  213. case textType:
  214. part := TextContent{}
  215. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  216. return nil, err
  217. }
  218. parts = append(parts, part)
  219. case imageURLType:
  220. part := ImageURLContent{}
  221. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  222. return nil, err
  223. }
  224. case binaryType:
  225. part := BinaryContent{}
  226. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  227. return nil, err
  228. }
  229. parts = append(parts, part)
  230. case toolCallType:
  231. part := ToolCall{}
  232. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  233. return nil, err
  234. }
  235. parts = append(parts, part)
  236. case toolResultType:
  237. part := ToolResult{}
  238. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  239. return nil, err
  240. }
  241. parts = append(parts, part)
  242. case finishType:
  243. part := Finish{}
  244. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  245. return nil, err
  246. }
  247. parts = append(parts, part)
  248. default:
  249. return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
  250. }
  251. }
  252. return parts, nil
  253. }