content.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. package message
  2. import (
  3. "encoding/base64"
  4. "slices"
  5. "time"
  6. "github.com/charmbracelet/catwalk/pkg/catwalk"
  7. )
  8. type MessageRole string
  9. const (
  10. Assistant MessageRole = "assistant"
  11. User MessageRole = "user"
  12. System MessageRole = "system"
  13. Tool MessageRole = "tool"
  14. )
  15. type FinishReason string
  16. const (
  17. FinishReasonEndTurn FinishReason = "end_turn"
  18. FinishReasonMaxTokens FinishReason = "max_tokens"
  19. FinishReasonToolUse FinishReason = "tool_use"
  20. FinishReasonCanceled FinishReason = "canceled"
  21. FinishReasonError FinishReason = "error"
  22. FinishReasonPermissionDenied FinishReason = "permission_denied"
  23. // Should never happen
  24. FinishReasonUnknown FinishReason = "unknown"
  25. )
  26. type ContentPart interface {
  27. isPart()
  28. }
  29. type ReasoningContent struct {
  30. Thinking string `json:"thinking"`
  31. Signature string `json:"signature"`
  32. StartedAt int64 `json:"started_at,omitempty"`
  33. FinishedAt int64 `json:"finished_at,omitempty"`
  34. }
  35. func (tc ReasoningContent) String() string {
  36. return tc.Thinking
  37. }
  38. func (ReasoningContent) isPart() {}
  39. type TextContent struct {
  40. Text string `json:"text"`
  41. }
  42. func (tc TextContent) String() string {
  43. return tc.Text
  44. }
  45. func (TextContent) isPart() {}
  46. type ImageURLContent struct {
  47. URL string `json:"url"`
  48. Detail string `json:"detail,omitempty"`
  49. }
  50. func (iuc ImageURLContent) String() string {
  51. return iuc.URL
  52. }
  53. func (ImageURLContent) isPart() {}
  54. type BinaryContent struct {
  55. Path string
  56. MIMEType string
  57. Data []byte
  58. }
  59. func (bc BinaryContent) String(p catwalk.InferenceProvider) string {
  60. base64Encoded := base64.StdEncoding.EncodeToString(bc.Data)
  61. if p == catwalk.InferenceProviderOpenAI {
  62. return "data:" + bc.MIMEType + ";base64," + base64Encoded
  63. }
  64. return base64Encoded
  65. }
  66. func (BinaryContent) isPart() {}
  67. type ToolCall struct {
  68. ID string `json:"id"`
  69. Name string `json:"name"`
  70. Input string `json:"input"`
  71. Type string `json:"type"`
  72. Finished bool `json:"finished"`
  73. }
  74. func (ToolCall) isPart() {}
  75. type ToolResult struct {
  76. ToolCallID string `json:"tool_call_id"`
  77. Name string `json:"name"`
  78. Content string `json:"content"`
  79. Metadata string `json:"metadata"`
  80. IsError bool `json:"is_error"`
  81. }
  82. func (ToolResult) isPart() {}
  83. type Finish struct {
  84. Reason FinishReason `json:"reason"`
  85. Time int64 `json:"time"`
  86. Message string `json:"message,omitempty"`
  87. Details string `json:"details,omitempty"`
  88. }
  89. func (Finish) isPart() {}
  90. type Message struct {
  91. ID string
  92. Role MessageRole
  93. SessionID string
  94. Parts []ContentPart
  95. Model string
  96. Provider string
  97. CreatedAt int64
  98. UpdatedAt int64
  99. }
  100. func (m *Message) Content() TextContent {
  101. for _, part := range m.Parts {
  102. if c, ok := part.(TextContent); ok {
  103. return c
  104. }
  105. }
  106. return TextContent{}
  107. }
  108. func (m *Message) ReasoningContent() ReasoningContent {
  109. for _, part := range m.Parts {
  110. if c, ok := part.(ReasoningContent); ok {
  111. return c
  112. }
  113. }
  114. return ReasoningContent{}
  115. }
  116. func (m *Message) ImageURLContent() []ImageURLContent {
  117. imageURLContents := make([]ImageURLContent, 0)
  118. for _, part := range m.Parts {
  119. if c, ok := part.(ImageURLContent); ok {
  120. imageURLContents = append(imageURLContents, c)
  121. }
  122. }
  123. return imageURLContents
  124. }
  125. func (m *Message) BinaryContent() []BinaryContent {
  126. binaryContents := make([]BinaryContent, 0)
  127. for _, part := range m.Parts {
  128. if c, ok := part.(BinaryContent); ok {
  129. binaryContents = append(binaryContents, c)
  130. }
  131. }
  132. return binaryContents
  133. }
  134. func (m *Message) ToolCalls() []ToolCall {
  135. toolCalls := make([]ToolCall, 0)
  136. for _, part := range m.Parts {
  137. if c, ok := part.(ToolCall); ok {
  138. toolCalls = append(toolCalls, c)
  139. }
  140. }
  141. return toolCalls
  142. }
  143. func (m *Message) ToolResults() []ToolResult {
  144. toolResults := make([]ToolResult, 0)
  145. for _, part := range m.Parts {
  146. if c, ok := part.(ToolResult); ok {
  147. toolResults = append(toolResults, c)
  148. }
  149. }
  150. return toolResults
  151. }
  152. func (m *Message) IsFinished() bool {
  153. for _, part := range m.Parts {
  154. if _, ok := part.(Finish); ok {
  155. return true
  156. }
  157. }
  158. return false
  159. }
  160. func (m *Message) FinishPart() *Finish {
  161. for _, part := range m.Parts {
  162. if c, ok := part.(Finish); ok {
  163. return &c
  164. }
  165. }
  166. return nil
  167. }
  168. func (m *Message) FinishReason() FinishReason {
  169. for _, part := range m.Parts {
  170. if c, ok := part.(Finish); ok {
  171. return c.Reason
  172. }
  173. }
  174. return ""
  175. }
  176. func (m *Message) IsThinking() bool {
  177. if m.ReasoningContent().Thinking != "" && m.Content().Text == "" && !m.IsFinished() {
  178. return true
  179. }
  180. return false
  181. }
  182. func (m *Message) AppendContent(delta string) {
  183. found := false
  184. for i, part := range m.Parts {
  185. if c, ok := part.(TextContent); ok {
  186. m.Parts[i] = TextContent{Text: c.Text + delta}
  187. found = true
  188. }
  189. }
  190. if !found {
  191. m.Parts = append(m.Parts, TextContent{Text: delta})
  192. }
  193. }
  194. func (m *Message) AppendReasoningContent(delta string) {
  195. found := false
  196. for i, part := range m.Parts {
  197. if c, ok := part.(ReasoningContent); ok {
  198. m.Parts[i] = ReasoningContent{
  199. Thinking: c.Thinking + delta,
  200. Signature: c.Signature,
  201. StartedAt: c.StartedAt,
  202. FinishedAt: c.FinishedAt,
  203. }
  204. found = true
  205. }
  206. }
  207. if !found {
  208. m.Parts = append(m.Parts, ReasoningContent{
  209. Thinking: delta,
  210. StartedAt: time.Now().Unix(),
  211. })
  212. }
  213. }
  214. func (m *Message) AppendReasoningSignature(signature string) {
  215. for i, part := range m.Parts {
  216. if c, ok := part.(ReasoningContent); ok {
  217. m.Parts[i] = ReasoningContent{
  218. Thinking: c.Thinking,
  219. Signature: c.Signature + signature,
  220. StartedAt: c.StartedAt,
  221. FinishedAt: c.FinishedAt,
  222. }
  223. return
  224. }
  225. }
  226. m.Parts = append(m.Parts, ReasoningContent{Signature: signature})
  227. }
  228. func (m *Message) FinishThinking() {
  229. for i, part := range m.Parts {
  230. if c, ok := part.(ReasoningContent); ok {
  231. if c.FinishedAt == 0 {
  232. m.Parts[i] = ReasoningContent{
  233. Thinking: c.Thinking,
  234. Signature: c.Signature,
  235. StartedAt: c.StartedAt,
  236. FinishedAt: time.Now().Unix(),
  237. }
  238. }
  239. return
  240. }
  241. }
  242. }
  243. func (m *Message) ThinkingDuration() time.Duration {
  244. reasoning := m.ReasoningContent()
  245. if reasoning.StartedAt == 0 {
  246. return 0
  247. }
  248. endTime := reasoning.FinishedAt
  249. if endTime == 0 {
  250. endTime = time.Now().Unix()
  251. }
  252. return time.Duration(endTime-reasoning.StartedAt) * time.Second
  253. }
  254. func (m *Message) FinishToolCall(toolCallID string) {
  255. for i, part := range m.Parts {
  256. if c, ok := part.(ToolCall); ok {
  257. if c.ID == toolCallID {
  258. m.Parts[i] = ToolCall{
  259. ID: c.ID,
  260. Name: c.Name,
  261. Input: c.Input,
  262. Type: c.Type,
  263. Finished: true,
  264. }
  265. return
  266. }
  267. }
  268. }
  269. }
  270. func (m *Message) AppendToolCallInput(toolCallID string, inputDelta string) {
  271. for i, part := range m.Parts {
  272. if c, ok := part.(ToolCall); ok {
  273. if c.ID == toolCallID {
  274. m.Parts[i] = ToolCall{
  275. ID: c.ID,
  276. Name: c.Name,
  277. Input: c.Input + inputDelta,
  278. Type: c.Type,
  279. Finished: c.Finished,
  280. }
  281. return
  282. }
  283. }
  284. }
  285. }
  286. func (m *Message) AddToolCall(tc ToolCall) {
  287. for i, part := range m.Parts {
  288. if c, ok := part.(ToolCall); ok {
  289. if c.ID == tc.ID {
  290. m.Parts[i] = tc
  291. return
  292. }
  293. }
  294. }
  295. m.Parts = append(m.Parts, tc)
  296. }
  297. func (m *Message) SetToolCalls(tc []ToolCall) {
  298. // remove any existing tool call part it could have multiple
  299. parts := make([]ContentPart, 0)
  300. for _, part := range m.Parts {
  301. if _, ok := part.(ToolCall); ok {
  302. continue
  303. }
  304. parts = append(parts, part)
  305. }
  306. m.Parts = parts
  307. for _, toolCall := range tc {
  308. m.Parts = append(m.Parts, toolCall)
  309. }
  310. }
  311. func (m *Message) AddToolResult(tr ToolResult) {
  312. m.Parts = append(m.Parts, tr)
  313. }
  314. func (m *Message) SetToolResults(tr []ToolResult) {
  315. for _, toolResult := range tr {
  316. m.Parts = append(m.Parts, toolResult)
  317. }
  318. }
  319. func (m *Message) AddFinish(reason FinishReason, message, details string) {
  320. // remove any existing finish part
  321. for i, part := range m.Parts {
  322. if _, ok := part.(Finish); ok {
  323. m.Parts = slices.Delete(m.Parts, i, i+1)
  324. break
  325. }
  326. }
  327. m.Parts = append(m.Parts, Finish{Reason: reason, Time: time.Now().Unix(), Message: message, Details: details})
  328. }
  329. func (m *Message) AddImageURL(url, detail string) {
  330. m.Parts = append(m.Parts, ImageURLContent{URL: url, Detail: detail})
  331. }
  332. func (m *Message) AddBinary(mimeType string, data []byte) {
  333. m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
  334. }