Просмотр исходного кода

tweak: consolidate session lock logic (#3185)

Aiden Cline 4 месяцев назад
Родитель
Сommit
b3fcc9a81d

+ 2 - 1
packages/opencode/src/server/server.ts

@@ -25,6 +25,7 @@ import { Global } from "../global"
 import { ProjectRoute } from "./project"
 import { ProjectRoute } from "./project"
 import { ToolRegistry } from "../tool/registry"
 import { ToolRegistry } from "../tool/registry"
 import { zodToJsonSchema } from "zod-to-json-schema"
 import { zodToJsonSchema } from "zod-to-json-schema"
+import { SessionLock } from "../session/lock"
 import { SessionPrompt } from "../session/prompt"
 import { SessionPrompt } from "../session/prompt"
 import { SessionCompaction } from "../session/compaction"
 import { SessionCompaction } from "../session/compaction"
 import { SessionRevert } from "../session/revert"
 import { SessionRevert } from "../session/revert"
@@ -549,7 +550,7 @@ export namespace Server {
           }),
           }),
         ),
         ),
         async (c) => {
         async (c) => {
-          return c.json(SessionPrompt.abort(c.req.valid("param").id))
+          return c.json(SessionLock.abort(c.req.valid("param").id))
         },
         },
       )
       )
       .post(
       .post(

+ 94 - 0
packages/opencode/src/session/lock.ts

@@ -0,0 +1,94 @@
+import z from "zod/v4"
+import { Instance } from "../project/instance"
+import { Log } from "../util/log"
+import { NamedError } from "../util/error"
+
+export namespace SessionLock {
+  const log = Log.create({ service: "session.lock" })
+
+  export const LockedError = NamedError.create(
+    "SessionLockedError",
+    z.object({
+      sessionID: z.string(),
+      message: z.string(),
+    }),
+  )
+
+  type LockState = {
+    controller: AbortController
+    created: number
+  }
+
+  const state = Instance.state(
+    () => {
+      const locks = new Map<string, LockState>()
+      return {
+        locks,
+      }
+    },
+    async (current) => {
+      for (const [sessionID, lock] of current.locks) {
+        log.info("force abort", { sessionID })
+        lock.controller.abort()
+      }
+      current.locks.clear()
+    },
+  )
+
+  function get(sessionID: string) {
+    return state().locks.get(sessionID)
+  }
+
+  function unset(input: { sessionID: string; controller: AbortController }) {
+    const lock = get(input.sessionID)
+    if (!lock) return false
+    if (lock.controller !== input.controller) return false
+    state().locks.delete(input.sessionID)
+    return true
+  }
+
+  export function acquire(input: { sessionID: string }) {
+    const lock = get(input.sessionID)
+    if (lock) {
+      throw new LockedError({ sessionID: input.sessionID, message: `Session ${input.sessionID} is locked` })
+    }
+    const controller = new AbortController()
+    state().locks.set(input.sessionID, {
+      controller,
+      created: Date.now(),
+    })
+    log.info("locked", { sessionID: input.sessionID })
+    return {
+      signal: controller.signal,
+      abort() {
+        controller.abort()
+        unset({ sessionID: input.sessionID, controller })
+      },
+      async [Symbol.dispose]() {
+        const removed = unset({ sessionID: input.sessionID, controller })
+        if (removed) {
+          log.info("unlocked", { sessionID: input.sessionID })
+        }
+      },
+    }
+  }
+
+  export function abort(sessionID: string) {
+    const lock = get(sessionID)
+    if (!lock) return false
+    log.info("abort", { sessionID })
+    lock.controller.abort()
+    state().locks.delete(sessionID)
+    return true
+  }
+
+  export function isLocked(sessionID: string) {
+    return get(sessionID) !== undefined
+  }
+
+  export function assertUnlocked(sessionID: string) {
+    const lock = get(sessionID)
+    if (!lock) return
+    throw new LockedError({ sessionID, message: `Session ${sessionID} is locked` })
+  }
+}

+ 9 - 22
packages/opencode/src/session/prompt.ts

@@ -22,6 +22,7 @@ import {
   jsonSchema,
   jsonSchema,
 } from "ai"
 } from "ai"
 import { SessionCompaction } from "./compaction"
 import { SessionCompaction } from "./compaction"
+import { SessionLock } from "./lock"
 import { Instance } from "../project/instance"
 import { Instance } from "../project/instance"
 import { Bus } from "../bus"
 import { Bus } from "../bus"
 import { ProviderTransform } from "../provider/transform"
 import { ProviderTransform } from "../provider/transform"
@@ -65,7 +66,6 @@ export namespace SessionPrompt {
 
 
   const state = Instance.state(
   const state = Instance.state(
     () => {
     () => {
-      const pending = new Map<string, AbortController>()
       const queued = new Map<
       const queued = new Map<
         string,
         string,
         {
         {
@@ -75,14 +75,11 @@ export namespace SessionPrompt {
       >()
       >()
 
 
       return {
       return {
-        pending,
         queued,
         queued,
       }
       }
     },
     },
-    async (state) => {
-      for (const [_, controller] of state.pending) {
-        controller.abort()
-      }
+    async (current) => {
+      current.queued.clear()
     },
     },
   )
   )
 
 
@@ -1179,30 +1176,20 @@ export namespace SessionPrompt {
   }
   }
 
 
   function isBusy(sessionID: string) {
   function isBusy(sessionID: string) {
-    return state().pending.has(sessionID)
+    return SessionLock.isLocked(sessionID)
   }
   }
 
 
-  export function abort(sessionID: string) {
-    const controller = state().pending.get(sessionID)
-    if (!controller) return false
-    log.info("aborting", {
+  function lock(sessionID: string) {
+    const handle = SessionLock.acquire({
       sessionID,
       sessionID,
     })
     })
-    controller.abort()
-    state().pending.delete(sessionID)
-    return true
-  }
-
-  function lock(sessionID: string) {
     log.info("locking", { sessionID })
     log.info("locking", { sessionID })
-    if (state().pending.has(sessionID)) throw new Error("TODO")
-    const controller = new AbortController()
-    state().pending.set(sessionID, controller)
     return {
     return {
-      signal: controller.signal,
+      signal: handle.signal,
+      abort: handle.abort,
       async [Symbol.dispose]() {
       async [Symbol.dispose]() {
+        handle[Symbol.dispose]()
         log.info("unlocking", { sessionID })
         log.info("unlocking", { sessionID })
-        state().pending.delete(sessionID)
 
 
         const session = await Session.get(sessionID)
         const session = await Session.get(sessionID)
         if (session.parentID) return
         if (session.parentID) return

+ 10 - 0
packages/opencode/src/session/revert.ts

@@ -7,6 +7,7 @@ import { Log } from "../util/log"
 import { splitWhen } from "remeda"
 import { splitWhen } from "remeda"
 import { Storage } from "../storage/storage"
 import { Storage } from "../storage/storage"
 import { Bus } from "../bus"
 import { Bus } from "../bus"
+import { SessionLock } from "./lock"
 
 
 export namespace SessionRevert {
 export namespace SessionRevert {
   const log = Log.create({ service: "session.revert" })
   const log = Log.create({ service: "session.revert" })
@@ -19,6 +20,11 @@ export namespace SessionRevert {
   export type RevertInput = z.infer<typeof RevertInput>
   export type RevertInput = z.infer<typeof RevertInput>
 
 
   export async function revert(input: RevertInput) {
   export async function revert(input: RevertInput) {
+    SessionLock.assertUnlocked(input.sessionID)
+    using _ = SessionLock.acquire({
+      sessionID: input.sessionID,
+    })
+
     const all = await Session.messages(input.sessionID)
     const all = await Session.messages(input.sessionID)
     let lastUser: MessageV2.User | undefined
     let lastUser: MessageV2.User | undefined
     const session = await Session.get(input.sessionID)
     const session = await Session.get(input.sessionID)
@@ -64,6 +70,10 @@ export namespace SessionRevert {
 
 
   export async function unrevert(input: { sessionID: string }) {
   export async function unrevert(input: { sessionID: string }) {
     log.info("unreverting", input)
     log.info("unreverting", input)
+    SessionLock.assertUnlocked(input.sessionID)
+    using _ = SessionLock.acquire({
+      sessionID: input.sessionID,
+    })
     const session = await Session.get(input.sessionID)
     const session = await Session.get(input.sessionID)
     if (!session.revert) return session
     if (!session.revert) return session
     if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)
     if (session.revert.snapshot) await Snapshot.restore(session.revert.snapshot)

+ 2 - 1
packages/opencode/src/tool/task.ts

@@ -6,6 +6,7 @@ import { Bus } from "../bus"
 import { MessageV2 } from "../session/message-v2"
 import { MessageV2 } from "../session/message-v2"
 import { Identifier } from "../id/id"
 import { Identifier } from "../id/id"
 import { Agent } from "../agent/agent"
 import { Agent } from "../agent/agent"
+import { SessionLock } from "../session/lock"
 import { SessionPrompt } from "../session/prompt"
 import { SessionPrompt } from "../session/prompt"
 
 
 export const TaskTool = Tool.define("task", async () => {
 export const TaskTool = Tool.define("task", async () => {
@@ -53,7 +54,7 @@ export const TaskTool = Tool.define("task", async () => {
       }
       }
 
 
       ctx.abort.addEventListener("abort", () => {
       ctx.abort.addEventListener("abort", () => {
-        SessionPrompt.abort(session.id)
+        SessionLock.abort(session.id)
       })
       })
       const result = await SessionPrompt.prompt({
       const result = await SessionPrompt.prompt({
         messageID,
         messageID,

+ 0 - 3
packages/opencode/src/util/error.ts

@@ -1,7 +1,4 @@
 import z from "zod/v4"
 import z from "zod/v4"
-// import { Log } from "./log"
-
-// const log = Log.create()
 
 
 export abstract class NamedError extends Error {
 export abstract class NamedError extends Error {
   abstract schema(): z.core.$ZodType
   abstract schema(): z.core.$ZodType