processor.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. import type { ModelsDev } from "@/provider/models"
  2. import { MessageV2 } from "./message-v2"
  3. import { type StreamTextResult, type Tool as AITool, APICallError } from "ai"
  4. import { Log } from "@/util/log"
  5. import { Identifier } from "@/id/id"
  6. import { Session } from "."
  7. import { Agent } from "@/agent/agent"
  8. import { Permission } from "@/permission"
  9. import { Snapshot } from "@/snapshot"
  10. import { SessionSummary } from "./summary"
  11. import { Bus } from "@/bus"
  12. import { SessionRetry } from "./retry"
  13. import { SessionStatus } from "./status"
  14. export namespace SessionProcessor {
  15. const DOOM_LOOP_THRESHOLD = 3
  16. const log = Log.create({ service: "session.processor" })
  17. export type Info = Awaited<ReturnType<typeof create>>
  18. export type Result = Awaited<ReturnType<Info["process"]>>
  19. export function create(input: {
  20. assistantMessage: MessageV2.Assistant
  21. sessionID: string
  22. providerID: string
  23. model: ModelsDev.Model
  24. abort: AbortSignal
  25. }) {
  26. const toolcalls: Record<string, MessageV2.ToolPart> = {}
  27. let snapshot: string | undefined
  28. let blocked = false
  29. let attempt = 0
  30. const result = {
  31. get message() {
  32. return input.assistantMessage
  33. },
  34. partFromToolCall(toolCallID: string) {
  35. return toolcalls[toolCallID]
  36. },
  37. async process(fn: () => StreamTextResult<Record<string, AITool>, never>) {
  38. log.info("process")
  39. while (true) {
  40. try {
  41. let currentText: MessageV2.TextPart | undefined
  42. let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
  43. const stream = fn()
  44. for await (const value of stream.fullStream) {
  45. input.abort.throwIfAborted()
  46. switch (value.type) {
  47. case "start":
  48. SessionStatus.set(input.sessionID, { type: "busy" })
  49. break
  50. case "reasoning-start":
  51. if (value.id in reasoningMap) {
  52. continue
  53. }
  54. reasoningMap[value.id] = {
  55. id: Identifier.ascending("part"),
  56. messageID: input.assistantMessage.id,
  57. sessionID: input.assistantMessage.sessionID,
  58. type: "reasoning",
  59. text: "",
  60. time: {
  61. start: Date.now(),
  62. },
  63. metadata: value.providerMetadata,
  64. }
  65. break
  66. case "reasoning-delta":
  67. if (value.id in reasoningMap) {
  68. const part = reasoningMap[value.id]
  69. part.text += value.text
  70. if (value.providerMetadata) part.metadata = value.providerMetadata
  71. if (part.text) await Session.updatePart({ part, delta: value.text })
  72. }
  73. break
  74. case "reasoning-end":
  75. if (value.id in reasoningMap) {
  76. const part = reasoningMap[value.id]
  77. part.text = part.text.trimEnd()
  78. part.time = {
  79. ...part.time,
  80. end: Date.now(),
  81. }
  82. if (value.providerMetadata) part.metadata = value.providerMetadata
  83. await Session.updatePart(part)
  84. delete reasoningMap[value.id]
  85. }
  86. break
  87. case "tool-input-start":
  88. const part = await Session.updatePart({
  89. id: toolcalls[value.id]?.id ?? Identifier.ascending("part"),
  90. messageID: input.assistantMessage.id,
  91. sessionID: input.assistantMessage.sessionID,
  92. type: "tool",
  93. tool: value.toolName,
  94. callID: value.id,
  95. state: {
  96. status: "pending",
  97. input: {},
  98. raw: "",
  99. },
  100. })
  101. toolcalls[value.id] = part as MessageV2.ToolPart
  102. break
  103. case "tool-input-delta":
  104. break
  105. case "tool-input-end":
  106. break
  107. case "tool-call": {
  108. const match = toolcalls[value.toolCallId]
  109. if (match) {
  110. const part = await Session.updatePart({
  111. ...match,
  112. tool: value.toolName,
  113. state: {
  114. status: "running",
  115. input: value.input,
  116. time: {
  117. start: Date.now(),
  118. },
  119. },
  120. metadata: value.providerMetadata,
  121. })
  122. toolcalls[value.toolCallId] = part as MessageV2.ToolPart
  123. const parts = await MessageV2.parts(input.assistantMessage.id)
  124. const lastThree = parts.slice(-DOOM_LOOP_THRESHOLD)
  125. if (
  126. lastThree.length === DOOM_LOOP_THRESHOLD &&
  127. lastThree.every(
  128. (p) =>
  129. p.type === "tool" &&
  130. p.tool === value.toolName &&
  131. p.state.status !== "pending" &&
  132. JSON.stringify(p.state.input) === JSON.stringify(value.input),
  133. )
  134. ) {
  135. const permission = await Agent.get(input.assistantMessage.mode).then((x) => x.permission)
  136. if (permission.doom_loop === "ask") {
  137. await Permission.ask({
  138. type: "doom_loop",
  139. pattern: value.toolName,
  140. sessionID: input.assistantMessage.sessionID,
  141. messageID: input.assistantMessage.id,
  142. callID: value.toolCallId,
  143. title: `Possible doom loop: "${value.toolName}" called ${DOOM_LOOP_THRESHOLD} times with identical arguments`,
  144. metadata: {
  145. tool: value.toolName,
  146. input: value.input,
  147. },
  148. })
  149. } else if (permission.doom_loop === "deny") {
  150. throw new Permission.RejectedError(
  151. input.assistantMessage.sessionID,
  152. "doom_loop",
  153. value.toolCallId,
  154. {
  155. tool: value.toolName,
  156. input: value.input,
  157. },
  158. `You seem to be stuck in a doom loop, please stop repeating the same action`,
  159. )
  160. }
  161. }
  162. }
  163. break
  164. }
  165. case "tool-result": {
  166. const match = toolcalls[value.toolCallId]
  167. if (match && match.state.status === "running") {
  168. await Session.updatePart({
  169. ...match,
  170. state: {
  171. status: "completed",
  172. input: value.input,
  173. output: value.output.output,
  174. metadata: value.output.metadata,
  175. title: value.output.title,
  176. time: {
  177. start: match.state.time.start,
  178. end: Date.now(),
  179. },
  180. attachments: value.output.attachments,
  181. },
  182. })
  183. delete toolcalls[value.toolCallId]
  184. }
  185. break
  186. }
  187. case "tool-error": {
  188. const match = toolcalls[value.toolCallId]
  189. if (match && match.state.status === "running") {
  190. await Session.updatePart({
  191. ...match,
  192. state: {
  193. status: "error",
  194. input: value.input,
  195. error: (value.error as any).toString(),
  196. metadata: value.error instanceof Permission.RejectedError ? value.error.metadata : undefined,
  197. time: {
  198. start: match.state.time.start,
  199. end: Date.now(),
  200. },
  201. },
  202. })
  203. if (value.error instanceof Permission.RejectedError) {
  204. blocked = true
  205. }
  206. delete toolcalls[value.toolCallId]
  207. }
  208. break
  209. }
  210. case "error":
  211. throw value.error
  212. case "start-step":
  213. snapshot = await Snapshot.track()
  214. await Session.updatePart({
  215. id: Identifier.ascending("part"),
  216. messageID: input.assistantMessage.id,
  217. sessionID: input.sessionID,
  218. snapshot,
  219. type: "step-start",
  220. })
  221. break
  222. case "finish-step":
  223. const usage = Session.getUsage({
  224. model: input.model,
  225. usage: value.usage,
  226. metadata: value.providerMetadata,
  227. })
  228. input.assistantMessage.finish = value.finishReason
  229. input.assistantMessage.cost += usage.cost
  230. input.assistantMessage.tokens = usage.tokens
  231. await Session.updatePart({
  232. id: Identifier.ascending("part"),
  233. reason: value.finishReason,
  234. snapshot: await Snapshot.track(),
  235. messageID: input.assistantMessage.id,
  236. sessionID: input.assistantMessage.sessionID,
  237. type: "step-finish",
  238. tokens: usage.tokens,
  239. cost: usage.cost,
  240. })
  241. await Session.updateMessage(input.assistantMessage)
  242. if (snapshot) {
  243. const patch = await Snapshot.patch(snapshot)
  244. if (patch.files.length) {
  245. await Session.updatePart({
  246. id: Identifier.ascending("part"),
  247. messageID: input.assistantMessage.id,
  248. sessionID: input.sessionID,
  249. type: "patch",
  250. hash: patch.hash,
  251. files: patch.files,
  252. })
  253. }
  254. snapshot = undefined
  255. }
  256. SessionSummary.summarize({
  257. sessionID: input.sessionID,
  258. messageID: input.assistantMessage.parentID,
  259. })
  260. break
  261. case "text-start":
  262. currentText = {
  263. id: Identifier.ascending("part"),
  264. messageID: input.assistantMessage.id,
  265. sessionID: input.assistantMessage.sessionID,
  266. type: "text",
  267. text: "",
  268. time: {
  269. start: Date.now(),
  270. },
  271. metadata: value.providerMetadata,
  272. }
  273. break
  274. case "text-delta":
  275. if (currentText) {
  276. currentText.text += value.text
  277. if (value.providerMetadata) currentText.metadata = value.providerMetadata
  278. if (currentText.text)
  279. await Session.updatePart({
  280. part: currentText,
  281. delta: value.text,
  282. })
  283. }
  284. break
  285. case "text-end":
  286. if (currentText) {
  287. currentText.text = currentText.text.trimEnd()
  288. currentText.time = {
  289. start: Date.now(),
  290. end: Date.now(),
  291. }
  292. if (value.providerMetadata) currentText.metadata = value.providerMetadata
  293. await Session.updatePart(currentText)
  294. }
  295. currentText = undefined
  296. break
  297. case "finish":
  298. break
  299. default:
  300. log.info("unhandled", {
  301. ...value,
  302. })
  303. continue
  304. }
  305. }
  306. } catch (e) {
  307. log.error("process", {
  308. error: e,
  309. })
  310. const error = MessageV2.fromError(e, { providerID: input.providerID })
  311. const retry = SessionRetry.retryable(error)
  312. if (retry !== undefined) {
  313. attempt++
  314. const delay = SessionRetry.delay(attempt, error.name === "APIError" ? error : undefined)
  315. SessionStatus.set(input.sessionID, {
  316. type: "retry",
  317. attempt,
  318. message: retry,
  319. next: Date.now() + delay,
  320. })
  321. await SessionRetry.sleep(delay, input.abort).catch(() => {})
  322. continue
  323. }
  324. input.assistantMessage.error = error
  325. Bus.publish(Session.Event.Error, {
  326. sessionID: input.assistantMessage.sessionID,
  327. error: input.assistantMessage.error,
  328. })
  329. }
  330. const p = await MessageV2.parts(input.assistantMessage.id)
  331. for (const part of p) {
  332. if (part.type === "tool" && part.state.status !== "completed" && part.state.status !== "error") {
  333. await Session.updatePart({
  334. ...part,
  335. state: {
  336. ...part.state,
  337. status: "error",
  338. error: "Tool execution aborted",
  339. time: {
  340. start: Date.now(),
  341. end: Date.now(),
  342. },
  343. },
  344. })
  345. }
  346. }
  347. input.assistantMessage.time.completed = Date.now()
  348. await Session.updateMessage(input.assistantMessage)
  349. if (blocked) return "stop"
  350. if (input.assistantMessage.error) return "stop"
  351. return "continue"
  352. }
  353. },
  354. }
  355. return result
  356. }
  357. }