Просмотр исходного кода

fix(app): terminal pty isolation

Adam 1 месяц назад
Родитель
Сommit
548608b7ad

+ 6 - 2
packages/app/src/components/terminal.tsx

@@ -10,6 +10,7 @@ import { resolveThemeVariant, useTheme, withAlpha, type HexColor } from "@openco
 import { useLanguage } from "@/context/language"
 import { showToast } from "@opencode-ai/ui/toast"
 import { disposeIfDisposable, getHoveredLinkText, setOptionIfSupported } from "@/utils/runtime-adapters"
+import { terminalWriter } from "@/utils/terminal-writer"
 
 const TOGGLE_TERMINAL_ID = "terminal.toggle"
 const DEFAULT_TOGGLE_TERMINAL_KEYBIND = "ctrl+`"
@@ -160,6 +161,7 @@ export const Terminal = (props: TerminalProps) => {
   const start =
     typeof local.pty.cursor === "number" && Number.isSafeInteger(local.pty.cursor) ? local.pty.cursor : undefined
   let cursor = start ?? 0
+  let output: ReturnType<typeof terminalWriter> | undefined
 
   const cleanup = () => {
     if (!cleanups.length) return
@@ -300,7 +302,7 @@ export const Terminal = (props: TerminalProps) => {
         fontSize: 14,
         fontFamily: monoFontFamily(settings.appearance.font()),
         allowTransparency: false,
-        convertEol: true,
+        convertEol: false,
         theme: terminalColors(),
         scrollback: 10_000,
         ghostty: g,
@@ -312,6 +314,7 @@ export const Terminal = (props: TerminalProps) => {
       }
       ghostty = g
       term = t
+      output = terminalWriter((data) => t.write(data))
 
       t.attachCustomKeyEventHandler((event) => {
         const key = event.key.toLowerCase()
@@ -416,7 +419,7 @@ export const Terminal = (props: TerminalProps) => {
 
         const data = typeof event.data === "string" ? event.data : ""
         if (!data) return
-        t.write(data)
+        output?.push(data)
         cursor += data.length
       }
       socket.addEventListener("message", handleMessage)
@@ -459,6 +462,7 @@ export const Terminal = (props: TerminalProps) => {
 
   onCleanup(() => {
     disposed = true
+    output?.flush()
     persistTerminal({ term, addon: serializeAddon, cursor, pty: local.pty, onCleanup: props.onCleanup })
     cleanup()
   })

+ 33 - 0
packages/app/src/utils/terminal-writer.test.ts

@@ -0,0 +1,33 @@
+import { describe, expect, test } from "bun:test"
+import { terminalWriter } from "./terminal-writer"
+
+describe("terminalWriter", () => {
+  test("buffers and flushes once per schedule", () => {
+    const calls: string[] = []
+    const scheduled: VoidFunction[] = []
+    const writer = terminalWriter(
+      (data) => calls.push(data),
+      (flush) => scheduled.push(flush),
+    )
+
+    writer.push("a")
+    writer.push("b")
+    writer.push("c")
+
+    expect(calls).toEqual([])
+    expect(scheduled).toHaveLength(1)
+
+    scheduled[0]?.()
+    expect(calls).toEqual(["abc"])
+  })
+
+  test("flush is a no-op when empty", () => {
+    const calls: string[] = []
+    const writer = terminalWriter(
+      (data) => calls.push(data),
+      (flush) => flush(),
+    )
+    writer.flush()
+    expect(calls).toEqual([])
+  })
+})

+ 27 - 0
packages/app/src/utils/terminal-writer.ts

@@ -0,0 +1,27 @@
+export function terminalWriter(
+  write: (data: string) => void,
+  schedule: (flush: VoidFunction) => void = queueMicrotask,
+) {
+  let chunks: string[] | undefined
+  let scheduled = false
+
+  const flush = () => {
+    scheduled = false
+    const items = chunks
+    if (!items?.length) return
+    chunks = undefined
+    write(items.join(""))
+  }
+
+  const push = (data: string) => {
+    if (!data) return
+    if (chunks) chunks.push(data)
+    else chunks = [data]
+
+    if (scheduled) return
+    scheduled = true
+    schedule(flush)
+  }
+
+  return { push, flush }
+}

+ 50 - 16
packages/opencode/src/pty/index.ts

@@ -4,7 +4,6 @@ import { type IPty } from "bun-pty"
 import z from "zod"
 import { Identifier } from "../id/id"
 import { Log } from "../util/log"
-import type { WSContext } from "hono/ws"
 import { Instance } from "../project/instance"
 import { lazy } from "@opencode-ai/util/lazy"
 import { Shell } from "@/shell/shell"
@@ -17,6 +16,22 @@ export namespace Pty {
   const BUFFER_CHUNK = 64 * 1024
   const encoder = new TextEncoder()
 
+  type Socket = {
+    readyState: number
+    send: (data: string | Uint8Array<ArrayBuffer> | ArrayBuffer) => void
+    close: (code?: number, reason?: string) => void
+  }
+
+  const sockets = new WeakMap<object, number>()
+  let socketCounter = 0
+
+  const tagSocket = (ws: Socket) => {
+    if (!ws || typeof ws !== "object") return
+    const next = (socketCounter = (socketCounter + 1) % Number.MAX_SAFE_INTEGER)
+    sockets.set(ws, next)
+    return next
+  }
+
   // WebSocket control frame: 0x00 + UTF-8 JSON (currently { cursor }).
   const meta = (cursor: number) => {
     const json = JSON.stringify({ cursor })
@@ -81,7 +96,7 @@ export namespace Pty {
     buffer: string
     bufferCursor: number
     cursor: number
-    subscribers: Set<WSContext>
+    subscribers: Map<Socket, number>
   }
 
   const state = Instance.state(
@@ -91,8 +106,12 @@ export namespace Pty {
         try {
           session.process.kill()
         } catch {}
-        for (const ws of session.subscribers) {
-          ws.close()
+        for (const ws of session.subscribers.keys()) {
+          try {
+            ws.close()
+          } catch {
+            // ignore
+          }
         }
       }
       sessions.clear()
@@ -154,18 +173,26 @@ export namespace Pty {
       buffer: "",
       bufferCursor: 0,
       cursor: 0,
-      subscribers: new Set(),
+      subscribers: new Map(),
     }
     state().set(id, session)
     ptyProcess.onData((data) => {
       session.cursor += data.length
 
-      for (const ws of session.subscribers) {
+      for (const [ws, id] of session.subscribers) {
         if (ws.readyState !== 1) {
           session.subscribers.delete(ws)
           continue
         }
-        ws.send(data)
+        if (typeof ws === "object" && sockets.get(ws) !== id) {
+          session.subscribers.delete(ws)
+          continue
+        }
+        try {
+          ws.send(data)
+        } catch {
+          session.subscribers.delete(ws)
+        }
       }
 
       session.buffer += data
@@ -177,14 +204,15 @@ export namespace Pty {
     ptyProcess.onExit(({ exitCode }) => {
       log.info("session exited", { id, exitCode })
       session.info.status = "exited"
-      for (const ws of session.subscribers) {
-        ws.close()
+      for (const ws of session.subscribers.keys()) {
+        try {
+          ws.close()
+        } catch {
+          // ignore
+        }
       }
       session.subscribers.clear()
       Bus.publish(Event.Exited, { id, exitCode })
-      for (const ws of session.subscribers) {
-        ws.close()
-      }
       state().delete(id)
     })
     Bus.publish(Event.Created, { info })
@@ -211,9 +239,14 @@ export namespace Pty {
     try {
       session.process.kill()
     } catch {}
-    for (const ws of session.subscribers) {
-      ws.close()
+    for (const ws of session.subscribers.keys()) {
+      try {
+        ws.close()
+      } catch {
+        // ignore
+      }
     }
+    session.subscribers.clear()
     state().delete(id)
     Bus.publish(Event.Deleted, { id })
   }
@@ -232,7 +265,7 @@ export namespace Pty {
     }
   }
 
-  export function connect(id: string, ws: WSContext, cursor?: number) {
+  export function connect(id: string, ws: Socket, cursor?: number) {
     const session = state().get(id)
     if (!session) {
       ws.close()
@@ -272,7 +305,8 @@ export namespace Pty {
       return
     }
 
-    session.subscribers.add(ws)
+    const socketId = tagSocket(ws)
+    if (typeof socketId === "number") session.subscribers.set(ws, socketId)
     return {
       onMessage: (message: string | ArrayBuffer) => {
         session.process.write(String(message))

+ 20 - 1
packages/opencode/src/server/routes/pty.ts

@@ -160,9 +160,25 @@ export const PtyRoutes = lazy(() =>
         })()
         let handler: ReturnType<typeof Pty.connect>
         if (!Pty.get(id)) throw new Error("Session not found")
+
+        type Socket = {
+          readyState: number
+          send: (data: string | Uint8Array<ArrayBuffer> | ArrayBuffer) => void
+          close: (code?: number, reason?: string) => void
+        }
+
+        const isSocket = (value: unknown): value is Socket => {
+          if (!value || typeof value !== "object") return false
+          if (!("readyState" in value)) return false
+          if (!("send" in value) || typeof (value as { send?: unknown }).send !== "function") return false
+          if (!("close" in value) || typeof (value as { close?: unknown }).close !== "function") return false
+          return typeof (value as { readyState?: unknown }).readyState === "number"
+        }
+
         return {
           onOpen(_event, ws) {
-            handler = Pty.connect(id, ws, cursor)
+            const socket = isSocket(ws.raw) ? ws.raw : ws
+            handler = Pty.connect(id, socket, cursor)
           },
           onMessage(event) {
             handler?.onMessage(String(event.data))
@@ -170,6 +186,9 @@ export const PtyRoutes = lazy(() =>
           onClose() {
             handler?.onClose()
           },
+          onError() {
+            handler?.onClose()
+          },
         }
       }),
     ),

+ 54 - 0
packages/opencode/test/pty/pty-output-isolation.test.ts

@@ -0,0 +1,54 @@
+import { describe, expect, test } from "bun:test"
+import { Instance } from "../../src/project/instance"
+import { Pty } from "../../src/pty"
+import { tmpdir } from "../fixture/fixture"
+
+describe("pty", () => {
+  test("does not leak output when websocket objects are reused", async () => {
+    await using dir = await tmpdir({ git: true })
+
+    await Instance.provide({
+      directory: dir.path,
+      fn: async () => {
+        const a = await Pty.create({ command: "cat", title: "a" })
+        const b = await Pty.create({ command: "cat", title: "b" })
+        try {
+          const outA: string[] = []
+          const outB: string[] = []
+
+          const ws = {
+            readyState: 1,
+            send: (data: unknown) => {
+              outA.push(typeof data === "string" ? data : Buffer.from(data as Uint8Array).toString("utf8"))
+            },
+            close: () => {
+              // no-op (simulate abrupt drop)
+            },
+          }
+
+          // Connect "a" first with ws.
+          Pty.connect(a.id, ws as any)
+
+          // Now "reuse" the same ws object for another connection.
+          ws.send = (data: unknown) => {
+            outB.push(typeof data === "string" ? data : Buffer.from(data as Uint8Array).toString("utf8"))
+          }
+          Pty.connect(b.id, ws as any)
+
+          // Clear connect metadata writes.
+          outA.length = 0
+          outB.length = 0
+
+          // Output from a must never show up in b.
+          Pty.write(a.id, "AAA\n")
+          await Bun.sleep(100)
+
+          expect(outB.join("")).not.toContain("AAA")
+        } finally {
+          await Pty.remove(a.id)
+          await Pty.remove(b.id)
+        }
+      },
+    })
+  })
+})