message.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. package proto
  2. import (
  3. "encoding/base64"
  4. "encoding/json"
  5. "fmt"
  6. "slices"
  7. "time"
  8. "charm.land/catwalk/pkg/catwalk"
  9. )
  10. // CreateMessageParams represents parameters for creating a message.
  11. type CreateMessageParams struct {
  12. Role MessageRole `json:"role"`
  13. Parts []ContentPart `json:"parts"`
  14. Model string `json:"model"`
  15. Provider string `json:"provider,omitempty"`
  16. }
  17. // Message represents a message in the proto layer.
  18. type Message struct {
  19. ID string `json:"id"`
  20. Role MessageRole `json:"role"`
  21. SessionID string `json:"session_id"`
  22. Parts []ContentPart `json:"parts"`
  23. Model string `json:"model"`
  24. Provider string `json:"provider"`
  25. CreatedAt int64 `json:"created_at"`
  26. UpdatedAt int64 `json:"updated_at"`
  27. }
  28. // MessageRole represents the role of a message sender.
  29. type MessageRole string
  30. const (
  31. Assistant MessageRole = "assistant"
  32. User MessageRole = "user"
  33. System MessageRole = "system"
  34. Tool MessageRole = "tool"
  35. )
  36. // MarshalText implements the [encoding.TextMarshaler] interface.
  37. func (r MessageRole) MarshalText() ([]byte, error) {
  38. return []byte(r), nil
  39. }
  40. // UnmarshalText implements the [encoding.TextUnmarshaler] interface.
  41. func (r *MessageRole) UnmarshalText(data []byte) error {
  42. *r = MessageRole(data)
  43. return nil
  44. }
  45. // FinishReason represents why a message generation finished.
  46. type FinishReason string
  47. const (
  48. FinishReasonEndTurn FinishReason = "end_turn"
  49. FinishReasonMaxTokens FinishReason = "max_tokens"
  50. FinishReasonToolUse FinishReason = "tool_use"
  51. FinishReasonCanceled FinishReason = "canceled"
  52. FinishReasonError FinishReason = "error"
  53. FinishReasonPermissionDenied FinishReason = "permission_denied"
  54. FinishReasonUnknown FinishReason = "unknown"
  55. )
  56. // MarshalText implements the [encoding.TextMarshaler] interface.
  57. func (fr FinishReason) MarshalText() ([]byte, error) {
  58. return []byte(fr), nil
  59. }
  60. // UnmarshalText implements the [encoding.TextUnmarshaler] interface.
  61. func (fr *FinishReason) UnmarshalText(data []byte) error {
  62. *fr = FinishReason(data)
  63. return nil
  64. }
  65. // ContentPart is a part of a message's content.
  66. type ContentPart interface {
  67. isPart()
  68. }
  69. // ReasoningContent represents the reasoning/thinking part of a message.
  70. type ReasoningContent struct {
  71. Thinking string `json:"thinking"`
  72. Signature string `json:"signature"`
  73. StartedAt int64 `json:"started_at,omitempty"`
  74. FinishedAt int64 `json:"finished_at,omitempty"`
  75. }
  76. // String returns the thinking content as a string.
  77. func (tc ReasoningContent) String() string {
  78. return tc.Thinking
  79. }
  80. func (ReasoningContent) isPart() {}
  81. // TextContent represents a text part of a message.
  82. type TextContent struct {
  83. Text string `json:"text"`
  84. }
  85. // String returns the text content as a string.
  86. func (tc TextContent) String() string {
  87. return tc.Text
  88. }
  89. func (TextContent) isPart() {}
  90. // ImageURLContent represents an image URL part of a message.
  91. type ImageURLContent struct {
  92. URL string `json:"url"`
  93. Detail string `json:"detail,omitempty"`
  94. }
  95. // String returns the image URL as a string.
  96. func (iuc ImageURLContent) String() string {
  97. return iuc.URL
  98. }
  99. func (ImageURLContent) isPart() {}
  100. // BinaryContent represents binary data in a message.
  101. type BinaryContent struct {
  102. Path string
  103. MIMEType string
  104. Data []byte
  105. }
  106. // String returns a base64-encoded string of the binary data.
  107. func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
  108. base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
  109. if p == catwalk.InferenceProviderOpenAI {
  110. return "data:" + bc.MIMEType + ";base64," + base64Encoded
  111. }
  112. return base64Encoded
  113. }
  114. func (BinaryContent) isPart() {}
  115. // ToolCall represents a tool call in a message.
  116. type ToolCall struct {
  117. ID string `json:"id"`
  118. Name string `json:"name"`
  119. Input string `json:"input"`
  120. Type string `json:"type,omitempty"`
  121. Finished bool `json:"finished,omitempty"`
  122. }
  123. func (ToolCall) isPart() {}
  124. // ToolResult represents the result of a tool call.
  125. type ToolResult struct {
  126. ToolCallID string `json:"tool_call_id"`
  127. Name string `json:"name"`
  128. Content string `json:"content"`
  129. Metadata string `json:"metadata"`
  130. IsError bool `json:"is_error"`
  131. }
  132. func (ToolResult) isPart() {}
  133. // Finish represents the end of a message generation.
  134. type Finish struct {
  135. Reason FinishReason `json:"reason"`
  136. Time int64 `json:"time"`
  137. Message string `json:"message,omitempty"`
  138. Details string `json:"details,omitempty"`
  139. }
  140. func (Finish) isPart() {}
  141. // MarshalJSON implements the [json.Marshaler] interface.
  142. func (m Message) MarshalJSON() ([]byte, error) {
  143. parts, err := MarshalParts(m.Parts)
  144. if err != nil {
  145. return nil, err
  146. }
  147. type Alias Message
  148. return json.Marshal(&struct {
  149. Parts json.RawMessage `json:"parts"`
  150. *Alias
  151. }{
  152. Parts: json.RawMessage(parts),
  153. Alias: (*Alias)(&m),
  154. })
  155. }
  156. // UnmarshalJSON implements the [json.Unmarshaler] interface.
  157. func (m *Message) UnmarshalJSON(data []byte) error {
  158. type Alias Message
  159. aux := &struct {
  160. Parts json.RawMessage `json:"parts"`
  161. *Alias
  162. }{
  163. Alias: (*Alias)(m),
  164. }
  165. if err := json.Unmarshal(data, &aux); err != nil {
  166. return err
  167. }
  168. parts, err := UnmarshalParts([]byte(aux.Parts))
  169. if err != nil {
  170. return err
  171. }
  172. m.Parts = parts
  173. return nil
  174. }
  175. // Content returns the first text content part.
  176. func (m *Message) Content() TextContent {
  177. for _, part := range m.Parts {
  178. if c, ok := part.(TextContent); ok {
  179. return c
  180. }
  181. }
  182. return TextContent{}
  183. }
  184. // ReasoningContent returns the first reasoning content part.
  185. func (m *Message) ReasoningContent() ReasoningContent {
  186. for _, part := range m.Parts {
  187. if c, ok := part.(ReasoningContent); ok {
  188. return c
  189. }
  190. }
  191. return ReasoningContent{}
  192. }
  193. // ImageURLContent returns all image URL content parts.
  194. func (m *Message) ImageURLContent() []ImageURLContent {
  195. imageURLContents := make([]ImageURLContent, 0)
  196. for _, part := range m.Parts {
  197. if c, ok := part.(ImageURLContent); ok {
  198. imageURLContents = append(imageURLContents, c)
  199. }
  200. }
  201. return imageURLContents
  202. }
  203. // BinaryContent returns all binary content parts.
  204. func (m *Message) BinaryContent() []BinaryContent {
  205. binaryContents := make([]BinaryContent, 0)
  206. for _, part := range m.Parts {
  207. if c, ok := part.(BinaryContent); ok {
  208. binaryContents = append(binaryContents, c)
  209. }
  210. }
  211. return binaryContents
  212. }
  213. // ToolCalls returns all tool call parts.
  214. func (m *Message) ToolCalls() []ToolCall {
  215. toolCalls := make([]ToolCall, 0)
  216. for _, part := range m.Parts {
  217. if c, ok := part.(ToolCall); ok {
  218. toolCalls = append(toolCalls, c)
  219. }
  220. }
  221. return toolCalls
  222. }
  223. // ToolResults returns all tool result parts.
  224. func (m *Message) ToolResults() []ToolResult {
  225. toolResults := make([]ToolResult, 0)
  226. for _, part := range m.Parts {
  227. if c, ok := part.(ToolResult); ok {
  228. toolResults = append(toolResults, c)
  229. }
  230. }
  231. return toolResults
  232. }
  233. // IsFinished returns true if the message has a finish part.
  234. func (m *Message) IsFinished() bool {
  235. for _, part := range m.Parts {
  236. if _, ok := part.(Finish); ok {
  237. return true
  238. }
  239. }
  240. return false
  241. }
  242. // FinishPart returns the finish part if present.
  243. func (m *Message) FinishPart() *Finish {
  244. for _, part := range m.Parts {
  245. if c, ok := part.(Finish); ok {
  246. return &c
  247. }
  248. }
  249. return nil
  250. }
  251. // FinishReason returns the finish reason if present.
  252. func (m *Message) FinishReason() FinishReason {
  253. for _, part := range m.Parts {
  254. if c, ok := part.(Finish); ok {
  255. return c.Reason
  256. }
  257. }
  258. return ""
  259. }
  260. // IsThinking returns true if the message is currently in a thinking state.
  261. func (m *Message) IsThinking() bool {
  262. return m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished()
  263. }
  264. // AppendContent appends text to the text content part.
  265. func (m *Message) AppendContent(delta string) {
  266. found := false
  267. for i, part := range m.Parts {
  268. if c, ok := part.(TextContent); ok {
  269. m.Parts[i] = TextContent{Text: c.Text + delta}
  270. found = true
  271. }
  272. }
  273. if !found {
  274. m.Parts = append(m.Parts, TextContent{Text: delta})
  275. }
  276. }
  277. // AppendReasoningContent appends text to the reasoning content part.
  278. func (m *Message) AppendReasoningContent(delta string) {
  279. found := false
  280. for i, part := range m.Parts {
  281. if c, ok := part.(ReasoningContent); ok {
  282. m.Parts[i] = ReasoningContent{
  283. Thinking: c.Thinking + delta,
  284. Signature: c.Signature,
  285. StartedAt: c.StartedAt,
  286. FinishedAt: c.FinishedAt,
  287. }
  288. found = true
  289. }
  290. }
  291. if !found {
  292. m.Parts = append(m.Parts, ReasoningContent{
  293. Thinking: delta,
  294. StartedAt: time.Now().Unix(),
  295. })
  296. }
  297. }
  298. // AppendReasoningSignature appends a signature to the reasoning content part.
  299. func (m *Message) AppendReasoningSignature(signature string) {
  300. for i, part := range m.Parts {
  301. if c, ok := part.(ReasoningContent); ok {
  302. m.Parts[i] = ReasoningContent{
  303. Thinking: c.Thinking,
  304. Signature: c.Signature + signature,
  305. StartedAt: c.StartedAt,
  306. FinishedAt: c.FinishedAt,
  307. }
  308. return
  309. }
  310. }
  311. m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
  312. }
  313. // FinishThinking marks the reasoning content as finished.
  314. func (m *Message) FinishThinking() {
  315. for i, part := range m.Parts {
  316. if c, ok := part.(ReasoningContent); ok {
  317. if c.FinishedAt == 0 {
  318. m.Parts[i] = ReasoningContent{
  319. Thinking: c.Thinking,
  320. Signature: c.Signature,
  321. StartedAt: c.StartedAt,
  322. FinishedAt: time.Now().Unix(),
  323. }
  324. }
  325. return
  326. }
  327. }
  328. }
  329. // ThinkingDuration returns the duration of the thinking phase.
  330. func (m *Message) ThinkingDuration() time.Duration {
  331. reasoning := m.ReasoningContent()
  332. if reasoning.StartedAt == 0 {
  333. return 0
  334. }
  335. endTime := reasoning.FinishedAt
  336. if endTime == 0 {
  337. endTime = time.Now().Unix()
  338. }
  339. return time.Duration(endTime-reasoning.StartedAt) * time.Second
  340. }
  341. // FinishToolCall marks a tool call as finished.
  342. func (m *Message) FinishToolCall(toolCallID string) {
  343. for i, part := range m.Parts {
  344. if c, ok := part.(ToolCall); ok {
  345. if c.ID == toolCallID {
  346. m.Parts[i] = ToolCall{
  347. ID: c.ID,
  348. Name: c.Name,
  349. Input: c.Input,
  350. Type: c.Type,
  351. Finished: true,
  352. }
  353. return
  354. }
  355. }
  356. }
  357. }
  358. // AppendToolCallInput appends input to a tool call.
  359. func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
  360. for i, part := range m.Parts {
  361. if c, ok := part.(ToolCall); ok {
  362. if c.ID == toolCallID {
  363. m.Parts[i] = ToolCall{
  364. ID: c.ID,
  365. Name: c.Name,
  366. Input: c.Input + inputDelta,
  367. Type: c.Type,
  368. Finished: c.Finished,
  369. }
  370. return
  371. }
  372. }
  373. }
  374. }
  375. // AddToolCall adds or updates a tool call.
  376. func (m *Message) AddToolCall(tc ToolCall) {
  377. for i, part := range m.Parts {
  378. if c, ok := part.(ToolCall); ok {
  379. if c.ID == tc.ID {
  380. m.Parts[i] = tc
  381. return
  382. }
  383. }
  384. }
  385. m.Parts = append(m.Parts, tc)
  386. }
  387. // SetToolCalls replaces all tool call parts.
  388. func (m *Message) SetToolCalls(tc []ToolCall) {
  389. parts := make([]ContentPart, 0)
  390. for _, part := range m.Parts {
  391. if _, ok := part.(ToolCall); ok {
  392. continue
  393. }
  394. parts = append(parts, part)
  395. }
  396. m.Parts = parts
  397. for _, toolCall := range tc {
  398. m.Parts = append(m.Parts, toolCall)
  399. }
  400. }
  401. // AddToolResult adds a tool result.
  402. func (m *Message) AddToolResult(tr ToolResult) {
  403. m.Parts = append(m.Parts, tr)
  404. }
  405. // SetToolResults adds multiple tool results.
  406. func (m *Message) SetToolResults(tr []ToolResult) {
  407. for _, toolResult := range tr {
  408. m.Parts = append(m.Parts, toolResult)
  409. }
  410. }
  411. // AddFinish adds a finish part to the message.
  412. func (m *Message) AddFinish(reason FinishReason, message, details string) {
  413. for i, part := range m.Parts {
  414. if _, ok := part.(Finish); ok {
  415. m.Parts = slices.Delete(m.Parts, i, i+1)
  416. break
  417. }
  418. }
  419. m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
  420. }
  421. // AddImageURL adds an image URL part to the message.
  422. func (m *Message) AddImageURL(url, detail string) {
  423. m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
  424. }
  425. // AddBinary adds a binary content part to the message.
  426. func (m *Message) AddBinary(mimeType string, data []byte) {
  427. m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
  428. }
  429. type partType string
  430. const (
  431. reasoningType partType = "reasoning"
  432. textType partType = "text"
  433. imageURLType partType = "image_url"
  434. binaryType partType = "binary"
  435. toolCallType partType = "tool_call"
  436. toolResultType partType = "tool_result"
  437. finishType partType = "finish"
  438. )
  439. type partWrapper struct {
  440. Type partType `json:"type"`
  441. Data ContentPart `json:"data"`
  442. }
  443. // MarshalParts marshals content parts to JSON.
  444. func MarshalParts(parts []ContentPart) ([]byte, error) {
  445. wrappedParts := make([]partWrapper, len(parts))
  446. for i, part := range parts {
  447. var typ partType
  448. switch part.(type) {
  449. case ReasoningContent:
  450. typ = reasoningType
  451. case TextContent:
  452. typ = textType
  453. case ImageURLContent:
  454. typ = imageURLType
  455. case BinaryContent:
  456. typ = binaryType
  457. case ToolCall:
  458. typ = toolCallType
  459. case ToolResult:
  460. typ = toolResultType
  461. case Finish:
  462. typ = finishType
  463. default:
  464. return nil, fmt.Errorf("unknown part type: %T", part)
  465. }
  466. wrappedParts[i] = partWrapper{
  467. Type: typ,
  468. Data: part,
  469. }
  470. }
  471. return json.Marshal(wrappedParts)
  472. }
  473. // UnmarshalParts unmarshals content parts from JSON.
  474. func UnmarshalParts(data []byte) ([]ContentPart, error) {
  475. temp := []json.RawMessage{}
  476. if err := json.Unmarshal(data, &temp); err != nil {
  477. return nil, err
  478. }
  479. parts := make([]ContentPart, 0)
  480. for _, rawPart := range temp {
  481. var wrapper struct {
  482. Type partType `json:"type"`
  483. Data json.RawMessage `json:"data"`
  484. }
  485. if err := json.Unmarshal(rawPart, &wrapper); err != nil {
  486. return nil, err
  487. }
  488. switch wrapper.Type {
  489. case reasoningType:
  490. part := ReasoningContent{}
  491. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  492. return nil, err
  493. }
  494. parts = append(parts, part)
  495. case textType:
  496. part := TextContent{}
  497. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  498. return nil, err
  499. }
  500. parts = append(parts, part)
  501. case imageURLType:
  502. part := ImageURLContent{}
  503. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  504. return nil, err
  505. }
  506. parts = append(parts, part)
  507. case binaryType:
  508. part := BinaryContent{}
  509. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  510. return nil, err
  511. }
  512. parts = append(parts, part)
  513. case toolCallType:
  514. part := ToolCall{}
  515. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  516. return nil, err
  517. }
  518. parts = append(parts, part)
  519. case toolResultType:
  520. part := ToolResult{}
  521. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  522. return nil, err
  523. }
  524. parts = append(parts, part)
  525. case finishType:
  526. part := Finish{}
  527. if err := json.Unmarshal(wrapper.Data, &part); err != nil {
  528. return nil, err
  529. }
  530. parts = append(parts, part)
  531. default:
  532. return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
  533. }
  534. }
  535. return parts, nil
  536. }
  537. // Attachment represents a file attachment.
  538. type Attachment struct {
  539. FilePath string `json:"file_path"`
  540. FileName string `json:"file_name"`
  541. MimeType string `json:"mime_type"`
  542. Content []byte `json:"content"`
  543. }
  544. // MarshalJSON implements the [json.Marshaler] interface.
  545. func (a Attachment) MarshalJSON() ([]byte, error) {
  546. type Alias Attachment
  547. return json.Marshal(&struct {
  548. Content string `json:"content"`
  549. *Alias
  550. }{
  551. Content: base64.StdEncoding.EncodeToString(a.Content),
  552. Alias: (*Alias)(&a),
  553. })
  554. }
  555. // UnmarshalJSON implements the [json.Unmarshaler] interface.
  556. func (a *Attachment) UnmarshalJSON(data []byte) error {
  557. type Alias Attachment
  558. aux := &struct {
  559. Content string `json:"content"`
  560. *Alias
  561. }{
  562. Alias: (*Alias)(a),
  563. }
  564. if err := json.Unmarshal(data, &aux); err != nil {
  565. return err
  566. }
  567. content, err := base64.StdEncoding.DecodeString(aux.Content)
  568. if err != nil {
  569. return err
  570. }
  571. a.Content = content
  572. return nil
  573. }