index.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. import { BusEvent } from "@/bus/bus-event"
  2. import { Bus } from "@/bus"
  3. import { Decimal } from "decimal.js"
  4. import z from "zod"
  5. import { type LanguageModelUsage, type ProviderMetadata } from "ai"
  6. import { Config } from "../config/config"
  7. import { Flag } from "../flag/flag"
  8. import { Identifier } from "../id/id"
  9. import { Installation } from "../installation"
  10. import { Storage } from "../storage/storage"
  11. import { Log } from "../util/log"
  12. import { MessageV2 } from "./message-v2"
  13. import { Instance } from "../project/instance"
  14. import { SessionPrompt } from "./prompt"
  15. import { fn } from "@/util/fn"
  16. import { Command } from "../command"
  17. import { Snapshot } from "@/snapshot"
  18. import type { Provider } from "@/provider/provider"
  19. import { PermissionNext } from "@/permission/next"
  20. export namespace Session {
  21. const log = Log.create({ service: "session" })
  22. const parentTitlePrefix = "New session - "
  23. const childTitlePrefix = "Child session - "
  24. function createDefaultTitle(isChild = false) {
  25. return (isChild ? childTitlePrefix : parentTitlePrefix) + new Date().toISOString()
  26. }
  27. export function isDefaultTitle(title: string) {
  28. return new RegExp(
  29. `^(${parentTitlePrefix}|${childTitlePrefix})\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}\\.\\d{3}Z$`,
  30. ).test(title)
  31. }
  32. export const Info = z
  33. .object({
  34. id: Identifier.schema("session"),
  35. projectID: z.string(),
  36. directory: z.string(),
  37. parentID: Identifier.schema("session").optional(),
  38. summary: z
  39. .object({
  40. additions: z.number(),
  41. deletions: z.number(),
  42. files: z.number(),
  43. diffs: Snapshot.FileDiff.array().optional(),
  44. })
  45. .optional(),
  46. share: z
  47. .object({
  48. url: z.string(),
  49. })
  50. .optional(),
  51. title: z.string(),
  52. version: z.string(),
  53. time: z.object({
  54. created: z.number(),
  55. updated: z.number(),
  56. compacting: z.number().optional(),
  57. archived: z.number().optional(),
  58. }),
  59. permission: PermissionNext.Ruleset.optional(),
  60. revert: z
  61. .object({
  62. messageID: z.string(),
  63. partID: z.string().optional(),
  64. snapshot: z.string().optional(),
  65. diff: z.string().optional(),
  66. })
  67. .optional(),
  68. })
  69. .meta({
  70. ref: "Session",
  71. })
  72. export type Info = z.output<typeof Info>
  73. export const ShareInfo = z
  74. .object({
  75. secret: z.string(),
  76. url: z.string(),
  77. })
  78. .meta({
  79. ref: "SessionShare",
  80. })
  81. export type ShareInfo = z.output<typeof ShareInfo>
  82. export const Event = {
  83. Created: BusEvent.define(
  84. "session.created",
  85. z.object({
  86. info: Info,
  87. }),
  88. ),
  89. Updated: BusEvent.define(
  90. "session.updated",
  91. z.object({
  92. info: Info,
  93. }),
  94. ),
  95. Deleted: BusEvent.define(
  96. "session.deleted",
  97. z.object({
  98. info: Info,
  99. }),
  100. ),
  101. Diff: BusEvent.define(
  102. "session.diff",
  103. z.object({
  104. sessionID: z.string(),
  105. diff: Snapshot.FileDiff.array(),
  106. }),
  107. ),
  108. Error: BusEvent.define(
  109. "session.error",
  110. z.object({
  111. sessionID: z.string().optional(),
  112. error: MessageV2.Assistant.shape.error,
  113. }),
  114. ),
  115. }
  116. export const create = fn(
  117. z
  118. .object({
  119. parentID: Identifier.schema("session").optional(),
  120. title: z.string().optional(),
  121. permission: Info.shape.permission,
  122. })
  123. .optional(),
  124. async (input) => {
  125. return createNext({
  126. parentID: input?.parentID,
  127. directory: Instance.directory,
  128. title: input?.title,
  129. permission: input?.permission,
  130. })
  131. },
  132. )
  133. export const fork = fn(
  134. z.object({
  135. sessionID: Identifier.schema("session"),
  136. messageID: Identifier.schema("message").optional(),
  137. }),
  138. async (input) => {
  139. const session = await createNext({
  140. directory: Instance.directory,
  141. })
  142. const msgs = await messages({ sessionID: input.sessionID })
  143. const idMap = new Map<string, string>()
  144. for (const msg of msgs) {
  145. if (input.messageID && msg.info.id >= input.messageID) break
  146. const newID = Identifier.ascending("message")
  147. idMap.set(msg.info.id, newID)
  148. const parentID = msg.info.role === "assistant" && msg.info.parentID ? idMap.get(msg.info.parentID) : undefined
  149. const cloned = await updateMessage({
  150. ...msg.info,
  151. sessionID: session.id,
  152. id: newID,
  153. ...(parentID && { parentID }),
  154. })
  155. for (const part of msg.parts) {
  156. await updatePart({
  157. ...part,
  158. id: Identifier.ascending("part"),
  159. messageID: cloned.id,
  160. sessionID: session.id,
  161. })
  162. }
  163. }
  164. return session
  165. },
  166. )
  167. export const touch = fn(Identifier.schema("session"), async (sessionID) => {
  168. await update(sessionID, (draft) => {
  169. draft.time.updated = Date.now()
  170. })
  171. })
  172. export async function createNext(input: {
  173. id?: string
  174. title?: string
  175. parentID?: string
  176. directory: string
  177. permission?: PermissionNext.Ruleset
  178. }) {
  179. const result: Info = {
  180. id: Identifier.descending("session", input.id),
  181. version: Installation.VERSION,
  182. projectID: Instance.project.id,
  183. directory: input.directory,
  184. parentID: input.parentID,
  185. title: input.title ?? createDefaultTitle(!!input.parentID),
  186. permission: input.permission,
  187. time: {
  188. created: Date.now(),
  189. updated: Date.now(),
  190. },
  191. }
  192. log.info("created", result)
  193. await Storage.write(["session", Instance.project.id, result.id], result)
  194. Bus.publish(Event.Created, {
  195. info: result,
  196. })
  197. const cfg = await Config.get()
  198. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  199. share(result.id)
  200. .then((share) => {
  201. update(result.id, (draft) => {
  202. draft.share = share
  203. })
  204. })
  205. .catch(() => {
  206. // Silently ignore sharing errors during session creation
  207. })
  208. Bus.publish(Event.Updated, {
  209. info: result,
  210. })
  211. return result
  212. }
  213. export const get = fn(Identifier.schema("session"), async (id) => {
  214. const read = await Storage.read<Info>(["session", Instance.project.id, id])
  215. return read as Info
  216. })
  217. export const getShare = fn(Identifier.schema("session"), async (id) => {
  218. return Storage.read<ShareInfo>(["share", id])
  219. })
  220. export const share = fn(Identifier.schema("session"), async (id) => {
  221. const cfg = await Config.get()
  222. if (cfg.share === "disabled") {
  223. throw new Error("Sharing is disabled in configuration")
  224. }
  225. const { ShareNext } = await import("@/share/share-next")
  226. const share = await ShareNext.create(id)
  227. await update(id, (draft) => {
  228. draft.share = {
  229. url: share.url,
  230. }
  231. })
  232. return share
  233. })
  234. export const unshare = fn(Identifier.schema("session"), async (id) => {
  235. // Use ShareNext to remove the share (same as share function uses ShareNext to create)
  236. const { ShareNext } = await import("@/share/share-next")
  237. await ShareNext.remove(id)
  238. await update(id, (draft) => {
  239. draft.share = undefined
  240. })
  241. })
  242. export async function update(id: string, editor: (session: Info) => void) {
  243. const project = Instance.project
  244. const result = await Storage.update<Info>(["session", project.id, id], (draft) => {
  245. editor(draft)
  246. draft.time.updated = Date.now()
  247. })
  248. Bus.publish(Event.Updated, {
  249. info: result,
  250. })
  251. return result
  252. }
  253. export const diff = fn(Identifier.schema("session"), async (sessionID) => {
  254. const diffs = await Storage.read<Snapshot.FileDiff[]>(["session_diff", sessionID])
  255. return diffs ?? []
  256. })
  257. export const messages = fn(
  258. z.object({
  259. sessionID: Identifier.schema("session"),
  260. limit: z.number().optional(),
  261. }),
  262. async (input) => {
  263. const result = [] as MessageV2.WithParts[]
  264. for await (const msg of MessageV2.stream(input.sessionID)) {
  265. if (input.limit && result.length >= input.limit) break
  266. result.push(msg)
  267. }
  268. result.reverse()
  269. return result
  270. },
  271. )
  272. export async function* list() {
  273. const project = Instance.project
  274. for (const item of await Storage.list(["session", project.id])) {
  275. yield Storage.read<Info>(item)
  276. }
  277. }
  278. export const children = fn(Identifier.schema("session"), async (parentID) => {
  279. const project = Instance.project
  280. const result = [] as Session.Info[]
  281. for (const item of await Storage.list(["session", project.id])) {
  282. const session = await Storage.read<Info>(item)
  283. if (session.parentID !== parentID) continue
  284. result.push(session)
  285. }
  286. return result
  287. })
  288. export const remove = fn(Identifier.schema("session"), async (sessionID) => {
  289. const project = Instance.project
  290. try {
  291. const session = await get(sessionID)
  292. for (const child of await children(sessionID)) {
  293. await remove(child.id)
  294. }
  295. await unshare(sessionID).catch(() => {})
  296. for (const msg of await Storage.list(["message", sessionID])) {
  297. for (const part of await Storage.list(["part", msg.at(-1)!])) {
  298. await Storage.remove(part)
  299. }
  300. await Storage.remove(msg)
  301. }
  302. await Storage.remove(["session", project.id, sessionID])
  303. Bus.publish(Event.Deleted, {
  304. info: session,
  305. })
  306. } catch (e) {
  307. log.error(e)
  308. }
  309. })
  310. export const updateMessage = fn(MessageV2.Info, async (msg) => {
  311. await Storage.write(["message", msg.sessionID, msg.id], msg)
  312. Bus.publish(MessageV2.Event.Updated, {
  313. info: msg,
  314. })
  315. return msg
  316. })
  317. export const removeMessage = fn(
  318. z.object({
  319. sessionID: Identifier.schema("session"),
  320. messageID: Identifier.schema("message"),
  321. }),
  322. async (input) => {
  323. await Storage.remove(["message", input.sessionID, input.messageID])
  324. Bus.publish(MessageV2.Event.Removed, {
  325. sessionID: input.sessionID,
  326. messageID: input.messageID,
  327. })
  328. return input.messageID
  329. },
  330. )
  331. export const removePart = fn(
  332. z.object({
  333. sessionID: Identifier.schema("session"),
  334. messageID: Identifier.schema("message"),
  335. partID: Identifier.schema("part"),
  336. }),
  337. async (input) => {
  338. await Storage.remove(["part", input.messageID, input.partID])
  339. Bus.publish(MessageV2.Event.PartRemoved, {
  340. sessionID: input.sessionID,
  341. messageID: input.messageID,
  342. partID: input.partID,
  343. })
  344. return input.partID
  345. },
  346. )
  347. const UpdatePartInput = z.union([
  348. MessageV2.Part,
  349. z.object({
  350. part: MessageV2.TextPart,
  351. delta: z.string(),
  352. }),
  353. z.object({
  354. part: MessageV2.ReasoningPart,
  355. delta: z.string(),
  356. }),
  357. ])
  358. export const updatePart = fn(UpdatePartInput, async (input) => {
  359. const part = "delta" in input ? input.part : input
  360. const delta = "delta" in input ? input.delta : undefined
  361. await Storage.write(["part", part.messageID, part.id], part)
  362. Bus.publish(MessageV2.Event.PartUpdated, {
  363. part,
  364. delta,
  365. })
  366. return part
  367. })
  368. export const getUsage = fn(
  369. z.object({
  370. model: z.custom<Provider.Model>(),
  371. usage: z.custom<LanguageModelUsage>(),
  372. metadata: z.custom<ProviderMetadata>().optional(),
  373. }),
  374. (input) => {
  375. const cachedInputTokens = input.usage.cachedInputTokens ?? 0
  376. const excludesCachedTokens = !!(input.metadata?.["anthropic"] || input.metadata?.["bedrock"])
  377. const adjustedInputTokens = excludesCachedTokens
  378. ? (input.usage.inputTokens ?? 0)
  379. : (input.usage.inputTokens ?? 0) - cachedInputTokens
  380. const safe = (value: number) => {
  381. if (!Number.isFinite(value)) return 0
  382. return value
  383. }
  384. const tokens = {
  385. input: safe(adjustedInputTokens),
  386. output: safe(input.usage.outputTokens ?? 0),
  387. reasoning: safe(input.usage?.reasoningTokens ?? 0),
  388. cache: {
  389. write: safe(
  390. (input.metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  391. // @ts-expect-error
  392. input.metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  393. 0) as number,
  394. ),
  395. read: safe(cachedInputTokens),
  396. },
  397. }
  398. const costInfo =
  399. input.model.cost?.experimentalOver200K && tokens.input + tokens.cache.read > 200_000
  400. ? input.model.cost.experimentalOver200K
  401. : input.model.cost
  402. return {
  403. cost: safe(
  404. new Decimal(0)
  405. .add(new Decimal(tokens.input).mul(costInfo?.input ?? 0).div(1_000_000))
  406. .add(new Decimal(tokens.output).mul(costInfo?.output ?? 0).div(1_000_000))
  407. .add(new Decimal(tokens.cache.read).mul(costInfo?.cache?.read ?? 0).div(1_000_000))
  408. .add(new Decimal(tokens.cache.write).mul(costInfo?.cache?.write ?? 0).div(1_000_000))
  409. // TODO: update models.dev to have better pricing model, for now:
  410. // charge reasoning tokens at the same rate as output tokens
  411. .add(new Decimal(tokens.reasoning).mul(costInfo?.output ?? 0).div(1_000_000))
  412. .toNumber(),
  413. ),
  414. tokens,
  415. }
  416. },
  417. )
  418. export class BusyError extends Error {
  419. constructor(public readonly sessionID: string) {
  420. super(`Session ${sessionID} is busy`)
  421. }
  422. }
  423. export const initialize = fn(
  424. z.object({
  425. sessionID: Identifier.schema("session"),
  426. modelID: z.string(),
  427. providerID: z.string(),
  428. messageID: Identifier.schema("message"),
  429. }),
  430. async (input) => {
  431. await SessionPrompt.command({
  432. sessionID: input.sessionID,
  433. messageID: input.messageID,
  434. model: input.providerID + "/" + input.modelID,
  435. command: Command.Default.INIT,
  436. arguments: "",
  437. })
  438. },
  439. )
  440. }