Преглед изворни кода

fix(session): restore busy route handling and add regression coverage (#20125)

Kit Langton пре 2 недеља
родитељ
комит
2ed756c72c

+ 4 - 0
packages/opencode/src/server/middleware.ts

@@ -1,6 +1,7 @@
 import { Provider } from "../provider/provider"
 import { NamedError } from "@opencode-ai/util/error"
 import { NotFoundError } from "../storage/db"
+import { Session } from "../session"
 import type { ContentfulStatusCode } from "hono/utils/http-status"
 import type { ErrorHandler } from "hono"
 import { HTTPException } from "hono/http-exception"
@@ -20,6 +21,9 @@ export function errorHandler(log: Log.Logger): ErrorHandler {
       else status = 500
       return c.json(err.toObject(), { status })
     }
+    if (err instanceof Session.BusyError) {
+      return c.json(new NamedError.Unknown({ message: err.message }).toObject(), { status: 400 })
+    }
     if (err instanceof HTTPException) return err.getResponse()
     const message = err instanceof Error && err.stack ? err.stack : err.toString()
     return c.json(new NamedError.Unknown({ message }).toObject(), {

+ 4 - 2
packages/opencode/src/session/index.ts

@@ -849,7 +849,8 @@ export namespace Session {
   export const children = fn(SessionID.zod, (id) => runPromise((svc) => svc.children(id)))
   export const remove = fn(SessionID.zod, (id) => runPromise((svc) => svc.remove(id)))
   export async function updateMessage<T extends MessageV2.Info>(msg: T): Promise<T> {
-    return runPromise((svc) => svc.updateMessage(MessageV2.Info.parse(msg) as T))
+    MessageV2.Info.parse(msg)
+    return runPromise((svc) => svc.updateMessage(msg))
   }
 
   export const removeMessage = fn(z.object({ sessionID: SessionID.zod, messageID: MessageID.zod }), (input) =>
@@ -862,7 +863,8 @@ export namespace Session {
   )
 
   export async function updatePart<T extends MessageV2.Part>(part: T): Promise<T> {
-    return runPromise((svc) => svc.updatePart(MessageV2.Part.parse(part) as T))
+    MessageV2.Part.parse(part)
+    return runPromise((svc) => svc.updatePart(part))
   }
 
   export const updatePartDelta = fn(

+ 0 - 3
packages/opencode/src/session/revert.ts

@@ -92,12 +92,10 @@ export namespace SessionRevert {
     const sessionID = session.id
     const msgs = await Session.messages({ sessionID })
     const messageID = session.revert.messageID
-    const preserve = [] as MessageV2.WithParts[]
     const remove = [] as MessageV2.WithParts[]
     let target: MessageV2.WithParts | undefined
     for (const msg of msgs) {
       if (msg.info.id < messageID) {
-        preserve.push(msg)
         continue
       }
       if (msg.info.id > messageID) {
@@ -105,7 +103,6 @@ export namespace SessionRevert {
         continue
       }
       if (session.revert.partID) {
-        preserve.push(msg)
         target = msg
         continue
       }

+ 83 - 0
packages/opencode/test/server/session-actions.test.ts

@@ -0,0 +1,83 @@
+import { afterEach, describe, expect, mock, spyOn, test } from "bun:test"
+import { Instance } from "../../src/project/instance"
+import { Server } from "../../src/server/server"
+import { Session } from "../../src/session"
+import { ModelID, ProviderID } from "../../src/provider/schema"
+import { MessageID, PartID, type SessionID } from "../../src/session/schema"
+import { SessionPrompt } from "../../src/session/prompt"
+import { Log } from "../../src/util/log"
+import { tmpdir } from "../fixture/fixture"
+
+Log.init({ print: false })
+
+afterEach(async () => {
+  mock.restore()
+  await Instance.disposeAll()
+})
+
+async function user(sessionID: SessionID, text: string) {
+  const msg = await Session.updateMessage({
+    id: MessageID.ascending(),
+    role: "user",
+    sessionID,
+    agent: "build",
+    model: { providerID: ProviderID.make("test"), modelID: ModelID.make("test") },
+    time: { created: Date.now() },
+  })
+  await Session.updatePart({
+    id: PartID.ascending(),
+    sessionID,
+    messageID: msg.id,
+    type: "text",
+    text,
+  })
+  return msg
+}
+
+describe("session action routes", () => {
+  test("abort route calls SessionPrompt.cancel", async () => {
+    await using tmp = await tmpdir({ git: true })
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const session = await Session.create({})
+        const cancel = spyOn(SessionPrompt, "cancel").mockResolvedValue()
+        const app = Server.Default()
+
+        const res = await app.request(`/session/${session.id}/abort`, {
+          method: "POST",
+        })
+
+        expect(res.status).toBe(200)
+        expect(await res.json()).toBe(true)
+        expect(cancel).toHaveBeenCalledWith(session.id)
+
+        await Session.remove(session.id)
+      },
+    })
+  })
+
+  test("delete message route returns 400 when session is busy", async () => {
+    await using tmp = await tmpdir({ git: true })
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const session = await Session.create({})
+        const msg = await user(session.id, "hello")
+        const busy = spyOn(SessionPrompt, "assertNotBusy").mockRejectedValue(new Session.BusyError(session.id))
+        const remove = spyOn(Session, "removeMessage").mockResolvedValue(msg.id)
+        const app = Server.Default()
+
+        const res = await app.request(`/session/${session.id}/message/${msg.id}`, {
+          method: "DELETE",
+        })
+
+        expect(res.status).toBe(400)
+        expect(busy).toHaveBeenCalledWith(session.id)
+        expect(remove).not.toHaveBeenCalled()
+
+        await Session.remove(session.id)
+      },
+    })
+  })
+})

+ 30 - 0
packages/opencode/test/session/compaction.test.ts

@@ -509,6 +509,36 @@ describe("session.compaction.prune", () => {
 })
 
 describe("session.compaction.process", () => {
+  test("throws when parent is not a user message", async () => {
+    await using tmp = await tmpdir()
+    await Instance.provide({
+      directory: tmp.path,
+      fn: async () => {
+        const session = await Session.create({})
+        const msg = await user(session.id, "hello")
+        const reply = await assistant(session.id, msg.id, tmp.path)
+        const rt = runtime("continue")
+        try {
+          const msgs = await Session.messages({ sessionID: session.id })
+          await expect(
+            rt.runPromise(
+              SessionCompaction.Service.use((svc) =>
+                svc.process({
+                  parentID: reply.id,
+                  messages: msgs,
+                  sessionID: session.id,
+                  auto: false,
+                }),
+              ),
+            ),
+          ).rejects.toThrow(`Compaction parent must be a user message: ${reply.id}`)
+        } finally {
+          await rt.dispose()
+        }
+      },
+    })
+  })
+
   test("publishes compacted event on continue", async () => {
     await using tmp = await tmpdir()
     await Instance.provide({

+ 66 - 1
packages/opencode/test/session/prompt-effect.test.ts

@@ -1,7 +1,8 @@
 import { NodeFileSystem } from "@effect/platform-node"
-import { expect } from "bun:test"
+import { expect, spyOn } from "bun:test"
 import { Cause, Effect, Exit, Fiber, Layer, ServiceMap } from "effect"
 import * as Stream from "effect/Stream"
+import z from "zod"
 import type { Agent } from "../../src/agent/agent"
 import { Agent as AgentSvc } from "../../src/agent/agent"
 import { Bus } from "../../src/bus"
@@ -25,6 +26,7 @@ import { MessageID, PartID, SessionID } from "../../src/session/schema"
 import { SessionStatus } from "../../src/session/status"
 import { Shell } from "../../src/shell/shell"
 import { Snapshot } from "../../src/snapshot"
+import { TaskTool } from "../../src/tool/task"
 import { ToolRegistry } from "../../src/tool/registry"
 import { Truncate } from "../../src/tool/truncate"
 import { Log } from "../../src/util/log"
@@ -630,6 +632,69 @@ it.effect(
   30_000,
 )
 
+it.effect(
+  "cancel finalizes subtask tool state",
+  () =>
+    provideTmpdirInstance(
+      (dir) =>
+        Effect.gen(function* () {
+          const ready = defer<void>()
+          const aborted = defer<void>()
+          const init = spyOn(TaskTool, "init").mockImplementation(async () => ({
+            description: "task",
+            parameters: z.object({
+              description: z.string(),
+              prompt: z.string(),
+              subagent_type: z.string(),
+              task_id: z.string().optional(),
+              command: z.string().optional(),
+            }),
+            execute: async (_args, ctx) => {
+              ready.resolve()
+              ctx.abort.addEventListener("abort", () => aborted.resolve(), { once: true })
+              await new Promise<void>(() => {})
+              return {
+                title: "",
+                metadata: {
+                  sessionId: SessionID.make("task"),
+                  model: ref,
+                },
+                output: "",
+              }
+            },
+          }))
+          yield* Effect.addFinalizer(() => Effect.sync(() => init.mockRestore()))
+
+          const { prompt, chat } = yield* boot()
+          const msg = yield* user(chat.id, "hello")
+          yield* addSubtask(chat.id, msg.id)
+
+          const fiber = yield* prompt.loop({ sessionID: chat.id }).pipe(Effect.forkChild)
+          yield* Effect.promise(() => ready.promise)
+          yield* prompt.cancel(chat.id)
+          yield* Effect.promise(() => aborted.promise)
+
+          const exit = yield* Fiber.await(fiber)
+          expect(Exit.isSuccess(exit)).toBe(true)
+
+          const msgs = yield* Effect.promise(() => MessageV2.filterCompacted(MessageV2.stream(chat.id)))
+          const taskMsg = msgs.find((item) => item.info.role === "assistant" && item.info.agent === "general")
+          expect(taskMsg?.info.role).toBe("assistant")
+          if (!taskMsg || taskMsg.info.role !== "assistant") return
+
+          const tool = toolPart(taskMsg.parts)
+          expect(tool?.type).toBe("tool")
+          if (!tool) return
+
+          expect(tool.state.status).not.toBe("running")
+          expect(taskMsg.info.time.completed).toBeDefined()
+          expect(taskMsg.info.finish).toBeDefined()
+        }),
+      { git: true, config: cfg },
+    ),
+  30_000,
+)
+
 it.effect(
   "cancel with queued callers resolves all cleanly",
   () =>