index.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. import { Decimal } from "decimal.js"
  2. import z from "zod"
  3. import { type LanguageModelUsage, type ProviderMetadata } from "ai"
  4. import { Bus } from "../bus"
  5. import { Config } from "../config/config"
  6. import { Flag } from "../flag/flag"
  7. import { Identifier } from "../id/id"
  8. import { Installation } from "../installation"
  9. import type { ModelsDev } from "../provider/models"
  10. import { Share } from "../share/share"
  11. import { Storage } from "../storage/storage"
  12. import { Log } from "../util/log"
  13. import { MessageV2 } from "./message-v2"
  14. import { Instance } from "../project/instance"
  15. import { SessionPrompt } from "./prompt"
  16. import { fn } from "@/util/fn"
  17. import { Command } from "../command"
  18. import { Snapshot } from "@/snapshot"
  19. export namespace Session {
  20. const log = Log.create({ service: "session" })
  21. const parentTitlePrefix = "New session - "
  22. const childTitlePrefix = "Child session - "
  23. function createDefaultTitle(isChild = false) {
  24. return (isChild ? childTitlePrefix : parentTitlePrefix) + new Date().toISOString()
  25. }
  26. export function isDefaultTitle(title: string) {
  27. return new RegExp(
  28. `^(${parentTitlePrefix}|${childTitlePrefix})\\d{4}-\\d{2}-\\d{2}T\\d{2}:\\d{2}:\\d{2}\\.\\d{3}Z$`,
  29. ).test(title)
  30. }
  31. export const Info = z
  32. .object({
  33. id: Identifier.schema("session"),
  34. projectID: z.string(),
  35. directory: z.string(),
  36. parentID: Identifier.schema("session").optional(),
  37. summary: z
  38. .object({
  39. additions: z.number(),
  40. deletions: z.number(),
  41. diffs: Snapshot.FileDiff.array().optional(),
  42. })
  43. .optional(),
  44. share: z
  45. .object({
  46. url: z.string(),
  47. })
  48. .optional(),
  49. title: z.string(),
  50. version: z.string(),
  51. time: z.object({
  52. created: z.number(),
  53. updated: z.number(),
  54. compacting: z.number().optional(),
  55. }),
  56. revert: z
  57. .object({
  58. messageID: z.string(),
  59. partID: z.string().optional(),
  60. snapshot: z.string().optional(),
  61. diff: z.string().optional(),
  62. })
  63. .optional(),
  64. })
  65. .meta({
  66. ref: "Session",
  67. })
  68. export type Info = z.output<typeof Info>
  69. export const ShareInfo = z
  70. .object({
  71. secret: z.string(),
  72. url: z.string(),
  73. })
  74. .meta({
  75. ref: "SessionShare",
  76. })
  77. export type ShareInfo = z.output<typeof ShareInfo>
  78. export const Event = {
  79. Created: Bus.event(
  80. "session.created",
  81. z.object({
  82. info: Info,
  83. }),
  84. ),
  85. Updated: Bus.event(
  86. "session.updated",
  87. z.object({
  88. info: Info,
  89. }),
  90. ),
  91. Deleted: Bus.event(
  92. "session.deleted",
  93. z.object({
  94. info: Info,
  95. }),
  96. ),
  97. Error: Bus.event(
  98. "session.error",
  99. z.object({
  100. sessionID: z.string().optional(),
  101. error: MessageV2.Assistant.shape.error,
  102. }),
  103. ),
  104. }
  105. export const create = fn(
  106. z
  107. .object({
  108. parentID: Identifier.schema("session").optional(),
  109. title: z.string().optional(),
  110. })
  111. .optional(),
  112. async (input) => {
  113. return createNext({
  114. parentID: input?.parentID,
  115. directory: Instance.directory,
  116. title: input?.title,
  117. })
  118. },
  119. )
  120. export const fork = fn(
  121. z.object({
  122. sessionID: Identifier.schema("session"),
  123. messageID: Identifier.schema("message").optional(),
  124. }),
  125. async (input) => {
  126. const session = await createNext({
  127. directory: Instance.directory,
  128. })
  129. const msgs = await messages(input.sessionID)
  130. for (const msg of msgs) {
  131. if (input.messageID && msg.info.id >= input.messageID) break
  132. const cloned = await updateMessage({
  133. ...msg.info,
  134. sessionID: session.id,
  135. id: Identifier.ascending("message"),
  136. })
  137. for (const part of msg.parts) {
  138. await updatePart({
  139. ...part,
  140. id: Identifier.ascending("part"),
  141. messageID: cloned.id,
  142. sessionID: session.id,
  143. })
  144. }
  145. }
  146. return session
  147. },
  148. )
  149. export const touch = fn(Identifier.schema("session"), async (sessionID) => {
  150. await update(sessionID, (draft) => {
  151. draft.time.updated = Date.now()
  152. })
  153. })
  154. export async function createNext(input: {
  155. id?: string
  156. title?: string
  157. parentID?: string
  158. directory: string
  159. }) {
  160. const result: Info = {
  161. id: Identifier.descending("session", input.id),
  162. version: Installation.VERSION,
  163. projectID: Instance.project.id,
  164. directory: input.directory,
  165. parentID: input.parentID,
  166. title: input.title ?? createDefaultTitle(!!input.parentID),
  167. time: {
  168. created: Date.now(),
  169. updated: Date.now(),
  170. },
  171. }
  172. log.info("created", result)
  173. await Storage.write(["session", Instance.project.id, result.id], result)
  174. Bus.publish(Event.Created, {
  175. info: result,
  176. })
  177. const cfg = await Config.get()
  178. if (!result.parentID && (Flag.OPENCODE_AUTO_SHARE || cfg.share === "auto"))
  179. share(result.id)
  180. .then((share) => {
  181. update(result.id, (draft) => {
  182. draft.share = share
  183. })
  184. })
  185. .catch(() => {
  186. // Silently ignore sharing errors during session creation
  187. })
  188. Bus.publish(Event.Updated, {
  189. info: result,
  190. })
  191. return result
  192. }
  193. export const get = fn(Identifier.schema("session"), async (id) => {
  194. const read = await Storage.read<Info>(["session", Instance.project.id, id])
  195. return read as Info
  196. })
  197. export const getShare = fn(Identifier.schema("session"), async (id) => {
  198. return Storage.read<ShareInfo>(["share", id])
  199. })
  200. export const share = fn(Identifier.schema("session"), async (id) => {
  201. const cfg = await Config.get()
  202. if (cfg.share === "disabled") {
  203. throw new Error("Sharing is disabled in configuration")
  204. }
  205. const session = await get(id)
  206. if (session.share) return session.share
  207. const share = await Share.create(id)
  208. await update(id, (draft) => {
  209. draft.share = {
  210. url: share.url,
  211. }
  212. })
  213. await Storage.write(["share", id], share)
  214. await Share.sync("session/info/" + id, session)
  215. for (const msg of await messages(id)) {
  216. await Share.sync("session/message/" + id + "/" + msg.info.id, msg.info)
  217. for (const part of msg.parts) {
  218. await Share.sync("session/part/" + id + "/" + msg.info.id + "/" + part.id, part)
  219. }
  220. }
  221. return share
  222. })
  223. export const unshare = fn(Identifier.schema("session"), async (id) => {
  224. const share = await getShare(id)
  225. if (!share) return
  226. await Storage.remove(["share", id])
  227. await update(id, (draft) => {
  228. draft.share = undefined
  229. })
  230. await Share.remove(id, share.secret)
  231. })
  232. export async function update(id: string, editor: (session: Info) => void) {
  233. const project = Instance.project
  234. const result = await Storage.update<Info>(["session", project.id, id], (draft) => {
  235. editor(draft)
  236. draft.time.updated = Date.now()
  237. })
  238. Bus.publish(Event.Updated, {
  239. info: result,
  240. })
  241. return result
  242. }
  243. export const diff = fn(Identifier.schema("session"), async (sessionID) => {
  244. const diffs = await Storage.read<Snapshot.FileDiff[]>(["session_diff", sessionID])
  245. return diffs ?? []
  246. })
  247. export const messages = fn(Identifier.schema("session"), async (sessionID) => {
  248. const result = [] as MessageV2.WithParts[]
  249. for (const p of await Storage.list(["message", sessionID])) {
  250. const read = await Storage.read<MessageV2.Info>(p)
  251. result.push({
  252. info: read,
  253. parts: await getParts(read.id),
  254. })
  255. }
  256. result.sort((a, b) => (a.info.id > b.info.id ? 1 : -1))
  257. return result
  258. })
  259. export const getMessage = fn(
  260. z.object({
  261. sessionID: Identifier.schema("session"),
  262. messageID: Identifier.schema("message"),
  263. }),
  264. async (input) => {
  265. return {
  266. info: await Storage.read<MessageV2.Info>(["message", input.sessionID, input.messageID]),
  267. parts: await getParts(input.messageID),
  268. }
  269. },
  270. )
  271. export const getParts = fn(Identifier.schema("message"), async (messageID) => {
  272. const result = [] as MessageV2.Part[]
  273. for (const item of await Storage.list(["part", messageID])) {
  274. const read = await Storage.read<MessageV2.Part>(item)
  275. result.push(read)
  276. }
  277. result.sort((a, b) => (a.id > b.id ? 1 : -1))
  278. return result
  279. })
  280. export async function* list() {
  281. const project = Instance.project
  282. for (const item of await Storage.list(["session", project.id])) {
  283. yield Storage.read<Info>(item)
  284. }
  285. }
  286. export const children = fn(Identifier.schema("session"), async (parentID) => {
  287. const project = Instance.project
  288. const result = [] as Session.Info[]
  289. for (const item of await Storage.list(["session", project.id])) {
  290. const session = await Storage.read<Info>(item)
  291. if (session.parentID !== parentID) continue
  292. result.push(session)
  293. }
  294. return result
  295. })
  296. export const remove = fn(Identifier.schema("session"), async (sessionID) => {
  297. const project = Instance.project
  298. try {
  299. const session = await get(sessionID)
  300. for (const child of await children(sessionID)) {
  301. await remove(child.id)
  302. }
  303. await unshare(sessionID).catch(() => {})
  304. for (const msg of await Storage.list(["message", sessionID])) {
  305. for (const part of await Storage.list(["part", msg.at(-1)!])) {
  306. await Storage.remove(part)
  307. }
  308. await Storage.remove(msg)
  309. }
  310. await Storage.remove(["session", project.id, sessionID])
  311. Bus.publish(Event.Deleted, {
  312. info: session,
  313. })
  314. } catch (e) {
  315. log.error(e)
  316. }
  317. })
  318. export const updateMessage = fn(MessageV2.Info, async (msg) => {
  319. await Storage.write(["message", msg.sessionID, msg.id], msg)
  320. Bus.publish(MessageV2.Event.Updated, {
  321. info: msg,
  322. })
  323. return msg
  324. })
  325. export const removeMessage = fn(
  326. z.object({
  327. sessionID: Identifier.schema("session"),
  328. messageID: Identifier.schema("message"),
  329. }),
  330. async (input) => {
  331. await Storage.remove(["message", input.sessionID, input.messageID])
  332. Bus.publish(MessageV2.Event.Removed, {
  333. sessionID: input.sessionID,
  334. messageID: input.messageID,
  335. })
  336. return input.messageID
  337. },
  338. )
  339. const UpdatePartInput = z.union([
  340. MessageV2.Part,
  341. z.object({
  342. part: MessageV2.TextPart,
  343. delta: z.string(),
  344. }),
  345. z.object({
  346. part: MessageV2.ReasoningPart,
  347. delta: z.string(),
  348. }),
  349. ])
  350. export const updatePart = fn(UpdatePartInput, async (input) => {
  351. const part = "delta" in input ? input.part : input
  352. const delta = "delta" in input ? input.delta : undefined
  353. await Storage.write(["part", part.messageID, part.id], part)
  354. Bus.publish(MessageV2.Event.PartUpdated, {
  355. part,
  356. delta,
  357. })
  358. return part
  359. })
  360. export const getUsage = fn(
  361. z.object({
  362. model: z.custom<ModelsDev.Model>(),
  363. usage: z.custom<LanguageModelUsage>(),
  364. metadata: z.custom<ProviderMetadata>().optional(),
  365. }),
  366. (input) => {
  367. const tokens = {
  368. input: input.usage.inputTokens ?? 0,
  369. output: input.usage.outputTokens ?? 0,
  370. reasoning: input.usage?.reasoningTokens ?? 0,
  371. cache: {
  372. write: (input.metadata?.["anthropic"]?.["cacheCreationInputTokens"] ??
  373. // @ts-expect-error
  374. input.metadata?.["bedrock"]?.["usage"]?.["cacheWriteInputTokens"] ??
  375. 0) as number,
  376. read: input.usage.cachedInputTokens ?? 0,
  377. },
  378. }
  379. return {
  380. cost: new Decimal(0)
  381. .add(new Decimal(tokens.input).mul(input.model.cost?.input ?? 0).div(1_000_000))
  382. .add(new Decimal(tokens.output).mul(input.model.cost?.output ?? 0).div(1_000_000))
  383. .add(new Decimal(tokens.cache.read).mul(input.model.cost?.cache_read ?? 0).div(1_000_000))
  384. .add(
  385. new Decimal(tokens.cache.write).mul(input.model.cost?.cache_write ?? 0).div(1_000_000),
  386. )
  387. .toNumber(),
  388. tokens,
  389. }
  390. },
  391. )
  392. export class BusyError extends Error {
  393. constructor(public readonly sessionID: string) {
  394. super(`Session ${sessionID} is busy`)
  395. }
  396. }
  397. export const initialize = fn(
  398. z.object({
  399. sessionID: Identifier.schema("session"),
  400. modelID: z.string(),
  401. providerID: z.string(),
  402. messageID: Identifier.schema("message"),
  403. }),
  404. async (input) => {
  405. await SessionPrompt.command({
  406. sessionID: input.sessionID,
  407. messageID: input.messageID,
  408. model: input.providerID + "/" + input.modelID,
  409. command: Command.Default.INIT,
  410. arguments: "",
  411. })
  412. },
  413. )
  414. }