index.ts 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719
  1. import path from "path"
  2. import { App } from "../app/app"
  3. import { Identifier } from "../id/id"
  4. import { Storage } from "../storage/storage"
  5. import { Log } from "../util/log"
  6. import {
  7. convertToModelMessages,
  8. generateText,
  9. LoadAPIKeyError,
  10. stepCountIs,
  11. streamText,
  12. tool,
  13. type Tool as AITool,
  14. type LanguageModelUsage,
  15. } from "ai"
  16. import { z, ZodSchema } from "zod"
  17. import { Decimal } from "decimal.js"
  18. import PROMPT_ANTHROPIC from "./prompt/anthropic.txt"
  19. import PROMPT_ANTHROPIC_SPOOF from "./prompt/anthropic_spoof.txt"
  20. import PROMPT_TITLE from "./prompt/title.txt"
  21. import PROMPT_SUMMARIZE from "./prompt/summarize.txt"
  22. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  23. import { Share } from "../share/share"
  24. import { Message } from "./message"
  25. import { Bus } from "../bus"
  26. import { Provider } from "../provider/provider"
  27. import { SessionContext } from "./context"
  28. import { ListTool } from "../tool/ls"
  29. import { MCP } from "../mcp"
  30. import { NamedError } from "../util/error"
  31. export namespace Session {
  32. const log = Log.create({ service: "session" })
  33. export const Info = z
  34. .object({
  35. id: Identifier.schema("session"),
  36. share: z
  37. .object({
  38. secret: z.string(),
  39. url: z.string(),
  40. })
  41. .optional(),
  42. title: z.string(),
  43. time: z.object({
  44. created: z.number(),
  45. updated: z.number(),
  46. }),
  47. })
  48. .openapi({
  49. ref: "session.info",
  50. })
  51. export type Info = z.output<typeof Info>
  52. export const Event = {
  53. Updated: Bus.event(
  54. "session.updated",
  55. z.object({
  56. info: Info,
  57. }),
  58. ),
  59. Error: Bus.event(
  60. "session.error",
  61. z.object({
  62. error: Message.Info.shape.metadata.shape.error,
  63. }),
  64. ),
  65. }
  66. const state = App.state("session", () => {
  67. const sessions = new Map<string, Info>()
  68. const messages = new Map<string, Message.Info[]>()
  69. return {
  70. sessions,
  71. messages,
  72. }
  73. })
  74. export async function create() {
  75. const result: Info = {
  76. id: Identifier.descending("session"),
  77. title: "New Session - " + new Date().toISOString(),
  78. time: {
  79. created: Date.now(),
  80. updated: Date.now(),
  81. },
  82. }
  83. log.info("created", result)
  84. state().sessions.set(result.id, result)
  85. await Storage.writeJSON("session/info/" + result.id, result)
  86. share(result.id).then((share) => {
  87. update(result.id, (draft) => {
  88. draft.share = share
  89. })
  90. })
  91. Bus.publish(Event.Updated, {
  92. info: result,
  93. })
  94. return result
  95. }
  96. export async function get(id: string) {
  97. const result = state().sessions.get(id)
  98. if (result) {
  99. return result
  100. }
  101. const read = await Storage.readJSON<Info>("session/info/" + id)
  102. state().sessions.set(id, read)
  103. return read as Info
  104. }
  105. export async function share(id: string) {
  106. const session = await get(id)
  107. if (session.share) return session.share
  108. const share = await Share.create(id)
  109. await update(id, (draft) => {
  110. draft.share = share
  111. })
  112. for (const msg of await messages(id)) {
  113. await Share.sync("session/message/" + id + "/" + msg.id, msg)
  114. }
  115. return share
  116. }
  117. export async function update(id: string, editor: (session: Info) => void) {
  118. const { sessions } = state()
  119. const session = await get(id)
  120. if (!session) return
  121. editor(session)
  122. session.time.updated = Date.now()
  123. sessions.set(id, session)
  124. await Storage.writeJSON("session/info/" + id, session)
  125. Bus.publish(Event.Updated, {
  126. info: session,
  127. })
  128. return session
  129. }
  130. export async function messages(sessionID: string) {
  131. const result = [] as Message.Info[]
  132. const list = Storage.list("session/message/" + sessionID)
  133. for await (const p of list) {
  134. const read = await Storage.readJSON<Message.Info>(p).catch(() => {})
  135. if (!read) continue
  136. result.push(read)
  137. }
  138. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  139. return result
  140. }
  141. export async function* list() {
  142. for await (const item of Storage.list("session/info")) {
  143. const sessionID = path.basename(item, ".json")
  144. yield get(sessionID)
  145. }
  146. }
  147. export function abort(sessionID: string) {
  148. const controller = pending.get(sessionID)
  149. if (!controller) return false
  150. controller.abort()
  151. pending.delete(sessionID)
  152. return true
  153. }
  154. async function updateMessage(msg: Message.Info) {
  155. await Storage.writeJSON(
  156. "session/message/" + msg.metadata.sessionID + "/" + msg.id,
  157. msg,
  158. )
  159. Bus.publish(Message.Event.Updated, {
  160. info: msg,
  161. })
  162. }
  163. export async function chat(input: {
  164. sessionID: string
  165. providerID: string
  166. modelID: string
  167. parts: Message.Part[]
  168. }) {
  169. const l = log.clone().tag("session", input.sessionID)
  170. l.info("chatting")
  171. const model = await Provider.getModel(input.providerID, input.modelID)
  172. let msgs = await messages(input.sessionID)
  173. const previous = msgs.at(-1)
  174. if (previous?.metadata.assistant) {
  175. const tokens =
  176. previous.metadata.assistant.tokens.input +
  177. previous.metadata.assistant.tokens.output
  178. if (
  179. tokens >
  180. (model.info.limit.context - (model.info.limit.output ?? 0)) * 0.9
  181. ) {
  182. await summarize({
  183. sessionID: input.sessionID,
  184. providerID: input.providerID,
  185. modelID: input.modelID,
  186. })
  187. return chat(input)
  188. }
  189. }
  190. using abort = lock(input.sessionID)
  191. const lastSummary = msgs.findLast(
  192. (msg) => msg.metadata.assistant?.summary === true,
  193. )
  194. if (lastSummary)
  195. msgs = msgs.filter(
  196. (msg) => msg.role === "system" || msg.id >= lastSummary.id,
  197. )
  198. if (msgs.length === 0) {
  199. const app = App.info()
  200. if (input.providerID === "anthropic") {
  201. const claude: Message.Info = {
  202. id: Identifier.ascending("message"),
  203. role: "system",
  204. parts: [
  205. {
  206. type: "text",
  207. text: PROMPT_ANTHROPIC_SPOOF.trim(),
  208. },
  209. ],
  210. metadata: {
  211. sessionID: input.sessionID,
  212. time: {
  213. created: Date.now(),
  214. },
  215. tool: {},
  216. },
  217. }
  218. await updateMessage(claude)
  219. msgs.push(claude)
  220. }
  221. const system: Message.Info = {
  222. id: Identifier.ascending("message"),
  223. role: "system",
  224. parts: [
  225. {
  226. type: "text",
  227. text: PROMPT_ANTHROPIC,
  228. },
  229. {
  230. type: "text",
  231. text: [
  232. `Here is some useful information about the environment you are running in:`,
  233. `<env>`,
  234. `Working directory: ${app.path.cwd}`,
  235. `Is directory a git repo: ${app.git ? "yes" : "no"}`,
  236. `Platform: ${process.platform}`,
  237. `Today's date: ${new Date().toISOString()}`,
  238. `</env>`,
  239. `<project>`,
  240. `${app.git ? await ListTool.execute({ path: app.path.cwd, ignore: [] }, { sessionID: input.sessionID, abort: abort.signal }).then((x) => x.output) : ""}`,
  241. `</project>`,
  242. ].join("\n"),
  243. },
  244. ],
  245. metadata: {
  246. sessionID: input.sessionID,
  247. time: {
  248. created: Date.now(),
  249. },
  250. tool: {},
  251. },
  252. }
  253. const context = await SessionContext.find()
  254. if (context) {
  255. system.parts.push({
  256. type: "text",
  257. text: context,
  258. })
  259. }
  260. msgs.push(system)
  261. generateText({
  262. maxOutputTokens: 80,
  263. messages: convertToModelMessages([
  264. {
  265. role: "system",
  266. parts: [
  267. {
  268. type: "text",
  269. text: PROMPT_ANTHROPIC_SPOOF.trim(),
  270. },
  271. ],
  272. },
  273. {
  274. role: "system",
  275. parts: [
  276. {
  277. type: "text",
  278. text: PROMPT_TITLE,
  279. },
  280. ],
  281. },
  282. {
  283. role: "user",
  284. parts: input.parts,
  285. },
  286. ]),
  287. model: model.language,
  288. })
  289. .then((result) => {
  290. return Session.update(input.sessionID, (draft) => {
  291. draft.title = result.text
  292. })
  293. })
  294. .catch(() => {})
  295. await updateMessage(system)
  296. }
  297. const msg: Message.Info = {
  298. role: "user",
  299. id: Identifier.ascending("message"),
  300. parts: input.parts,
  301. metadata: {
  302. time: {
  303. created: Date.now(),
  304. },
  305. sessionID: input.sessionID,
  306. tool: {},
  307. },
  308. }
  309. await updateMessage(msg)
  310. msgs.push(msg)
  311. const next: Message.Info = {
  312. id: Identifier.ascending("message"),
  313. role: "assistant",
  314. parts: [],
  315. metadata: {
  316. assistant: {
  317. cost: 0,
  318. tokens: {
  319. input: 0,
  320. output: 0,
  321. reasoning: 0,
  322. },
  323. modelID: input.modelID,
  324. providerID: input.providerID,
  325. },
  326. time: {
  327. created: Date.now(),
  328. },
  329. sessionID: input.sessionID,
  330. tool: {},
  331. },
  332. }
  333. await updateMessage(next)
  334. const tools: Record<string, AITool> = {}
  335. for (const item of await Provider.tools(input.providerID)) {
  336. tools[item.id.replaceAll(".", "_")] = tool({
  337. id: item.id as any,
  338. description: item.description,
  339. parameters: item.parameters as ZodSchema,
  340. async execute(args, opts) {
  341. const start = Date.now()
  342. try {
  343. const result = await item.execute(args, {
  344. sessionID: input.sessionID,
  345. abort: abort.signal,
  346. })
  347. next.metadata!.tool![opts.toolCallId] = {
  348. ...result.metadata,
  349. time: {
  350. start,
  351. end: Date.now(),
  352. },
  353. }
  354. return result.output
  355. } catch (e: any) {
  356. next.metadata!.tool![opts.toolCallId] = {
  357. error: true,
  358. message: e.toString(),
  359. time: {
  360. start,
  361. end: Date.now(),
  362. },
  363. }
  364. return e.toString()
  365. }
  366. },
  367. })
  368. }
  369. for (const [key, item] of Object.entries(await MCP.tools())) {
  370. const execute = item.execute
  371. if (!execute) continue
  372. item.execute = async (args, opts) => {
  373. const start = Date.now()
  374. try {
  375. const result = await execute(args, opts)
  376. next.metadata!.tool![opts.toolCallId] = {
  377. ...result.metadata,
  378. time: {
  379. start,
  380. end: Date.now(),
  381. },
  382. }
  383. return result.content
  384. .filter((x: any) => x.type === "text")
  385. .map((x: any) => x.text)
  386. .join("\n\n")
  387. } catch (e: any) {
  388. next.metadata!.tool![opts.toolCallId] = {
  389. error: true,
  390. message: e.toString(),
  391. time: {
  392. start,
  393. end: Date.now(),
  394. },
  395. }
  396. return e.toString()
  397. }
  398. }
  399. tools[key] = item
  400. }
  401. let text: Message.TextPart | undefined
  402. const result = streamText({
  403. onStepFinish: async (step) => {
  404. log.info("step finish", {
  405. finishReason: step.finishReason,
  406. })
  407. const assistant = next.metadata!.assistant!
  408. const usage = getUsage(step.usage, model.info)
  409. assistant.cost += usage.cost
  410. assistant.tokens = usage.tokens
  411. await updateMessage(next)
  412. if (text) {
  413. Bus.publish(Message.Event.PartUpdated, {
  414. part: text,
  415. })
  416. }
  417. text = undefined
  418. },
  419. async onChunk(input) {
  420. const value = input.chunk
  421. l.info("part", {
  422. type: value.type,
  423. })
  424. switch (value.type) {
  425. case "text":
  426. if (!text) {
  427. text = value
  428. next.parts.push(value)
  429. break
  430. } else text.text += value.text
  431. break
  432. case "tool-call":
  433. next.parts.push({
  434. type: "tool-invocation",
  435. toolInvocation: {
  436. state: "call",
  437. ...value,
  438. // hack until zod v4
  439. args: value.args as any,
  440. },
  441. })
  442. Bus.publish(Message.Event.PartUpdated, {
  443. part: next.parts[next.parts.length - 1],
  444. })
  445. break
  446. case "tool-call-streaming-start":
  447. next.parts.push({
  448. type: "tool-invocation",
  449. toolInvocation: {
  450. state: "partial-call",
  451. toolName: value.toolName,
  452. toolCallId: value.toolCallId,
  453. args: {},
  454. },
  455. })
  456. Bus.publish(Message.Event.PartUpdated, {
  457. part: next.parts[next.parts.length - 1],
  458. })
  459. break
  460. case "tool-call-delta":
  461. break
  462. case "tool-result":
  463. const match = next.parts.find(
  464. (p) =>
  465. p.type === "tool-invocation" &&
  466. p.toolInvocation.toolCallId === value.toolCallId,
  467. )
  468. if (match && match.type === "tool-invocation") {
  469. match.toolInvocation = {
  470. args: value.args,
  471. toolCallId: value.toolCallId,
  472. toolName: value.toolName,
  473. state: "result",
  474. result: value.result as string,
  475. }
  476. Bus.publish(Message.Event.PartUpdated, {
  477. part: match,
  478. })
  479. }
  480. break
  481. default:
  482. l.info("unhandled", {
  483. type: value.type,
  484. })
  485. }
  486. await updateMessage(next)
  487. },
  488. async onFinish(input) {
  489. const assistant = next.metadata!.assistant!
  490. const usage = getUsage(input.totalUsage, model.info)
  491. assistant.cost = usage.cost
  492. await updateMessage(next)
  493. },
  494. onError(err) {
  495. log.error("error", err)
  496. switch (true) {
  497. case LoadAPIKeyError.isInstance(err.error):
  498. next.metadata.error = new Provider.AuthError(
  499. {
  500. providerID: input.providerID,
  501. message: err.error.message,
  502. },
  503. { cause: err.error },
  504. ).toObject()
  505. break
  506. case err.error instanceof Error:
  507. next.metadata.error = new NamedError.Unknown(
  508. { message: err.error.toString() },
  509. { cause: err.error },
  510. ).toObject()
  511. break
  512. default:
  513. next.metadata.error = new NamedError.Unknown(
  514. { message: JSON.stringify(err.error) },
  515. { cause: err.error },
  516. )
  517. }
  518. Bus.publish(Event.Error, {
  519. error: next.metadata.error,
  520. })
  521. },
  522. async prepareStep(step) {
  523. next.parts.push({
  524. type: "step-start",
  525. })
  526. await updateMessage(next)
  527. return step
  528. },
  529. toolCallStreaming: false,
  530. abortSignal: abort.signal,
  531. stopWhen: stepCountIs(1000),
  532. messages: convertToModelMessages(msgs),
  533. temperature: model.info.id === "codex-mini-latest" ? undefined : 0,
  534. tools: {
  535. ...(await MCP.tools()),
  536. ...tools,
  537. },
  538. model: model.language,
  539. })
  540. await result.consumeStream({
  541. onError: (err) => {
  542. log.error("stream error", {
  543. err,
  544. })
  545. },
  546. })
  547. next.metadata!.time.completed = Date.now()
  548. for (const part of next.parts) {
  549. if (
  550. part.type === "tool-invocation" &&
  551. part.toolInvocation.state !== "result"
  552. ) {
  553. part.toolInvocation = {
  554. ...part.toolInvocation,
  555. state: "result",
  556. result: "request was aborted",
  557. }
  558. }
  559. }
  560. await updateMessage(next)
  561. return next
  562. }
  563. export async function summarize(input: {
  564. sessionID: string
  565. providerID: string
  566. modelID: string
  567. }) {
  568. using abort = lock(input.sessionID)
  569. const msgs = await messages(input.sessionID)
  570. const lastSummary = msgs.findLast(
  571. (msg) => msg.metadata.assistant?.summary === true,
  572. )?.id
  573. const filtered = msgs.filter(
  574. (msg) => msg.role !== "system" && (!lastSummary || msg.id >= lastSummary),
  575. )
  576. const model = await Provider.getModel(input.providerID, input.modelID)
  577. const next: Message.Info = {
  578. id: Identifier.ascending("message"),
  579. role: "assistant",
  580. parts: [],
  581. metadata: {
  582. tool: {},
  583. sessionID: input.sessionID,
  584. assistant: {
  585. summary: true,
  586. cost: 0,
  587. modelID: input.modelID,
  588. providerID: input.providerID,
  589. tokens: {
  590. input: 0,
  591. output: 0,
  592. reasoning: 0,
  593. },
  594. },
  595. time: {
  596. created: Date.now(),
  597. },
  598. },
  599. }
  600. await updateMessage(next)
  601. const result = await generateText({
  602. abortSignal: abort.signal,
  603. model: model.language,
  604. messages: convertToModelMessages([
  605. {
  606. role: "system",
  607. parts: [
  608. {
  609. type: "text",
  610. text: PROMPT_SUMMARIZE,
  611. },
  612. ],
  613. },
  614. ...filtered,
  615. {
  616. role: "user",
  617. parts: [
  618. {
  619. type: "text",
  620. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  621. },
  622. ],
  623. },
  624. ]),
  625. })
  626. next.parts.push({
  627. type: "text",
  628. text: result.text,
  629. })
  630. const assistant = next.metadata!.assistant!
  631. const usage = getUsage(result.usage, model.info)
  632. assistant.cost = usage.cost
  633. assistant.tokens = usage.tokens
  634. await updateMessage(next)
  635. }
  636. const pending = new Map<string, AbortController>()
  637. function lock(sessionID: string) {
  638. log.info("locking", { sessionID })
  639. if (pending.has(sessionID)) throw new BusyError(sessionID)
  640. const controller = new AbortController()
  641. pending.set(sessionID, controller)
  642. return {
  643. signal: controller.signal,
  644. [Symbol.dispose]() {
  645. log.info("unlocking", { sessionID })
  646. pending.delete(sessionID)
  647. },
  648. }
  649. }
  650. function getUsage(usage: LanguageModelUsage, model: Provider.Model) {
  651. const tokens = {
  652. input: usage.inputTokens ?? 0,
  653. output: usage.outputTokens ?? 0,
  654. reasoning: usage.reasoningTokens ?? 0,
  655. }
  656. return {
  657. cost: new Decimal(0)
  658. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  659. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  660. .toNumber(),
  661. tokens,
  662. }
  663. }
  664. export class BusyError extends Error {
  665. constructor(public readonly sessionID: string) {
  666. super(`Session ${sessionID} is busy`)
  667. }
  668. }
  669. export async function initialize(input: {
  670. sessionID: string
  671. modelID: string
  672. providerID: string
  673. }) {
  674. const app = App.info()
  675. await Session.chat({
  676. sessionID: input.sessionID,
  677. providerID: input.providerID,
  678. modelID: input.modelID,
  679. parts: [
  680. {
  681. type: "text",
  682. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  683. },
  684. ],
  685. })
  686. await App.initialize()
  687. }
  688. }