index.ts 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. import path from "path"
  2. import { App } from "../app/app"
  3. import { Identifier } from "../id/id"
  4. import { Storage } from "../storage/storage"
  5. import { Log } from "../util/log"
  6. import {
  7. convertToModelMessages,
  8. generateText,
  9. LoadAPIKeyError,
  10. stepCountIs,
  11. streamText,
  12. tool,
  13. type Tool as AITool,
  14. type LanguageModelUsage,
  15. type UIMessage,
  16. } from "ai"
  17. import { z, ZodSchema } from "zod"
  18. import { Decimal } from "decimal.js"
  19. import PROMPT_INITIALIZE from "../session/prompt/initialize.txt"
  20. import { Share } from "../share/share"
  21. import { Message } from "./message"
  22. import { Bus } from "../bus"
  23. import { Provider } from "../provider/provider"
  24. import { MCP } from "../mcp"
  25. import { NamedError } from "../util/error"
  26. import type { Tool } from "../tool/tool"
  27. import { SystemPrompt } from "./system"
  28. import { Flag } from "../flag/flag"
  29. export namespace Session {
  30. const log = Log.create({ service: "session" })
  31. export const Info = z
  32. .object({
  33. id: Identifier.schema("session"),
  34. parentID: Identifier.schema("session").optional(),
  35. share: z
  36. .object({
  37. secret: z.string(),
  38. url: z.string(),
  39. })
  40. .optional(),
  41. title: z.string(),
  42. time: z.object({
  43. created: z.number(),
  44. updated: z.number(),
  45. }),
  46. })
  47. .openapi({
  48. ref: "session.info",
  49. })
  50. export type Info = z.output<typeof Info>
  51. export const Event = {
  52. Updated: Bus.event(
  53. "session.updated",
  54. z.object({
  55. info: Info,
  56. }),
  57. ),
  58. Error: Bus.event(
  59. "session.error",
  60. z.object({
  61. error: Message.Info.shape.metadata.shape.error,
  62. }),
  63. ),
  64. }
  65. const state = App.state("session", () => {
  66. const sessions = new Map<string, Info>()
  67. const messages = new Map<string, Message.Info[]>()
  68. return {
  69. sessions,
  70. messages,
  71. }
  72. })
  73. export async function create(parentID?: string) {
  74. const result: Info = {
  75. id: Identifier.descending("session"),
  76. parentID,
  77. title:
  78. (parentID ? "Child session - " : "New Session - ") +
  79. new Date().toISOString(),
  80. time: {
  81. created: Date.now(),
  82. updated: Date.now(),
  83. },
  84. }
  85. log.info("created", result)
  86. state().sessions.set(result.id, result)
  87. await Storage.writeJSON("session/info/" + result.id, result)
  88. if (!result.parentID && Flag.OPENCODE_AUTO_SHARE)
  89. share(result.id).then((share) => {
  90. update(result.id, (draft) => {
  91. draft.share = share
  92. })
  93. })
  94. Bus.publish(Event.Updated, {
  95. info: result,
  96. })
  97. return result
  98. }
  99. export async function get(id: string) {
  100. const result = state().sessions.get(id)
  101. if (result) {
  102. return result
  103. }
  104. const read = await Storage.readJSON<Info>("session/info/" + id)
  105. state().sessions.set(id, read)
  106. return read as Info
  107. }
  108. export async function share(id: string) {
  109. const session = await get(id)
  110. if (session.share) return session.share
  111. const share = await Share.create(id)
  112. await update(id, (draft) => {
  113. draft.share = share
  114. })
  115. for (const msg of await messages(id)) {
  116. await Share.sync("session/message/" + id + "/" + msg.id, msg)
  117. }
  118. return share
  119. }
  120. export async function update(id: string, editor: (session: Info) => void) {
  121. const { sessions } = state()
  122. const session = await get(id)
  123. if (!session) return
  124. editor(session)
  125. session.time.updated = Date.now()
  126. sessions.set(id, session)
  127. await Storage.writeJSON("session/info/" + id, session)
  128. Bus.publish(Event.Updated, {
  129. info: session,
  130. })
  131. return session
  132. }
  133. export async function messages(sessionID: string) {
  134. const result = [] as Message.Info[]
  135. const list = Storage.list("session/message/" + sessionID)
  136. for await (const p of list) {
  137. const read = await Storage.readJSON<Message.Info>(p)
  138. result.push(read)
  139. }
  140. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  141. return result
  142. }
  143. export async function getMessage(sessionID: string, messageID: string) {
  144. return Storage.readJSON<Message.Info>(
  145. "session/message/" + sessionID + "/" + messageID,
  146. )
  147. }
  148. export async function* list() {
  149. for await (const item of Storage.list("session/info")) {
  150. const sessionID = path.basename(item, ".json")
  151. yield get(sessionID)
  152. }
  153. }
  154. export function abort(sessionID: string) {
  155. const controller = pending.get(sessionID)
  156. if (!controller) return false
  157. controller.abort()
  158. pending.delete(sessionID)
  159. return true
  160. }
  161. async function updateMessage(msg: Message.Info) {
  162. await Storage.writeJSON(
  163. "session/message/" + msg.metadata.sessionID + "/" + msg.id,
  164. msg,
  165. )
  166. Bus.publish(Message.Event.Updated, {
  167. info: msg,
  168. })
  169. }
  170. export async function chat(input: {
  171. sessionID: string
  172. providerID: string
  173. modelID: string
  174. parts: Message.Part[]
  175. system?: string[]
  176. tools?: Tool.Info[]
  177. }) {
  178. const l = log.clone().tag("session", input.sessionID)
  179. l.info("chatting")
  180. const model = await Provider.getModel(input.providerID, input.modelID)
  181. let msgs = await messages(input.sessionID)
  182. const previous = msgs.at(-1)
  183. // auto summarize if too long
  184. if (previous?.metadata.assistant) {
  185. const tokens =
  186. previous.metadata.assistant.tokens.input +
  187. previous.metadata.assistant.tokens.output
  188. if (
  189. tokens >
  190. (model.info.limit.context - (model.info.limit.output ?? 0)) * 0.9
  191. ) {
  192. await summarize({
  193. sessionID: input.sessionID,
  194. providerID: input.providerID,
  195. modelID: input.modelID,
  196. })
  197. return chat(input)
  198. }
  199. }
  200. using abort = lock(input.sessionID)
  201. const lastSummary = msgs.findLast(
  202. (msg) => msg.metadata.assistant?.summary === true,
  203. )
  204. if (lastSummary) msgs = msgs.filter((msg) => msg.id >= lastSummary.id)
  205. const app = App.info()
  206. const session = await get(input.sessionID)
  207. if (msgs.length === 0 && !session.parentID) {
  208. generateText({
  209. maxOutputTokens: 20,
  210. messages: convertToModelMessages([
  211. ...SystemPrompt.title(input.providerID).map(
  212. (x): UIMessage => ({
  213. id: Identifier.ascending("message"),
  214. role: "system",
  215. parts: [
  216. {
  217. type: "text",
  218. text: x,
  219. },
  220. ],
  221. }),
  222. ),
  223. {
  224. role: "user",
  225. parts: input.parts,
  226. },
  227. ]),
  228. temperature: 0,
  229. model: model.language,
  230. })
  231. .then((result) => {
  232. return Session.update(input.sessionID, (draft) => {
  233. draft.title = result.text
  234. })
  235. })
  236. .catch(() => {})
  237. }
  238. const msg: Message.Info = {
  239. role: "user",
  240. id: Identifier.ascending("message"),
  241. parts: input.parts,
  242. metadata: {
  243. time: {
  244. created: Date.now(),
  245. },
  246. sessionID: input.sessionID,
  247. tool: {},
  248. },
  249. }
  250. await updateMessage(msg)
  251. msgs.push(msg)
  252. const system = input.system ?? SystemPrompt.provider(input.providerID)
  253. system.push(...(await SystemPrompt.environment(input.sessionID)))
  254. system.push(...(await SystemPrompt.custom()))
  255. const next: Message.Info = {
  256. id: Identifier.ascending("message"),
  257. role: "assistant",
  258. parts: [],
  259. metadata: {
  260. assistant: {
  261. system,
  262. path: {
  263. cwd: app.path.cwd,
  264. root: app.path.root,
  265. },
  266. cost: 0,
  267. tokens: {
  268. input: 0,
  269. output: 0,
  270. reasoning: 0,
  271. },
  272. modelID: input.modelID,
  273. providerID: input.providerID,
  274. },
  275. time: {
  276. created: Date.now(),
  277. },
  278. sessionID: input.sessionID,
  279. tool: {},
  280. },
  281. }
  282. await updateMessage(next)
  283. const tools: Record<string, AITool> = {}
  284. for (const item of await Provider.tools(input.providerID)) {
  285. tools[item.id.replaceAll(".", "_")] = tool({
  286. id: item.id as any,
  287. description: item.description,
  288. parameters: item.parameters as ZodSchema,
  289. async execute(args, opts) {
  290. const start = Date.now()
  291. try {
  292. const result = await item.execute(args, {
  293. sessionID: input.sessionID,
  294. abort: abort.signal,
  295. messageID: next.id,
  296. })
  297. next.metadata!.tool![opts.toolCallId] = {
  298. ...result.metadata,
  299. time: {
  300. start,
  301. end: Date.now(),
  302. },
  303. }
  304. await updateMessage(next)
  305. return result.output
  306. } catch (e: any) {
  307. next.metadata!.tool![opts.toolCallId] = {
  308. error: true,
  309. message: e.toString(),
  310. title: e.toString(),
  311. time: {
  312. start,
  313. end: Date.now(),
  314. },
  315. }
  316. await updateMessage(next)
  317. return e.toString()
  318. }
  319. },
  320. })
  321. }
  322. for (const [key, item] of Object.entries(await MCP.tools())) {
  323. const execute = item.execute
  324. if (!execute) continue
  325. item.execute = async (args, opts) => {
  326. const start = Date.now()
  327. try {
  328. const result = await execute(args, opts)
  329. next.metadata!.tool![opts.toolCallId] = {
  330. ...result.metadata,
  331. time: {
  332. start,
  333. end: Date.now(),
  334. },
  335. }
  336. await updateMessage(next)
  337. return result.content
  338. .filter((x: any) => x.type === "text")
  339. .map((x: any) => x.text)
  340. .join("\n\n")
  341. } catch (e: any) {
  342. next.metadata!.tool![opts.toolCallId] = {
  343. error: true,
  344. message: e.toString(),
  345. title: "mcp",
  346. time: {
  347. start,
  348. end: Date.now(),
  349. },
  350. }
  351. await updateMessage(next)
  352. return e.toString()
  353. }
  354. }
  355. tools[key] = item
  356. }
  357. let text: Message.TextPart | undefined
  358. const result = streamText({
  359. onStepFinish: async (step) => {
  360. log.info("step finish", {
  361. finishReason: step.finishReason,
  362. })
  363. const assistant = next.metadata!.assistant!
  364. const usage = getUsage(step.usage, model.info)
  365. assistant.cost += usage.cost
  366. assistant.tokens = usage.tokens
  367. await updateMessage(next)
  368. if (text) {
  369. Bus.publish(Message.Event.PartUpdated, {
  370. part: text,
  371. messageID: next.id,
  372. sessionID: next.metadata.sessionID,
  373. })
  374. }
  375. text = undefined
  376. },
  377. async onChunk(input) {
  378. const value = input.chunk
  379. l.info("part", {
  380. type: value.type,
  381. })
  382. switch (value.type) {
  383. case "text":
  384. if (!text) {
  385. text = value
  386. next.parts.push(value)
  387. break
  388. } else text.text += value.text
  389. break
  390. case "tool-call": {
  391. const [match] = next.parts.flatMap((p) =>
  392. p.type === "tool-invocation" &&
  393. p.toolInvocation.toolCallId === value.toolCallId
  394. ? [p]
  395. : [],
  396. )
  397. if (!match) break
  398. match.toolInvocation.args = value.args
  399. match.toolInvocation.state = "call"
  400. Bus.publish(Message.Event.PartUpdated, {
  401. part: match,
  402. messageID: next.id,
  403. sessionID: next.metadata.sessionID,
  404. })
  405. break
  406. }
  407. case "tool-call-streaming-start":
  408. next.parts.push({
  409. type: "tool-invocation",
  410. toolInvocation: {
  411. state: "partial-call",
  412. toolName: value.toolName,
  413. toolCallId: value.toolCallId,
  414. args: {},
  415. },
  416. })
  417. Bus.publish(Message.Event.PartUpdated, {
  418. part: next.parts[next.parts.length - 1],
  419. messageID: next.id,
  420. sessionID: next.metadata.sessionID,
  421. })
  422. break
  423. case "tool-call-delta":
  424. break
  425. case "tool-result":
  426. const match = next.parts.find(
  427. (p) =>
  428. p.type === "tool-invocation" &&
  429. p.toolInvocation.toolCallId === value.toolCallId,
  430. )
  431. if (match && match.type === "tool-invocation") {
  432. match.toolInvocation = {
  433. args: value.args,
  434. toolCallId: value.toolCallId,
  435. toolName: value.toolName,
  436. state: "result",
  437. result: value.result as string,
  438. }
  439. Bus.publish(Message.Event.PartUpdated, {
  440. part: match,
  441. messageID: next.id,
  442. sessionID: next.metadata.sessionID,
  443. })
  444. }
  445. break
  446. default:
  447. l.info("unhandled", {
  448. type: value.type,
  449. })
  450. }
  451. await updateMessage(next)
  452. },
  453. async onFinish(input) {
  454. const assistant = next.metadata!.assistant!
  455. const usage = getUsage(input.totalUsage, model.info)
  456. assistant.cost = usage.cost
  457. await updateMessage(next)
  458. },
  459. onError(err) {
  460. log.error("error", err)
  461. switch (true) {
  462. case LoadAPIKeyError.isInstance(err.error):
  463. next.metadata.error = new Provider.AuthError(
  464. {
  465. providerID: input.providerID,
  466. message: err.error.message,
  467. },
  468. { cause: err.error },
  469. ).toObject()
  470. break
  471. case err.error instanceof Error:
  472. next.metadata.error = new NamedError.Unknown(
  473. { message: err.error.toString() },
  474. { cause: err.error },
  475. ).toObject()
  476. break
  477. default:
  478. next.metadata.error = new NamedError.Unknown(
  479. { message: JSON.stringify(err.error) },
  480. { cause: err.error },
  481. )
  482. }
  483. Bus.publish(Event.Error, {
  484. error: next.metadata.error,
  485. })
  486. },
  487. async prepareStep(step) {
  488. next.parts.push({
  489. type: "step-start",
  490. })
  491. await updateMessage(next)
  492. return step
  493. },
  494. toolCallStreaming: true,
  495. abortSignal: abort.signal,
  496. stopWhen: stepCountIs(1000),
  497. messages: convertToModelMessages([
  498. ...system.map(
  499. (x): UIMessage => ({
  500. id: Identifier.ascending("message"),
  501. role: "system",
  502. parts: [
  503. {
  504. type: "text",
  505. text: x,
  506. },
  507. ],
  508. }),
  509. ),
  510. ...msgs,
  511. ]),
  512. temperature: model.info.id === "codex-mini-latest" ? undefined : 0,
  513. tools: {
  514. ...(await MCP.tools()),
  515. ...tools,
  516. },
  517. model: model.language,
  518. })
  519. await result.consumeStream({
  520. onError: (err) => {
  521. log.error("stream error", {
  522. err,
  523. })
  524. },
  525. })
  526. next.metadata!.time.completed = Date.now()
  527. for (const part of next.parts) {
  528. if (
  529. part.type === "tool-invocation" &&
  530. part.toolInvocation.state !== "result"
  531. ) {
  532. part.toolInvocation = {
  533. ...part.toolInvocation,
  534. state: "result",
  535. result: "request was aborted",
  536. }
  537. }
  538. }
  539. await updateMessage(next)
  540. return next
  541. }
  542. export async function summarize(input: {
  543. sessionID: string
  544. providerID: string
  545. modelID: string
  546. }) {
  547. using abort = lock(input.sessionID)
  548. const msgs = await messages(input.sessionID)
  549. const lastSummary = msgs.findLast(
  550. (msg) => msg.metadata.assistant?.summary === true,
  551. )?.id
  552. const filtered = msgs.filter((msg) => !lastSummary || msg.id >= lastSummary)
  553. const model = await Provider.getModel(input.providerID, input.modelID)
  554. const app = App.info()
  555. const system = SystemPrompt.summarize(input.providerID)
  556. const next: Message.Info = {
  557. id: Identifier.ascending("message"),
  558. role: "assistant",
  559. parts: [],
  560. metadata: {
  561. tool: {},
  562. sessionID: input.sessionID,
  563. assistant: {
  564. system,
  565. path: {
  566. cwd: app.path.cwd,
  567. root: app.path.root,
  568. },
  569. summary: true,
  570. cost: 0,
  571. modelID: input.modelID,
  572. providerID: input.providerID,
  573. tokens: {
  574. input: 0,
  575. output: 0,
  576. reasoning: 0,
  577. },
  578. },
  579. time: {
  580. created: Date.now(),
  581. },
  582. },
  583. }
  584. await updateMessage(next)
  585. const result = await generateText({
  586. abortSignal: abort.signal,
  587. model: model.language,
  588. messages: convertToModelMessages([
  589. ...system.map(
  590. (x): UIMessage => ({
  591. id: Identifier.ascending("message"),
  592. role: "system",
  593. parts: [
  594. {
  595. type: "text",
  596. text: x,
  597. },
  598. ],
  599. }),
  600. ),
  601. ...filtered,
  602. {
  603. role: "user",
  604. parts: [
  605. {
  606. type: "text",
  607. text: "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next.",
  608. },
  609. ],
  610. },
  611. ]),
  612. })
  613. next.parts.push({
  614. type: "text",
  615. text: result.text,
  616. })
  617. const assistant = next.metadata!.assistant!
  618. const usage = getUsage(result.usage, model.info)
  619. assistant.cost = usage.cost
  620. assistant.tokens = usage.tokens
  621. await updateMessage(next)
  622. }
  623. const pending = new Map<string, AbortController>()
  624. function lock(sessionID: string) {
  625. log.info("locking", { sessionID })
  626. if (pending.has(sessionID)) throw new BusyError(sessionID)
  627. const controller = new AbortController()
  628. pending.set(sessionID, controller)
  629. return {
  630. signal: controller.signal,
  631. [Symbol.dispose]() {
  632. log.info("unlocking", { sessionID })
  633. pending.delete(sessionID)
  634. },
  635. }
  636. }
  637. function getUsage(usage: LanguageModelUsage, model: Provider.Model) {
  638. const tokens = {
  639. input: usage.inputTokens ?? 0,
  640. output: usage.outputTokens ?? 0,
  641. reasoning: usage.reasoningTokens ?? 0,
  642. }
  643. return {
  644. cost: new Decimal(0)
  645. .add(new Decimal(tokens.input).mul(model.cost.input).div(1_000_000))
  646. .add(new Decimal(tokens.output).mul(model.cost.output).div(1_000_000))
  647. .toNumber(),
  648. tokens,
  649. }
  650. }
  651. export class BusyError extends Error {
  652. constructor(public readonly sessionID: string) {
  653. super(`Session ${sessionID} is busy`)
  654. }
  655. }
  656. export async function initialize(input: {
  657. sessionID: string
  658. modelID: string
  659. providerID: string
  660. }) {
  661. const app = App.info()
  662. await Session.chat({
  663. sessionID: input.sessionID,
  664. providerID: input.providerID,
  665. modelID: input.modelID,
  666. parts: [
  667. {
  668. type: "text",
  669. text: PROMPT_INITIALIZE.replace("${path}", app.path.root),
  670. },
  671. ],
  672. })
  673. await App.initialize()
  674. }
  675. }