index.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514
  1. import { Slug } from "@opencode-ai/util/slug"
  2. import path from "path"
  3. import { BusEvent } from "@/bus/bus-event"
  4. import { Bus } from "@/bus"
  5. import { Decimal } from "decimal.js"
  6. import z from "zod"
  7. import { type LanguageModelUsage, type ProviderMetadata } from "ai"
  8. import { Config } from "../config/config"
  9. import { Flag } from "../flag/flag"
  10. import { Identifier } from "../id/id"
  11. import { Installation } from "../installation"
  12. import { Storage } from "../storage/storage"
  13. import { Log } from "../util/log"
  14. import { MessageV2 } from "./message-v2"
  15. import { Instance } from "../project/instance"
  16. import { SessionPrompt } from "./prompt"
  17. import { fn } from "@/util/fn"
  18. import { Command } from "../command"
  19. import { Snapshot } from "@/snapshot"
  20. import type { Provider } from "@/provider/provider"
  21. import { PermissionNext } from "@/permission/next"
  22. import { Global } from "@/global"
  23. export namespace Session {
  24. const log = Log.create({ service: "session" })
  25. const parentTitlePrefix = "New session - "
  26. const childTitlePrefix = "Child session - "
  27. function createDefaultTitle(isChild = false) {
  28. return (isChild ? childTitlePrefix : parentTitlePrefix) + new Date().toISOString()
  29. }
  30. export function isDefaultTitle(title: string) {
  31. return new RegExp(
  32. `^(${parentTitlePrefix}|${childTitlePrefix})\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}\\.\\d{3}Z$`,
  33. ).test(title)
  34. }
  35. function getForkedTitle(title: string): string {
  36. const match = title.match(/^(.+) \(fork #(\d+)\)$/)
  37. if (match) {
  38. const base = match[1]
  39. const num = parseInt(match[2], 10)
  40. return `${base} (fork #${num + 1})`
  41. }
  42. return `${title} (fork #1)`
  43. }
  44. export const Info = z
  45. .object({
  46. id: Identifier.schema("session"),
  47. slug: z.string(),
  48. projectID: z.string(),
  49. directory: z.string(),
  50. parentID: Identifier.schema("session").optional(),
  51. summary: z
  52. .object({
  53. additions: z.number(),
  54. deletions: z.number(),
  55. files: z.number(),
  56. diffs: Snapshot.FileDiff.array().optional(),
  57. })
  58. .optional(),
  59. share: z
  60. .object({
  61. url: z.string(),
  62. })
  63. .optional(),
  64. title: z.string(),
  65. version: z.string(),
  66. time: z.object({
  67. created: z.number(),
  68. updated: z.number(),
  69. compacting: z.number().optional(),
  70. archived: z.number().optional(),
  71. }),
  72. permission: PermissionNext.Ruleset.optional(),
  73. revert: z
  74. .object({
  75. messageID: z.string(),
  76. partID: z.string().optional(),
  77. snapshot: z.string().optional(),
  78. diff: z.string().optional(),
  79. })
  80. .optional(),
  81. })
  82. .meta({
  83. ref: "Session",
  84. })
  85. export type Info = z.output<typeof Info>
  86. export const ShareInfo = z
  87. .object({
  88. secret: z.string(),
  89. url: z.string(),
  90. })
  91. .meta({
  92. ref: "SessionShare",
  93. })
  94. export type ShareInfo = z.output<typeof ShareInfo>
  95. export const Event = {
  96. Created: BusEvent.define(
  97. "session.created",
  98. z.object({
  99. info: Info,
  100. }),
  101. ),
  102. Updated: BusEvent.define(
  103. "session.updated",
  104. z.object({
  105. info: Info,
  106. }),
  107. ),
  108. Deleted: BusEvent.define(
  109. "session.deleted",
  110. z.object({
  111. info: Info,
  112. }),
  113. ),
  114. Diff: BusEvent.define(
  115. "session.diff",
  116. z.object({
  117. sessionID: z.string(),
  118. diff: Snapshot.FileDiff.array(),
  119. }),
  120. ),
  121. Error: BusEvent.define(
  122. "session.error",
  123. z.object({
  124. sessionID: z.string().optional(),
  125. error: MessageV2.Assistant.shape.error,
  126. }),
  127. ),
  128. }
  129. export const create = fn(
  130. z
  131. .object({
  132. parentID: Identifier.schema("session").optional(),
  133. title: z.string().optional(),
  134. permission: Info.shape.permission,
  135. })
  136. .optional(),
  137. async (input) => {
  138. return createNext({
  139. parentID: input?.parentID,
  140. directory: Instance.directory,
  141. title: input?.title,
  142. permission: input?.permission,
  143. })
  144. },
  145. )
  146. export const fork = fn(
  147. z.object({
  148. sessionID: Identifier.schema("session"),
  149. messageID: Identifier.schema("message").optional(),
  150. }),
  151. async (input) => {
  152. const original = await get(input.sessionID)
  153. if (!original) throw new Error("session not found")
  154. const title = getForkedTitle(original.title)
  155. const session = await createNext({
  156. directory: Instance.directory,
  157. title,
  158. })
  159. const msgs = await messages({ sessionID: input.sessionID })
  160. const idMap = new Map<string, string>()
  161. for (const msg of msgs) {
  162. if (input.messageID && msg.info.id >= input.messageID) break
  163. const newID = Identifier.ascending("message")
  164. idMap.set(msg.info.id, newID)
  165. const parentID = msg.info.role === "assistant" && msg.info.parentID ? idMap.get(msg.info.parentID) : undefined
  166. const cloned = await updateMessage({
  167. ...msg.info,
  168. sessionID: session.id,
  169. id: newID,
  170. ...(parentID && { parentID }),
  171. })
  172. for (const part of msg.parts) {
  173. await updatePart({
  174. ...part,
  175. id: Identifier.ascending("part"),
  176. messageID: cloned.id,
  177. sessionID: session.id,
  178. })
  179. }
  180. }
  181. return session
  182. },
  183. )
  184. export const touch = fn(Identifier.schema("session"), async (sessionID) => {
  185. await update(sessionID, (draft) => {
  186. draft.time.updated = Date.now()
  187. })
  188. })
  189. export async function createNext(input: {
  190. id?: string
  191. title?: string
  192. parentID?: string
  193. directory: string
  194. permission?: PermissionNext.Ruleset
  195. }) {
  196. const result: Info = {
  197. id: Identifier.descending("session", input.id),
  198. slug: Slug.create(),
  199. version: Installation.VERSION,
  200. projectID: Instance.project.id,
  201. directory: input.directory,
  202. parentID: input.parentID,
  203. title: input.title ?? createDefaultTitle(!!input.parentID),
  204. permission: input.permission,
  205. time: {
  206. created: Date.now(),
  207. updated: Date.now(),
  208. },
  209. }
  210. log.info("created", result)
  211. await Storage.write(["session", Instance.project.id, result.id], result)
  212. Bus.publish(Event.Created, {
  213. info: result,
  214. })
  215. const cfg = await Config.get()
  216. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  217. share(result.id)
  218. .then((share) => {
  219. update(result.id, (draft) => {
  220. draft.share = share
  221. })
  222. })
  223. .catch(() => {
  224. // Silently ignore sharing errors during session creation
  225. })
  226. Bus.publish(Event.Updated, {
  227. info: result,
  228. })
  229. return result
  230. }
  231. export function plan(input: { slug: string; time: { created: number } }) {
  232. const base = Instance.project.vcs
  233. ? path.join(Instance.worktree, ".opencode", "plans")
  234. : path.join(Global.Path.data, "plans")
  235. return path.join(base, [input.time.created, input.slug].join("-") + ".md")
  236. }
  237. export const get = fn(Identifier.schema("session"), async (id) => {
  238. const read = await Storage.read<Info>(["session", Instance.project.id, id])
  239. return read as Info
  240. })
  241. export const getShare = fn(Identifier.schema("session"), async (id) => {
  242. return Storage.read<ShareInfo>(["share", id])
  243. })
  244. export const share = fn(Identifier.schema("session"), async (id) => {
  245. const cfg = await Config.get()
  246. if (cfg.share === "disabled") {
  247. throw new Error("Sharing is disabled in configuration")
  248. }
  249. const { ShareNext } = await import("@/share/share-next")
  250. const share = await ShareNext.create(id)
  251. await update(
  252. id,
  253. (draft) => {
  254. draft.share = {
  255. url: share.url,
  256. }
  257. },
  258. { touch: false },
  259. )
  260. return share
  261. })
  262. export const unshare = fn(Identifier.schema("session"), async (id) => {
  263. // Use ShareNext to remove the share (same as share function uses ShareNext to create)
  264. const { ShareNext } = await import("@/share/share-next")
  265. await ShareNext.remove(id)
  266. await update(
  267. id,
  268. (draft) => {
  269. draft.share = undefined
  270. },
  271. { touch: false },
  272. )
  273. })
  274. export async function update(id: string, editor: (session: Info) => void, options?: { touch?: boolean }) {
  275. const project = Instance.project
  276. const result = await Storage.update<Info>(["session", project.id, id], (draft) => {
  277. editor(draft)
  278. if (options?.touch !== false) {
  279. draft.time.updated = Date.now()
  280. }
  281. })
  282. Bus.publish(Event.Updated, {
  283. info: result,
  284. })
  285. return result
  286. }
  287. export const diff = fn(Identifier.schema("session"), async (sessionID) => {
  288. const diffs = await Storage.read<Snapshot.FileDiff[]>(["session_diff", sessionID])
  289. return diffs ?? []
  290. })
  291. export const messages = fn(
  292. z.object({
  293. sessionID: Identifier.schema("session"),
  294. limit: z.number().optional(),
  295. }),
  296. async (input) => {
  297. const result = [] as MessageV2.WithParts[]
  298. for await (const msg of MessageV2.stream(input.sessionID)) {
  299. if (input.limit && result.length >= input.limit) break
  300. result.push(msg)
  301. }
  302. result.reverse()
  303. return result
  304. },
  305. )
  306. export async function* list() {
  307. const project = Instance.project
  308. for (const item of await Storage.list(["session", project.id])) {
  309. yield Storage.read<Info>(item)
  310. }
  311. }
  312. export const children = fn(Identifier.schema("session"), async (parentID) => {
  313. const project = Instance.project
  314. const result = [] as Session.Info[]
  315. for (const item of await Storage.list(["session", project.id])) {
  316. const session = await Storage.read<Info>(item)
  317. if (session.parentID !== parentID) continue
  318. result.push(session)
  319. }
  320. return result
  321. })
  322. export const remove = fn(Identifier.schema("session"), async (sessionID) => {
  323. const project = Instance.project
  324. try {
  325. const session = await get(sessionID)
  326. for (const child of await children(sessionID)) {
  327. await remove(child.id)
  328. }
  329. await unshare(sessionID).catch(() => {})
  330. for (const msg of await Storage.list(["message", sessionID])) {
  331. for (const part of await Storage.list(["part", msg.at(-1)!])) {
  332. await Storage.remove(part)
  333. }
  334. await Storage.remove(msg)
  335. }
  336. await Storage.remove(["session", project.id, sessionID])
  337. Bus.publish(Event.Deleted, {
  338. info: session,
  339. })
  340. } catch (e) {
  341. log.error(e)
  342. }
  343. })
  344. export const updateMessage = fn(MessageV2.Info, async (msg) => {
  345. await Storage.write(["message", msg.sessionID, msg.id], msg)
  346. Bus.publish(MessageV2.Event.Updated, {
  347. info: msg,
  348. })
  349. return msg
  350. })
  351. export const removeMessage = fn(
  352. z.object({
  353. sessionID: Identifier.schema("session"),
  354. messageID: Identifier.schema("message"),
  355. }),
  356. async (input) => {
  357. await Storage.remove(["message", input.sessionID, input.messageID])
  358. Bus.publish(MessageV2.Event.Removed, {
  359. sessionID: input.sessionID,
  360. messageID: input.messageID,
  361. })
  362. return input.messageID
  363. },
  364. )
  365. export const removePart = fn(
  366. z.object({
  367. sessionID: Identifier.schema("session"),
  368. messageID: Identifier.schema("message"),
  369. partID: Identifier.schema("part"),
  370. }),
  371. async (input) => {
  372. await Storage.remove(["part", input.messageID, input.partID])
  373. Bus.publish(MessageV2.Event.PartRemoved, {
  374. sessionID: input.sessionID,
  375. messageID: input.messageID,
  376. partID: input.partID,
  377. })
  378. return input.partID
  379. },
  380. )
  381. const UpdatePartInput = z.union([
  382. MessageV2.Part,
  383. z.object({
  384. part: MessageV2.TextPart,
  385. delta: z.string(),
  386. }),
  387. z.object({
  388. part: MessageV2.ReasoningPart,
  389. delta: z.string(),
  390. }),
  391. ])
  392. export const updatePart = fn(UpdatePartInput, async (input) => {
  393. const part = "delta" in input ? input.part : input
  394. const delta = "delta" in input ? input.delta : undefined
  395. await Storage.write(["part", part.messageID, part.id], part)
  396. Bus.publish(MessageV2.Event.PartUpdated, {
  397. part,
  398. delta,
  399. })
  400. return part
  401. })
  402. export const getUsage = fn(
  403. z.object({
  404. model: z.custom<Provider.Model>(),
  405. usage: z.custom<LanguageModelUsage>(),
  406. metadata: z.custom<ProviderMetadata>().optional(),
  407. }),
  408. (input) => {
  409. const cacheReadInputTokens = input.usage.cachedInputTokens ?? 0
  410. const cacheWriteInputTokens = (input.metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  411. // @ts-expect-error
  412. input.metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  413. // @ts-expect-error
  414. input.metadata?.["venice"]?.["usage"]?.["cacheCreationInputTokens"] ??
  415. 0) as number
  416. const excludesCachedTokens = !!(input.metadata?.["anthropic"] || input.metadata?.["bedrock"])
  417. const adjustedInputTokens = excludesCachedTokens
  418. ? (input.usage.inputTokens ?? 0)
  419. : (input.usage.inputTokens ?? 0) - cacheReadInputTokens - cacheWriteInputTokens
  420. const safe = (value: number) => {
  421. if (!Number.isFinite(value)) return 0
  422. return value
  423. }
  424. const tokens = {
  425. input: safe(adjustedInputTokens),
  426. output: safe(input.usage.outputTokens ?? 0),
  427. reasoning: safe(input.usage?.reasoningTokens ?? 0),
  428. cache: {
  429. write: safe(cacheWriteInputTokens),
  430. read: safe(cacheReadInputTokens),
  431. },
  432. }
  433. const costInfo =
  434. input.model.cost?.experimentalOver200K && tokens.input + tokens.cache.read > 200_000
  435. ? input.model.cost.experimentalOver200K
  436. : input.model.cost
  437. return {
  438. cost: safe(
  439. new Decimal(0)
  440. .add(new Decimal(tokens.input).mul(costInfo?.input ?? 0).div(1_000_000))
  441. .add(new Decimal(tokens.output).mul(costInfo?.output ?? 0).div(1_000_000))
  442. .add(new Decimal(tokens.cache.read).mul(costInfo?.cache?.read ?? 0).div(1_000_000))
  443. .add(new Decimal(tokens.cache.write).mul(costInfo?.cache?.write ?? 0).div(1_000_000))
  444. // TODO: update models.dev to have better pricing model, for now:
  445. // charge reasoning tokens at the same rate as output tokens
  446. .add(new Decimal(tokens.reasoning).mul(costInfo?.output ?? 0).div(1_000_000))
  447. .toNumber(),
  448. ),
  449. tokens,
  450. }
  451. },
  452. )
  453. export class BusyError extends Error {
  454. constructor(public readonly sessionID: string) {
  455. super(`Session ${sessionID} is busy`)
  456. }
  457. }
  458. export const initialize = fn(
  459. z.object({
  460. sessionID: Identifier.schema("session"),
  461. modelID: z.string(),
  462. providerID: z.string(),
  463. messageID: Identifier.schema("message"),
  464. }),
  465. async (input) => {
  466. await SessionPrompt.command({
  467. sessionID: input.sessionID,
  468. messageID: input.messageID,
  469. model: input.providerID + "/" + input.modelID,
  470. command: Command.Default.INIT,
  471. arguments: "",
  472. })
  473. },
  474. )
  475. }