session.ts 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623
  1. import path from "path"
  2. import { App } from "../app/app"
  3. import { Identifier } from "../id/id"
  4. import { Storage } from "../storage/storage"
  5. import { Log } from "../util/log"
  6. import {
  7. convertToModelMessages,
  8. generateText,
  9. stepCountIs,
  10. streamText,
  11. tool,
  12. type Tool as AITool,
  13. type LanguageModelUsage,
  14. } from "ai"
  15. import { z, ZodSchema } from "zod"
  16. import { Decimal } from "decimal.js"
  17. import PROMPT_ANTHROPIC from "./prompt/anthropic.txt"
  18. import PROMPT_TITLE from "./prompt/title.txt"
  19. import PROMPT_SUMMARIZE from "./prompt/summarize.txt"
  20. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  21. import { Share } from "../share/share"
  22. import { Message } from "./message"
  23. import { Bus } from "../bus"
  24. import { Provider } from "../provider/provider"
  25. import { SessionContext } from "./context"
  26. import { ListTool } from "../tool/ls"
  27. import { MCP } from "../mcp"
  28. export namespace Session {
  29. const log = Log.create({ service: "session" })
  30. export const Info = z
  31. .object({
  32. id: Identifier.schema("session"),
  33. share: z
  34. .object({
  35. secret: z.string(),
  36. url: z.string(),
  37. })
  38. .optional(),
  39. title: z.string(),
  40. time: z.object({
  41. created: z.number(),
  42. updated: z.number(),
  43. }),
  44. })
  45. .openapi({
  46. ref: "session.info",
  47. })
  48. export type Info = z.output<typeof Info>
  49. export const Event = {
  50. Updated: Bus.event(
  51. "session.updated",
  52. z.object({
  53. info: Info,
  54. }),
  55. ),
  56. }
  57. const state = App.state("session", () => {
  58. const sessions = new Map<string, Info>()
  59. const messages = new Map<string, Message.Info[]>()
  60. return {
  61. sessions,
  62. messages,
  63. }
  64. })
  65. export async function create() {
  66. const result: Info = {
  67. id: Identifier.descending("session"),
  68. title: "New Session - " + new Date().toISOString(),
  69. time: {
  70. created: Date.now(),
  71. updated: Date.now(),
  72. },
  73. }
  74. log.info("created", result)
  75. state().sessions.set(result.id, result)
  76. await Storage.writeJSON("session/info/" + result.id, result)
  77. share(result.id).then((share) => {
  78. update(result.id, (draft) => {
  79. draft.share = share
  80. })
  81. })
  82. Bus.publish(Event.Updated, {
  83. info: result,
  84. })
  85. return result
  86. }
  87. export async function get(id: string) {
  88. const result = state().sessions.get(id)
  89. if (result) {
  90. return result
  91. }
  92. const read = await Storage.readJSON<Info>("session/info/" + id)
  93. state().sessions.set(id, read)
  94. return read as Info
  95. }
  96. export async function share(id: string) {
  97. const session = await get(id)
  98. if (session.share) return session.share
  99. const share = await Share.create(id)
  100. await update(id, (draft) => {
  101. draft.share = share
  102. })
  103. for (const msg of await messages(id)) {
  104. await Share.sync("session/message/" + id + "/" + msg.id, msg)
  105. }
  106. return share
  107. }
  108. export async function update(id: string, editor: (session: Info) => void) {
  109. const { sessions } = state()
  110. const session = await get(id)
  111. if (!session) return
  112. editor(session)
  113. session.time.updated = Date.now()
  114. sessions.set(id, session)
  115. await Storage.writeJSON("session/info/" + id, session)
  116. Bus.publish(Event.Updated, {
  117. info: session,
  118. })
  119. return session
  120. }
  121. export async function messages(sessionID: string) {
  122. const result = [] as Message.Info[]
  123. const list = Storage.list("session/message/" + sessionID)
  124. for await (const p of list) {
  125. const read = await Storage.readJSON<Message.Info>(p).catch(() => {})
  126. if (!read) continue
  127. result.push(read)
  128. }
  129. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  130. return result
  131. }
  132. export async function* list() {
  133. for await (const item of Storage.list("session/info")) {
  134. const sessionID = path.basename(item, ".json")
  135. yield get(sessionID)
  136. }
  137. }
  138. export function abort(sessionID: string) {
  139. const controller = pending.get(sessionID)
  140. if (!controller) return false
  141. controller.abort()
  142. pending.delete(sessionID)
  143. return true
  144. }
  145. async function updateMessage(msg: Message.Info) {
  146. await Storage.writeJSON(
  147. "session/message/" + msg.metadata.sessionID + "/" + msg.id,
  148. msg,
  149. )
  150. Bus.publish(Message.Event.Updated, {
  151. info: msg,
  152. })
  153. }
  154. export async function chat(input: {
  155. sessionID: string
  156. providerID: string
  157. modelID: string
  158. parts: Message.Part[]
  159. }) {
  160. const l = log.clone().tag("session", input.sessionID)
  161. l.info("chatting")
  162. const model = await Provider.getModel(input.providerID, input.modelID)
  163. let msgs = await messages(input.sessionID)
  164. const previous = msgs.at(-1)
  165. if (previous?.metadata.assistant) {
  166. const tokens =
  167. previous.metadata.assistant.tokens.input +
  168. previous.metadata.assistant.tokens.output
  169. if (
  170. tokens >
  171. (model.info.contextWindow - (model.info.maxOutputTokens ?? 0)) * 0.9
  172. ) {
  173. await summarize({
  174. sessionID: input.sessionID,
  175. providerID: input.providerID,
  176. modelID: input.modelID,
  177. })
  178. return chat(input)
  179. }
  180. }
  181. using abort = lock(input.sessionID)
  182. const lastSummary = msgs.findLast(
  183. (msg) => msg.metadata.assistant?.summary === true,
  184. )
  185. if (lastSummary)
  186. msgs = msgs.filter(
  187. (msg) => msg.role === "system" || msg.id >= lastSummary.id,
  188. )
  189. if (msgs.length === 0) {
  190. const app = App.info()
  191. const system: Message.Info = {
  192. id: Identifier.ascending("message"),
  193. role: "system",
  194. parts: [
  195. {
  196. type: "text",
  197. text: PROMPT_ANTHROPIC,
  198. },
  199. {
  200. type: "text",
  201. text: `Here is some useful information about the environment you are running in:
  202. <env>
  203. Working directory: ${app.path.cwd}
  204. Is directory a git repo: ${app.git ? "yes" : "no"}
  205. Platform: ${process.platform}
  206. Today's date: ${new Date().toISOString()}
  207. </env>
  208. <project>
  209. ${app.git ? await ListTool.execute({ path: app.path.cwd }, { sessionID: input.sessionID }).then((x) => x.output) : ""}
  210. </project>
  211. `,
  212. },
  213. ],
  214. metadata: {
  215. sessionID: input.sessionID,
  216. time: {
  217. created: Date.now(),
  218. },
  219. tool: {},
  220. },
  221. }
  222. const context = await SessionContext.find()
  223. if (context) {
  224. system.parts.push({
  225. type: "text",
  226. text: context,
  227. })
  228. }
  229. msgs.push(system)
  230. generateText({
  231. maxOutputTokens: 80,
  232. messages: convertToModelMessages([
  233. {
  234. role: "system",
  235. parts: [
  236. {
  237. type: "text",
  238. text: PROMPT_TITLE,
  239. },
  240. ],
  241. },
  242. {
  243. role: "user",
  244. parts: input.parts,
  245. },
  246. ]),
  247. model: model.language,
  248. }).then((result) => {
  249. return Session.update(input.sessionID, (draft) => {
  250. draft.title = result.text
  251. })
  252. })
  253. await updateMessage(system)
  254. }
  255. const msg: Message.Info = {
  256. role: "user",
  257. id: Identifier.ascending("message"),
  258. parts: input.parts,
  259. metadata: {
  260. time: {
  261. created: Date.now(),
  262. },
  263. sessionID: input.sessionID,
  264. tool: {},
  265. },
  266. }
  267. await updateMessage(msg)
  268. msgs.push(msg)
  269. const next: Message.Info = {
  270. id: Identifier.ascending("message"),
  271. role: "assistant",
  272. parts: [],
  273. metadata: {
  274. assistant: {
  275. cost: 0,
  276. tokens: {
  277. input: 0,
  278. output: 0,
  279. reasoning: 0,
  280. },
  281. modelID: input.modelID,
  282. providerID: input.providerID,
  283. },
  284. time: {
  285. created: Date.now(),
  286. },
  287. sessionID: input.sessionID,
  288. tool: {},
  289. },
  290. }
  291. await updateMessage(next)
  292. const tools: Record<string, AITool> = {}
  293. for (const item of await Provider.tools(input.providerID)) {
  294. tools[item.id.replaceAll(".", "_")] = tool({
  295. id: item.id as any,
  296. description: item.description,
  297. parameters: item.parameters as ZodSchema,
  298. async execute(args, opts) {
  299. const start = Date.now()
  300. try {
  301. const result = await item.execute(args, {
  302. sessionID: input.sessionID,
  303. })
  304. next.metadata!.tool![opts.toolCallId] = {
  305. ...result.metadata,
  306. time: {
  307. start,
  308. end: Date.now(),
  309. },
  310. }
  311. return result.output
  312. } catch (e: any) {
  313. next.metadata!.tool![opts.toolCallId] = {
  314. error: true,
  315. message: e.toString(),
  316. time: {
  317. start,
  318. end: Date.now(),
  319. },
  320. }
  321. return e.toString()
  322. }
  323. },
  324. })
  325. }
  326. for (const [key, item] of Object.entries(await MCP.tools())) {
  327. const execute = item.execute
  328. if (!execute) continue
  329. item.execute = async (args, opts) => {
  330. const start = Date.now()
  331. try {
  332. const result = await execute(args, opts)
  333. next.metadata!.tool![opts.toolCallId] = {
  334. ...result.metadata,
  335. time: {
  336. start,
  337. end: Date.now(),
  338. },
  339. }
  340. return result.content
  341. .filter((x: any) => x.type === "text")
  342. .map((x: any) => x.text)
  343. .join("\n\n")
  344. } catch (e: any) {
  345. next.metadata!.tool![opts.toolCallId] = {
  346. error: true,
  347. message: e.toString(),
  348. time: {
  349. start,
  350. end: Date.now(),
  351. },
  352. }
  353. return e.toString()
  354. }
  355. }
  356. tools[key] = item
  357. }
  358. const result = streamText({
  359. onStepFinish: async (step) => {
  360. const assistant = next.metadata!.assistant!
  361. const usage = getUsage(step.usage, model.info)
  362. assistant.cost = usage.cost
  363. assistant.tokens = usage.tokens
  364. await updateMessage(next)
  365. },
  366. toolCallStreaming: false,
  367. abortSignal: abort.signal,
  368. maxRetries: 6,
  369. stopWhen: stepCountIs(1000),
  370. messages: convertToModelMessages(msgs),
  371. temperature: 0,
  372. tools: {
  373. ...(await MCP.tools()),
  374. ...tools,
  375. },
  376. model: model.language,
  377. })
  378. let text: Message.TextPart | undefined
  379. const reader = result.toUIMessageStream().getReader()
  380. while (true) {
  381. const result = await reader.read().catch((e) => {
  382. if (e instanceof DOMException && e.name === "AbortError") {
  383. return
  384. }
  385. throw e
  386. })
  387. if (!result) break
  388. const { done, value } = result
  389. if (done) break
  390. l.info("part", {
  391. type: value.type,
  392. })
  393. switch (value.type) {
  394. case "start":
  395. break
  396. case "start-step":
  397. text = undefined
  398. next.parts.push({
  399. type: "step-start",
  400. })
  401. break
  402. case "text":
  403. if (!text) {
  404. text = value
  405. next.parts.push(value)
  406. break
  407. }
  408. text.text += value.text
  409. break
  410. case "tool-call":
  411. next.parts.push({
  412. type: "tool-invocation",
  413. toolInvocation: {
  414. state: "call",
  415. ...value,
  416. // hack until zod v4
  417. args: value.args as any,
  418. },
  419. })
  420. break
  421. case "tool-call-streaming-start":
  422. next.parts.push({
  423. type: "tool-invocation",
  424. toolInvocation: {
  425. state: "call",
  426. toolName: value.toolName,
  427. toolCallId: value.toolCallId,
  428. args: {},
  429. },
  430. })
  431. break
  432. case "tool-call-delta":
  433. break
  434. case "tool-result":
  435. const match = next.parts.find(
  436. (p) =>
  437. p.type === "tool-invocation" &&
  438. p.toolInvocation.toolCallId === value.toolCallId,
  439. )
  440. if (match && match.type === "tool-invocation") {
  441. match.toolInvocation = {
  442. args: match.toolInvocation.args,
  443. toolCallId: match.toolInvocation.toolCallId,
  444. toolName: match.toolInvocation.toolName,
  445. state: "result",
  446. result: value.result as string,
  447. }
  448. }
  449. break
  450. case "finish":
  451. break
  452. case "finish-step":
  453. break
  454. case "error":
  455. log.error("error", value)
  456. break
  457. default:
  458. l.info("unhandled", {
  459. type: value.type,
  460. })
  461. }
  462. await updateMessage(next)
  463. }
  464. next.metadata!.time.completed = Date.now()
  465. await updateMessage(next)
  466. return next
  467. }
  468. export async function summarize(input: {
  469. sessionID: string
  470. providerID: string
  471. modelID: string
  472. }) {
  473. using abort = lock(input.sessionID)
  474. const msgs = await messages(input.sessionID)
  475. const lastSummary = msgs.findLast(
  476. (msg) => msg.metadata.assistant?.summary === true,
  477. )?.id
  478. const filtered = msgs.filter(
  479. (msg) => msg.role !== "system" && (!lastSummary || msg.id >= lastSummary),
  480. )
  481. const model = await Provider.getModel(input.providerID, input.modelID)
  482. const next: Message.Info = {
  483. id: Identifier.ascending("message"),
  484. role: "assistant",
  485. parts: [],
  486. metadata: {
  487. tool: {},
  488. sessionID: input.sessionID,
  489. assistant: {
  490. summary: true,
  491. cost: 0,
  492. modelID: input.modelID,
  493. providerID: input.providerID,
  494. tokens: {
  495. input: 0,
  496. output: 0,
  497. reasoning: 0,
  498. },
  499. },
  500. time: {
  501. created: Date.now(),
  502. },
  503. },
  504. }
  505. await updateMessage(next)
  506. const result = await generateText({
  507. abortSignal: abort.signal,
  508. model: model.language,
  509. messages: convertToModelMessages([
  510. {
  511. role: "system",
  512. parts: [
  513. {
  514. type: "text",
  515. text: PROMPT_SUMMARIZE,
  516. },
  517. ],
  518. },
  519. ...filtered,
  520. {
  521. role: "user",
  522. parts: [
  523. {
  524. type: "text",
  525. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  526. },
  527. ],
  528. },
  529. ]),
  530. })
  531. next.parts.push({
  532. type: "text",
  533. text: result.text,
  534. })
  535. const assistant = next.metadata!.assistant!
  536. const usage = getUsage(result.usage, model.info)
  537. assistant.cost = usage.cost
  538. assistant.tokens = usage.tokens
  539. await updateMessage(next)
  540. }
  541. const pending = new Map<string, AbortController>()
  542. function lock(sessionID: string) {
  543. log.info("locking", { sessionID })
  544. if (pending.has(sessionID)) throw new BusyError(sessionID)
  545. const controller = new AbortController()
  546. pending.set(sessionID, controller)
  547. return {
  548. signal: controller.signal,
  549. [Symbol.dispose]() {
  550. log.info("unlocking", { sessionID })
  551. pending.delete(sessionID)
  552. },
  553. }
  554. }
  555. function getUsage(usage: LanguageModelUsage, model: Provider.Model) {
  556. const tokens = {
  557. input: usage.inputTokens ?? 0,
  558. output: usage.outputTokens ?? 0,
  559. reasoning: usage.reasoningTokens ?? 0,
  560. }
  561. return {
  562. cost: new Decimal(0)
  563. .add(new Decimal(tokens.input).mul(model.cost.input))
  564. .add(new Decimal(tokens.output).mul(model.cost.output))
  565. .toNumber(),
  566. tokens,
  567. }
  568. }
  569. export class BusyError extends Error {
  570. constructor(public readonly sessionID: string) {
  571. super(`Session ${sessionID} is busy`)
  572. }
  573. }
  574. export async function initialize(input: {
  575. sessionID: string
  576. modelID: string
  577. providerID: string
  578. }) {
  579. await Session.chat({
  580. sessionID: input.sessionID,
  581. providerID: input.providerID,
  582. modelID: input.modelID,
  583. parts: [
  584. {
  585. type: "text",
  586. text: PROMPT_INITIALIZE,
  587. },
  588. ],
  589. })
  590. await App.initialize()
  591. }
  592. }