| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- package message
- import (
- "context"
- "database/sql"
- "encoding/json"
- "fmt"
- "time"
- "github.com/charmbracelet/crush/internal/db"
- "github.com/charmbracelet/crush/internal/pubsub"
- "github.com/google/uuid"
- )
- type CreateMessageParams struct {
- Role MessageRole
- Parts []ContentPart
- Model string
- Provider string
- IsSummaryMessage bool
- }
- type Service interface {
- pubsub.Subscriber[Message]
- Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error)
- Update(ctx context.Context, message Message) error
- Get(ctx context.Context, id string) (Message, error)
- List(ctx context.Context, sessionID string) ([]Message, error)
- ListUserMessages(ctx context.Context, sessionID string) ([]Message, error)
- ListAllUserMessages(ctx context.Context) ([]Message, error)
- Delete(ctx context.Context, id string) error
- DeleteSessionMessages(ctx context.Context, sessionID string) error
- }
- type service struct {
- *pubsub.Broker[Message]
- q db.Querier
- }
- func NewService(q db.Querier) Service {
- return &service{
- Broker: pubsub.NewBroker[Message](),
- q: q,
- }
- }
- func (s *service) Delete(ctx context.Context, id string) error {
- message, err := s.Get(ctx, id)
- if err != nil {
- return err
- }
- err = s.q.DeleteMessage(ctx, message.ID)
- if err != nil {
- return err
- }
- // Clone the message before publishing to avoid race conditions with
- // concurrent modifications to the Parts slice.
- s.Publish(pubsub.DeletedEvent, message.Clone())
- return nil
- }
- func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
- if params.Role != Assistant {
- params.Parts = append(params.Parts, Finish{
- Reason: "stop",
- })
- }
- partsJSON, err := marshalParts(params.Parts)
- if err != nil {
- return Message{}, err
- }
- isSummary := int64(0)
- if params.IsSummaryMessage {
- isSummary = 1
- }
- dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{
- ID: uuid.New().String(),
- SessionID: sessionID,
- Role: string(params.Role),
- Parts: string(partsJSON),
- Model: sql.NullString{String: string(params.Model), Valid: true},
- Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""},
- IsSummaryMessage: isSummary,
- })
- if err != nil {
- return Message{}, err
- }
- message, err := s.fromDBItem(dbMessage)
- if err != nil {
- return Message{}, err
- }
- // Clone the message before publishing to avoid race conditions with
- // concurrent modifications to the Parts slice.
- s.Publish(pubsub.CreatedEvent, message.Clone())
- return message, nil
- }
- func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) error {
- messages, err := s.List(ctx, sessionID)
- if err != nil {
- return err
- }
- for _, message := range messages {
- if message.SessionID == sessionID {
- err = s.Delete(ctx, message.ID)
- if err != nil {
- return err
- }
- }
- }
- return nil
- }
- func (s *service) Update(ctx context.Context, message Message) error {
- parts, err := marshalParts(message.Parts)
- if err != nil {
- return err
- }
- finishedAt := sql.NullInt64{}
- if f := message.FinishPart(); f != nil {
- finishedAt.Int64 = f.Time
- finishedAt.Valid = true
- }
- err = s.q.UpdateMessage(ctx, db.UpdateMessageParams{
- ID: message.ID,
- Parts: string(parts),
- FinishedAt: finishedAt,
- })
- if err != nil {
- return err
- }
- message.UpdatedAt = time.Now().Unix()
- // Clone the message before publishing to avoid race conditions with
- // concurrent modifications to the Parts slice.
- s.Publish(pubsub.UpdatedEvent, message.Clone())
- return nil
- }
- func (s *service) Get(ctx context.Context, id string) (Message, error) {
- dbMessage, err := s.q.GetMessage(ctx, id)
- if err != nil {
- return Message{}, err
- }
- return s.fromDBItem(dbMessage)
- }
- func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) {
- dbMessages, err := s.q.ListMessagesBySession(ctx, sessionID)
- if err != nil {
- return nil, err
- }
- messages := make([]Message, len(dbMessages))
- for i, dbMessage := range dbMessages {
- messages[i], err = s.fromDBItem(dbMessage)
- if err != nil {
- return nil, err
- }
- }
- return messages, nil
- }
- func (s *service) ListUserMessages(ctx context.Context, sessionID string) ([]Message, error) {
- dbMessages, err := s.q.ListUserMessagesBySession(ctx, sessionID)
- if err != nil {
- return nil, err
- }
- messages := make([]Message, len(dbMessages))
- for i, dbMessage := range dbMessages {
- messages[i], err = s.fromDBItem(dbMessage)
- if err != nil {
- return nil, err
- }
- }
- return messages, nil
- }
- func (s *service) ListAllUserMessages(ctx context.Context) ([]Message, error) {
- dbMessages, err := s.q.ListAllUserMessages(ctx)
- if err != nil {
- return nil, err
- }
- messages := make([]Message, len(dbMessages))
- for i, dbMessage := range dbMessages {
- messages[i], err = s.fromDBItem(dbMessage)
- if err != nil {
- return nil, err
- }
- }
- return messages, nil
- }
- func (s *service) fromDBItem(item db.Message) (Message, error) {
- parts, err := unmarshalParts([]byte(item.Parts))
- if err != nil {
- return Message{}, err
- }
- return Message{
- ID: item.ID,
- SessionID: item.SessionID,
- Role: MessageRole(item.Role),
- Parts: parts,
- Model: item.Model.String,
- Provider: item.Provider.String,
- CreatedAt: item.CreatedAt,
- UpdatedAt: item.UpdatedAt,
- IsSummaryMessage: item.IsSummaryMessage != 0,
- }, nil
- }
- type partType string
- const (
- reasoningType partType = "reasoning"
- textType partType = "text"
- imageURLType partType = "image_url"
- binaryType partType = "binary"
- toolCallType partType = "tool_call"
- toolResultType partType = "tool_result"
- finishType partType = "finish"
- )
- type partWrapper struct {
- Type partType `json:"type"`
- Data ContentPart `json:"data"`
- }
- func marshalParts(parts []ContentPart) ([]byte, error) {
- wrappedParts := make([]partWrapper, len(parts))
- for i, part := range parts {
- var typ partType
- switch part.(type) {
- case ReasoningContent:
- typ = reasoningType
- case TextContent:
- typ = textType
- case ImageURLContent:
- typ = imageURLType
- case BinaryContent:
- typ = binaryType
- case ToolCall:
- typ = toolCallType
- case ToolResult:
- typ = toolResultType
- case Finish:
- typ = finishType
- default:
- return nil, fmt.Errorf("unknown part type: %T", part)
- }
- wrappedParts[i] = partWrapper{
- Type: typ,
- Data: part,
- }
- }
- return json.Marshal(wrappedParts)
- }
- func unmarshalParts(data []byte) ([]ContentPart, error) {
- temp := []json.RawMessage{}
- if err := json.Unmarshal(data, &temp); err != nil {
- return nil, err
- }
- parts := make([]ContentPart, 0)
- for _, rawPart := range temp {
- var wrapper struct {
- Type partType `json:"type"`
- Data json.RawMessage `json:"data"`
- }
- if err := json.Unmarshal(rawPart, &wrapper); err != nil {
- return nil, err
- }
- switch wrapper.Type {
- case reasoningType:
- part := ReasoningContent{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case textType:
- part := TextContent{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case imageURLType:
- part := ImageURLContent{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case binaryType:
- part := BinaryContent{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case toolCallType:
- part := ToolCall{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case toolResultType:
- part := ToolResult{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case finishType:
- part := Finish{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- default:
- return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
- }
- }
- return parts, nil
- }
|