message-v2.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. import z from "zod"
  2. import { Bus } from "../bus"
  3. import { NamedError } from "../util/error"
  4. import { Message } from "./message"
  5. import { convertToModelMessages, type ModelMessage, type UIMessage } from "ai"
  6. import { Identifier } from "../id/id"
  7. import { LSP } from "../lsp"
  8. export namespace MessageV2 {
  9. export const OutputLengthError = NamedError.create("MessageOutputLengthError", z.object({}))
  10. export const AbortedError = NamedError.create("MessageAbortedError", z.object({}))
  11. export const AuthError = NamedError.create(
  12. "ProviderAuthError",
  13. z.object({
  14. providerID: z.string(),
  15. message: z.string(),
  16. }),
  17. )
  18. export const ToolStatePending = z
  19. .object({
  20. status: z.literal("pending"),
  21. })
  22. .openapi({
  23. ref: "ToolStatePending",
  24. })
  25. export type ToolStatePending = z.infer<typeof ToolStatePending>
  26. export const ToolStateRunning = z
  27. .object({
  28. status: z.literal("running"),
  29. input: z.any(),
  30. title: z.string().optional(),
  31. metadata: z.record(z.any()).optional(),
  32. time: z.object({
  33. start: z.number(),
  34. }),
  35. })
  36. .openapi({
  37. ref: "ToolStateRunning",
  38. })
  39. export type ToolStateRunning = z.infer<typeof ToolStateRunning>
  40. export const ToolStateCompleted = z
  41. .object({
  42. status: z.literal("completed"),
  43. input: z.record(z.any()),
  44. output: z.string(),
  45. title: z.string(),
  46. metadata: z.record(z.any()),
  47. time: z.object({
  48. start: z.number(),
  49. end: z.number(),
  50. }),
  51. })
  52. .openapi({
  53. ref: "ToolStateCompleted",
  54. })
  55. export type ToolStateCompleted = z.infer<typeof ToolStateCompleted>
  56. export const ToolStateError = z
  57. .object({
  58. status: z.literal("error"),
  59. input: z.record(z.any()),
  60. error: z.string(),
  61. time: z.object({
  62. start: z.number(),
  63. end: z.number(),
  64. }),
  65. })
  66. .openapi({
  67. ref: "ToolStateError",
  68. })
  69. export type ToolStateError = z.infer<typeof ToolStateError>
  70. export const ToolState = z
  71. .discriminatedUnion("status", [ToolStatePending, ToolStateRunning, ToolStateCompleted, ToolStateError])
  72. .openapi({
  73. ref: "ToolState",
  74. })
  75. const PartBase = z.object({
  76. id: z.string(),
  77. sessionID: z.string(),
  78. messageID: z.string(),
  79. })
  80. export const SnapshotPart = PartBase.extend({
  81. type: z.literal("snapshot"),
  82. snapshot: z.string(),
  83. }).openapi({
  84. ref: "SnapshotPart",
  85. })
  86. export type SnapshotPart = z.infer<typeof SnapshotPart>
  87. export const PatchPart = PartBase.extend({
  88. type: z.literal("patch"),
  89. hash: z.string(),
  90. files: z.string().array(),
  91. }).openapi({
  92. ref: "PatchPart",
  93. })
  94. export type PatchPart = z.infer<typeof PatchPart>
  95. export const TextPart = PartBase.extend({
  96. type: z.literal("text"),
  97. text: z.string(),
  98. synthetic: z.boolean().optional(),
  99. time: z
  100. .object({
  101. start: z.number(),
  102. end: z.number().optional(),
  103. })
  104. .optional(),
  105. }).openapi({
  106. ref: "TextPart",
  107. })
  108. export type TextPart = z.infer<typeof TextPart>
  109. export const ToolPart = PartBase.extend({
  110. type: z.literal("tool"),
  111. callID: z.string(),
  112. tool: z.string(),
  113. state: ToolState,
  114. }).openapi({
  115. ref: "ToolPart",
  116. })
  117. export type ToolPart = z.infer<typeof ToolPart>
  118. const FilePartSourceBase = z.object({
  119. text: z
  120. .object({
  121. value: z.string(),
  122. start: z.number().int(),
  123. end: z.number().int(),
  124. })
  125. .openapi({
  126. ref: "FilePartSourceText",
  127. }),
  128. })
  129. export const FileSource = FilePartSourceBase.extend({
  130. type: z.literal("file"),
  131. path: z.string(),
  132. }).openapi({
  133. ref: "FileSource",
  134. })
  135. export const SymbolSource = FilePartSourceBase.extend({
  136. type: z.literal("symbol"),
  137. path: z.string(),
  138. range: LSP.Range,
  139. name: z.string(),
  140. kind: z.number().int(),
  141. }).openapi({
  142. ref: "SymbolSource",
  143. })
  144. export const FilePartSource = z.discriminatedUnion("type", [FileSource, SymbolSource]).openapi({
  145. ref: "FilePartSource",
  146. })
  147. export const FilePart = PartBase.extend({
  148. type: z.literal("file"),
  149. mime: z.string(),
  150. filename: z.string().optional(),
  151. url: z.string(),
  152. source: FilePartSource.optional(),
  153. }).openapi({
  154. ref: "FilePart",
  155. })
  156. export type FilePart = z.infer<typeof FilePart>
  157. export const StepStartPart = PartBase.extend({
  158. type: z.literal("step-start"),
  159. }).openapi({
  160. ref: "StepStartPart",
  161. })
  162. export type StepStartPart = z.infer<typeof StepStartPart>
  163. export const StepFinishPart = PartBase.extend({
  164. type: z.literal("step-finish"),
  165. cost: z.number(),
  166. tokens: z.object({
  167. input: z.number(),
  168. output: z.number(),
  169. reasoning: z.number(),
  170. cache: z.object({
  171. read: z.number(),
  172. write: z.number(),
  173. }),
  174. }),
  175. }).openapi({
  176. ref: "StepFinishPart",
  177. })
  178. export type StepFinishPart = z.infer<typeof StepFinishPart>
  179. const Base = z.object({
  180. id: z.string(),
  181. sessionID: z.string(),
  182. })
  183. export const User = Base.extend({
  184. role: z.literal("user"),
  185. time: z.object({
  186. created: z.number(),
  187. }),
  188. }).openapi({
  189. ref: "UserMessage",
  190. })
  191. export type User = z.infer<typeof User>
  192. export const Part = z
  193. .discriminatedUnion("type", [TextPart, FilePart, ToolPart, StepStartPart, StepFinishPart, SnapshotPart, PatchPart])
  194. .openapi({
  195. ref: "Part",
  196. })
  197. export type Part = z.infer<typeof Part>
  198. export const Assistant = Base.extend({
  199. role: z.literal("assistant"),
  200. time: z.object({
  201. created: z.number(),
  202. completed: z.number().optional(),
  203. }),
  204. error: z
  205. .discriminatedUnion("name", [
  206. AuthError.Schema,
  207. NamedError.Unknown.Schema,
  208. OutputLengthError.Schema,
  209. AbortedError.Schema,
  210. ])
  211. .optional(),
  212. system: z.string().array(),
  213. modelID: z.string(),
  214. providerID: z.string(),
  215. mode: z.string(),
  216. path: z.object({
  217. cwd: z.string(),
  218. root: z.string(),
  219. }),
  220. summary: z.boolean().optional(),
  221. cost: z.number(),
  222. tokens: z.object({
  223. input: z.number(),
  224. output: z.number(),
  225. reasoning: z.number(),
  226. cache: z.object({
  227. read: z.number(),
  228. write: z.number(),
  229. }),
  230. }),
  231. }).openapi({
  232. ref: "AssistantMessage",
  233. })
  234. export type Assistant = z.infer<typeof Assistant>
  235. export const Info = z.discriminatedUnion("role", [User, Assistant]).openapi({
  236. ref: "Message",
  237. })
  238. export type Info = z.infer<typeof Info>
  239. export const Event = {
  240. Updated: Bus.event(
  241. "message.updated",
  242. z.object({
  243. info: Info,
  244. }),
  245. ),
  246. Removed: Bus.event(
  247. "message.removed",
  248. z.object({
  249. sessionID: z.string(),
  250. messageID: z.string(),
  251. }),
  252. ),
  253. PartUpdated: Bus.event(
  254. "message.part.updated",
  255. z.object({
  256. part: Part,
  257. }),
  258. ),
  259. PartRemoved: Bus.event(
  260. "message.part.removed",
  261. z.object({
  262. messageID: z.string(),
  263. partID: z.string(),
  264. }),
  265. ),
  266. }
  267. export function fromV1(v1: Message.Info) {
  268. if (v1.role === "assistant") {
  269. const info: Assistant = {
  270. id: v1.id,
  271. sessionID: v1.metadata.sessionID,
  272. role: "assistant",
  273. time: {
  274. created: v1.metadata.time.created,
  275. completed: v1.metadata.time.completed,
  276. },
  277. cost: v1.metadata.assistant!.cost,
  278. path: v1.metadata.assistant!.path,
  279. summary: v1.metadata.assistant!.summary,
  280. tokens: v1.metadata.assistant!.tokens,
  281. modelID: v1.metadata.assistant!.modelID,
  282. providerID: v1.metadata.assistant!.providerID,
  283. system: v1.metadata.assistant!.system,
  284. mode: "build",
  285. error: v1.metadata.error,
  286. }
  287. const parts = v1.parts.flatMap((part): Part[] => {
  288. const base = {
  289. id: Identifier.ascending("part"),
  290. messageID: v1.id,
  291. sessionID: v1.metadata.sessionID,
  292. }
  293. if (part.type === "text") {
  294. return [
  295. {
  296. ...base,
  297. type: "text",
  298. text: part.text,
  299. },
  300. ]
  301. }
  302. if (part.type === "step-start") {
  303. return [
  304. {
  305. ...base,
  306. type: "step-start",
  307. },
  308. ]
  309. }
  310. if (part.type === "tool-invocation") {
  311. return [
  312. {
  313. ...base,
  314. type: "tool",
  315. callID: part.toolInvocation.toolCallId,
  316. tool: part.toolInvocation.toolName,
  317. state: (() => {
  318. if (part.toolInvocation.state === "partial-call") {
  319. return {
  320. status: "pending",
  321. }
  322. }
  323. const { title, time, ...metadata } = v1.metadata.tool[part.toolInvocation.toolCallId] ?? {}
  324. if (part.toolInvocation.state === "call") {
  325. return {
  326. status: "running",
  327. input: part.toolInvocation.args,
  328. time: {
  329. start: time?.start,
  330. },
  331. }
  332. }
  333. if (part.toolInvocation.state === "result") {
  334. return {
  335. status: "completed",
  336. input: part.toolInvocation.args,
  337. output: part.toolInvocation.result,
  338. title,
  339. time,
  340. metadata,
  341. }
  342. }
  343. throw new Error("unknown tool invocation state")
  344. })(),
  345. },
  346. ]
  347. }
  348. return []
  349. })
  350. return {
  351. info,
  352. parts,
  353. }
  354. }
  355. if (v1.role === "user") {
  356. const info: User = {
  357. id: v1.id,
  358. sessionID: v1.metadata.sessionID,
  359. role: "user",
  360. time: {
  361. created: v1.metadata.time.created,
  362. },
  363. }
  364. const parts = v1.parts.flatMap((part): Part[] => {
  365. const base = {
  366. id: Identifier.ascending("part"),
  367. messageID: v1.id,
  368. sessionID: v1.metadata.sessionID,
  369. }
  370. if (part.type === "text") {
  371. return [
  372. {
  373. ...base,
  374. type: "text",
  375. text: part.text,
  376. },
  377. ]
  378. }
  379. if (part.type === "file") {
  380. return [
  381. {
  382. ...base,
  383. type: "file",
  384. mime: part.mediaType,
  385. filename: part.filename,
  386. url: part.url,
  387. },
  388. ]
  389. }
  390. return []
  391. })
  392. return { info, parts }
  393. }
  394. throw new Error("unknown message type")
  395. }
  396. export function toModelMessage(
  397. input: {
  398. info: Info
  399. parts: Part[]
  400. }[],
  401. ): ModelMessage[] {
  402. const result: UIMessage[] = []
  403. for (const msg of input) {
  404. if (msg.parts.length === 0) continue
  405. if (msg.info.role === "user") {
  406. result.push({
  407. id: msg.info.id,
  408. role: "user",
  409. parts: msg.parts.flatMap((part): UIMessage["parts"] => {
  410. if (part.type === "text")
  411. return [
  412. {
  413. type: "text",
  414. text: part.text,
  415. },
  416. ]
  417. // text/plain files are converted into text parts, ignore them
  418. if (part.type === "file" && part.mime !== "text/plain")
  419. return [
  420. {
  421. type: "file",
  422. url: part.url,
  423. mediaType: part.mime,
  424. filename: part.filename,
  425. },
  426. ]
  427. return []
  428. }),
  429. })
  430. }
  431. if (msg.info.role === "assistant") {
  432. result.push({
  433. id: msg.info.id,
  434. role: "assistant",
  435. parts: msg.parts.flatMap((part): UIMessage["parts"] => {
  436. if (part.type === "text")
  437. return [
  438. {
  439. type: "text",
  440. text: part.text,
  441. },
  442. ]
  443. if (part.type === "step-start")
  444. return [
  445. {
  446. type: "step-start",
  447. },
  448. ]
  449. if (part.type === "tool") {
  450. if (part.state.status === "completed")
  451. return [
  452. {
  453. type: ("tool-" + part.tool) as `tool-${string}`,
  454. state: "output-available",
  455. toolCallId: part.callID,
  456. input: part.state.input,
  457. output: part.state.output,
  458. },
  459. ]
  460. if (part.state.status === "error")
  461. return [
  462. {
  463. type: ("tool-" + part.tool) as `tool-${string}`,
  464. state: "output-error",
  465. toolCallId: part.callID,
  466. input: part.state.input,
  467. errorText: part.state.error,
  468. },
  469. ]
  470. }
  471. return []
  472. }),
  473. })
  474. }
  475. }
  476. return convertToModelMessages(result)
  477. }
  478. }