Преглед изворни кода

desktop: use process-wrap instead of manual job object (#13431)

Brendan Allan пре 1 месец
родитељ
комит
920255e8c6

+ 84 - 7
packages/desktop/src-tauri/Cargo.lock

@@ -2343,9 +2343,9 @@ dependencies = [
 
 [[package]]
 name = "libc"
-version = "0.2.177"
+version = "0.2.180"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976"
+checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
 
 [[package]]
 name = "libloading"
@@ -2663,6 +2663,18 @@ dependencies = [
  "memoffset",
 ]
 
+[[package]]
+name = "nix"
+version = "0.31.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "225e7cfe711e0ba79a68baeddb2982723e4235247aefce1482f2f16c27865b66"
+dependencies = [
+ "bitflags 2.10.0",
+ "cfg-if",
+ "cfg_aliases",
+ "libc",
+]
+
 [[package]]
 name = "nodrop"
 version = "0.1.14"
@@ -3093,6 +3105,7 @@ dependencies = [
  "listeners",
  "objc2 0.6.3",
  "objc2-web-kit",
+ "process-wrap",
  "reqwest 0.12.24",
  "semver",
  "serde",
@@ -3123,7 +3136,6 @@ dependencies = [
  "tracing-subscriber",
  "uuid",
  "webkit2gtk",
- "windows 0.61.3",
 ]
 
 [[package]]
@@ -3638,6 +3650,20 @@ dependencies = [
  "unicode-ident",
 ]
 
+[[package]]
+name = "process-wrap"
+version = "9.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ccd9713fe2c91c3c85ac388b31b89de339365d2c995146e630b5e0da9d06526a"
+dependencies = [
+ "futures",
+ "indexmap 2.12.1",
+ "nix 0.31.1",
+ "tokio",
+ "tracing",
+ "windows 0.62.2",
+]
+
 [[package]]
 name = "psl-types"
 version = "2.0.11"
@@ -6460,11 +6486,23 @@ version = "0.61.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893"
 dependencies = [
- "windows-collections",
+ "windows-collections 0.2.0",
  "windows-core 0.61.2",
- "windows-future",
+ "windows-future 0.2.1",
  "windows-link 0.1.3",
- "windows-numerics",
+ "windows-numerics 0.2.0",
+]
+
+[[package]]
+name = "windows"
+version = "0.62.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580"
+dependencies = [
+ "windows-collections 0.3.2",
+ "windows-core 0.62.2",
+ "windows-future 0.3.2",
+ "windows-numerics 0.3.1",
 ]
 
 [[package]]
@@ -6476,6 +6514,15 @@ dependencies = [
  "windows-core 0.61.2",
 ]
 
+[[package]]
+name = "windows-collections"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610"
+dependencies = [
+ "windows-core 0.62.2",
+]
+
 [[package]]
 name = "windows-core"
 version = "0.51.1"
@@ -6519,7 +6566,18 @@ checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e"
 dependencies = [
  "windows-core 0.61.2",
  "windows-link 0.1.3",
- "windows-threading",
+ "windows-threading 0.1.0",
+]
+
+[[package]]
+name = "windows-future"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb"
+dependencies = [
+ "windows-core 0.62.2",
+ "windows-link 0.2.1",
+ "windows-threading 0.2.1",
 ]
 
 [[package]]
@@ -6566,6 +6624,16 @@ dependencies = [
  "windows-link 0.1.3",
 ]
 
+[[package]]
+name = "windows-numerics"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26"
+dependencies = [
+ "windows-core 0.62.2",
+ "windows-link 0.2.1",
+]
+
 [[package]]
 name = "windows-registry"
 version = "0.5.3"
@@ -6741,6 +6809,15 @@ dependencies = [
  "windows-link 0.1.3",
 ]
 
+[[package]]
+name = "windows-threading"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37"
+dependencies = [
+ "windows-link 0.2.1",
+]
+
 [[package]]
 name = "windows-version"
 version = "0.1.7"

+ 2 - 9
packages/desktop/src-tauri/Cargo.toml

@@ -34,7 +34,7 @@ tauri-plugin-single-instance = { version = "2", features = ["deep-link"] }
 
 serde = { version = "1", features = ["derive"] }
 serde_json = "1"
-tokio = "1.48.0"
+tokio = { version = "1.48.0", features = ["process"] }
 listeners = "0.3"
 tauri-plugin-os = "2"
 futures = "0.3.31"
@@ -52,6 +52,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
 tracing-appender = "0.2"
 chrono = "0.4"
 tokio-stream = { version = "0.1.18", features = ["sync"] }
+process-wrap = { version = "9.0.3", features = ["tokio1"] }
 
 [target.'cfg(target_os = "linux")'.dependencies]
 gtk = "0.18.2"
@@ -62,14 +63,6 @@ objc2 = "0.6"
 objc2-web-kit = "0.3"
 
 
-[target.'cfg(windows)'.dependencies]
-windows = { version = "0.61", features = [
-    "Win32_Foundation",
-    "Win32_System_JobObjects",
-    "Win32_System_Threading",
-    "Win32_Security"
-] }
-
 [patch.crates-io]
 specta = { git = "https://github.com/specta-rs/specta", rev = "591a5f3ddc78348abf4cbb541d599d65306d92b9" }
 specta-typescript = { git = "https://github.com/specta-rs/specta", rev = "591a5f3ddc78348abf4cbb541d599d65306d92b9" }

+ 139 - 23
packages/desktop/src-tauri/src/cli.rs

@@ -1,12 +1,19 @@
 use futures::{FutureExt, Stream, StreamExt, future};
+use process_wrap::tokio::CommandWrap;
+#[cfg(unix)]
+use process_wrap::tokio::ProcessGroup;
+#[cfg(windows)]
+use process_wrap::tokio::{JobObject, KillOnDrop};
+#[cfg(unix)]
+use std::os::unix::process::ExitStatusExt;
+use std::{process::Stdio, time::Duration};
 use tauri::{AppHandle, Manager, path::BaseDirectory};
-use tauri_plugin_shell::{
-    ShellExt,
-    process::{CommandChild, CommandEvent, TerminatedPayload},
-};
 use tauri_plugin_store::StoreExt;
 use tauri_specta::Event;
-use tokio::sync::oneshot;
+use tokio::io::{AsyncBufReadExt, BufReader};
+use tokio::process::Command;
+use tokio::sync::{mpsc, oneshot};
+use tokio_stream::wrappers::ReceiverStream;
 use tracing::Instrument;
 
 use crate::constants::{SETTINGS_STORE, WSL_ENABLED_KEY};
@@ -25,6 +32,33 @@ pub struct Config {
     pub server: Option<ServerConfig>,
 }
 
+#[derive(Clone, Debug)]
+pub enum CommandEvent {
+    Stdout(Vec<u8>),
+    Stderr(Vec<u8>),
+    Error(String),
+    Terminated(TerminatedPayload),
+}
+
+#[derive(Clone, Copy, Debug)]
+pub struct TerminatedPayload {
+    pub code: Option<i32>,
+    pub signal: Option<i32>,
+}
+
+#[derive(Clone, Debug)]
+pub struct CommandChild {
+    kill: mpsc::Sender<()>,
+}
+
+impl CommandChild {
+    pub fn kill(&self) -> std::io::Result<()> {
+        self.kill
+            .try_send(())
+            .map_err(|e| std::io::Error::other(e.to_string()))
+    }
+}
+
 pub async fn get_config(app: &AppHandle) -> Option<Config> {
     let (events, _) = spawn_command(app, "debug config", &[]).ok()?;
 
@@ -190,7 +224,7 @@ pub fn spawn_command(
     app: &tauri::AppHandle,
     args: &str,
     extra_env: &[(&str, String)],
-) -> Result<(impl Stream<Item = CommandEvent> + 'static, CommandChild), tauri_plugin_shell::Error> {
+) -> Result<(impl Stream<Item = CommandEvent> + 'static, CommandChild), std::io::Error> {
     let state_dir = app
         .path()
         .resolve("", BaseDirectory::AppLocalData)
@@ -217,7 +251,7 @@ pub fn spawn_command(
             .map(|(key, value)| (key.to_string(), value.clone())),
     );
 
-    let cmd = if cfg!(windows) {
+    let mut cmd = if cfg!(windows) {
         if is_wsl_enabled(app) {
             tracing::info!("WSL is enabled, spawning CLI server in WSL");
             let version = app.package_info().version.to_string();
@@ -249,18 +283,16 @@ pub fn spawn_command(
 
             script.push(format!("{} exec \"$BIN\" {}", env_prefix.join(" "), args));
 
-            app.shell()
-                .command("wsl")
-                .args(["-e", "bash", "-lc", &script.join("\n")])
+            let mut cmd = Command::new("wsl");
+            cmd.args(["-e", "bash", "-lc", &script.join("\n")]);
+            cmd
         } else {
-            let mut cmd = app
-                .shell()
-                .sidecar("opencode-cli")
-                .unwrap()
-                .args(args.split_whitespace());
+            let sidecar = get_sidecar_path(app);
+            let mut cmd = Command::new(sidecar);
+            cmd.args(args.split_whitespace());
 
             for (key, value) in envs {
-                cmd = cmd.env(key, value);
+                cmd.env(key, value);
             }
 
             cmd
@@ -269,26 +301,111 @@ pub fn spawn_command(
         let sidecar = get_sidecar_path(app);
         let shell = get_user_shell();
 
-        let cmd = if shell.ends_with("/nu") {
+        let line = if shell.ends_with("/nu") {
             format!("^\"{}\" {}", sidecar.display(), args)
         } else {
             format!("\"{}\" {}", sidecar.display(), args)
         };
 
-        let mut cmd = app.shell().command(&shell).args(["-il", "-c", &cmd]);
+        let mut cmd = Command::new(shell);
+        cmd.args(["-il", "-c", &line]);
 
         for (key, value) in envs {
-            cmd = cmd.env(key, value);
+            cmd.env(key, value);
         }
 
         cmd
     };
 
-    let (rx, child) = cmd.spawn()?;
-    let event_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
+    cmd.stdin(Stdio::null())
+        .stdout(Stdio::piped())
+        .stderr(Stdio::piped());
+
+    let mut wrap = CommandWrap::from(cmd);
+
+    #[cfg(unix)]
+    {
+        wrap.wrap(ProcessGroup::leader());
+    }
+
+    #[cfg(windows)]
+    {
+        wrap.wrap(JobObject).wrap(KillOnDrop);
+    }
+
+    let mut child = wrap.spawn()?;
+    let stdout = child.stdout().take();
+    let stderr = child.stderr().take();
+    let (tx, rx) = mpsc::channel(256);
+    let (kill_tx, mut kill_rx) = mpsc::channel(1);
+
+    if let Some(stdout) = stdout {
+        let tx = tx.clone();
+        tokio::spawn(async move {
+            let mut lines = BufReader::new(stdout).lines();
+            while let Ok(Some(line)) = lines.next_line().await {
+                let _ = tx.send(CommandEvent::Stdout(line.into_bytes())).await;
+            }
+        });
+    }
+
+    if let Some(stderr) = stderr {
+        let tx = tx.clone();
+        tokio::spawn(async move {
+            let mut lines = BufReader::new(stderr).lines();
+            while let Ok(Some(line)) = lines.next_line().await {
+                let _ = tx.send(CommandEvent::Stderr(line.into_bytes())).await;
+            }
+        });
+    }
+
+    tokio::spawn(async move {
+        let status = loop {
+            match child.try_wait() {
+                Ok(Some(status)) => break Ok(status),
+                Ok(None) => {}
+                Err(err) => break Err(err),
+            }
+
+            tokio::select! {
+                _ = kill_rx.recv() => {
+                    let _ = child.start_kill();
+                }
+                _ = tokio::time::sleep(Duration::from_millis(100)) => {}
+            }
+        };
+
+        match status {
+            Ok(status) => {
+                let payload = TerminatedPayload {
+                    code: status.code(),
+                    signal: signal_from_status(status),
+                };
+                let _ = tx.send(CommandEvent::Terminated(payload)).await;
+            }
+            Err(err) => {
+                let _ = tx.send(CommandEvent::Error(err.to_string())).await;
+            }
+        }
+    });
+
+    let event_stream = ReceiverStream::new(rx);
     let event_stream = sqlite_migration::logs_middleware(app.clone(), event_stream);
 
-    Ok((event_stream, child))
+    Ok((event_stream, CommandChild { kill: kill_tx }))
+}
+
+fn signal_from_status(status: std::process::ExitStatus) -> Option<i32> {
+    #[cfg(unix)]
+    {
+        return status.signal();
+    }
+
+    #[cfg(not(unix))]
+    {
+        let _ = status;
+        None
+    }
 }
 
 pub fn serve(
@@ -340,7 +457,6 @@ pub fn serve(
                             let _ = tx.send(payload);
                         }
                     }
-                    _ => {}
                 }
 
                 future::ready(())

+ 0 - 145
packages/desktop/src-tauri/src/job_object.rs

@@ -1,145 +0,0 @@
-//! Windows Job Object for reliable child process cleanup.
-//!
-//! This module provides a wrapper around Windows Job Objects with the
-//! `JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE` flag set. When the job object handle
-//! is closed (including when the parent process exits or crashes), Windows
-//! automatically terminates all processes assigned to the job.
-//!
-//! This is more reliable than manual cleanup because it works even if:
-//! - The parent process crashes
-//! - The parent is killed via Task Manager
-//! - The RunEvent::Exit handler fails to run
-
-use std::io::{Error, Result};
-#[cfg(windows)]
-use std::sync::Mutex;
-use windows::Win32::Foundation::{CloseHandle, HANDLE};
-use windows::Win32::System::JobObjects::{
-    AssignProcessToJobObject, CreateJobObjectW, JobObjectExtendedLimitInformation,
-    SetInformationJobObject, JOBOBJECT_EXTENDED_LIMIT_INFORMATION,
-    JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE,
-};
-use windows::Win32::System::Threading::{OpenProcess, PROCESS_SET_QUOTA, PROCESS_TERMINATE};
-
-/// A Windows Job Object configured to kill all assigned processes when closed.
-///
-/// When this struct is dropped or when the owning process exits (even abnormally),
-/// Windows will automatically terminate all processes that have been assigned to it.
-pub struct JobObject(HANDLE);
-
-// SAFETY: HANDLE is just a pointer-sized value, and Windows job objects
-// can be safely accessed from multiple threads.
-unsafe impl Send for JobObject {}
-unsafe impl Sync for JobObject {}
-
-impl JobObject {
-    /// Creates a new anonymous job object with `JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE` set.
-    ///
-    /// When the last handle to this job is closed (including on process exit),
-    /// Windows will terminate all processes assigned to the job.
-    pub fn new() -> Result<Self> {
-        unsafe {
-            // Create an anonymous job object
-            let job = CreateJobObjectW(None, None).map_err(|e| Error::other(e.message()))?;
-
-            // Configure the job to kill all processes when the handle is closed
-            let mut info = JOBOBJECT_EXTENDED_LIMIT_INFORMATION::default();
-            info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE;
-
-            SetInformationJobObject(
-                job,
-                JobObjectExtendedLimitInformation,
-                &info as *const _ as *const std::ffi::c_void,
-                std::mem::size_of::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() as u32,
-            )
-            .map_err(|e| Error::other(e.message()))?;
-
-            Ok(Self(job))
-        }
-    }
-
-    /// Assigns a process to this job object by its process ID.
-    ///
-    /// Once assigned, the process will be terminated when this job object is dropped
-    /// or when the owning process exits.
-    ///
-    /// # Arguments
-    /// * `pid` - The process ID of the process to assign
-    pub fn assign_pid(&self, pid: u32) -> Result<()> {
-        unsafe {
-            // Open a handle to the process with the minimum required permissions
-            // PROCESS_SET_QUOTA and PROCESS_TERMINATE are required by AssignProcessToJobObject
-            let process = OpenProcess(PROCESS_SET_QUOTA | PROCESS_TERMINATE, false, pid)
-                .map_err(|e| Error::other(e.message()))?;
-
-            // Assign the process to the job
-            let result = AssignProcessToJobObject(self.0, process);
-
-            // Close our handle to the process - the job object maintains its own reference
-            let _ = CloseHandle(process);
-
-            result.map_err(|e| Error::other(e.message()))
-        }
-    }
-}
-
-impl Drop for JobObject {
-    fn drop(&mut self) {
-        unsafe {
-            // When this handle is closed and it's the last handle to the job,
-            // Windows will terminate all processes in the job due to KILL_ON_JOB_CLOSE
-            let _ = CloseHandle(self.0);
-        }
-    }
-}
-
-/// Holds the Windows Job Object that ensures child processes are killed when the app exits.
-/// On Windows, when the job object handle is closed (including on crash), all assigned
-/// processes are automatically terminated by the OS.
-#[cfg(windows)]
-pub struct JobObjectState {
-    job: Mutex<Option<JobObject>>,
-    error: Mutex<Option<String>>,
-}
-
-#[cfg(windows)]
-impl JobObjectState {
-    pub fn new() -> Self {
-        match JobObject::new() {
-            Ok(job) => Self {
-                job: Mutex::new(Some(job)),
-                error: Mutex::new(None),
-            },
-            Err(e) => {
-                tracing::error!("Failed to create job object: {e}");
-                Self {
-                    job: Mutex::new(None),
-                    error: Mutex::new(Some(format!("Failed to create job object: {e}"))),
-                }
-            }
-        }
-    }
-
-    pub fn assign_pid(&self, pid: u32) {
-        if let Some(job) = self.job.lock().unwrap().as_ref() {
-            if let Err(e) = job.assign_pid(pid) {
-                tracing::error!(pid, "Failed to assign process to job object: {e}");
-                *self.error.lock().unwrap() =
-                    Some(format!("Failed to assign process to job object: {e}"));
-            } else {
-                tracing::info!(pid, "Assigned process to job object for automatic cleanup");
-            }
-        }
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-
-    #[test]
-    fn test_job_object_creation() {
-        let job = JobObject::new();
-        assert!(job.is_ok(), "Failed to create job object: {:?}", job.err());
-    }
-}

+ 1 - 14
packages/desktop/src-tauri/src/lib.rs

@@ -1,7 +1,5 @@
 mod cli;
 mod constants;
-#[cfg(windows)]
-mod job_object;
 #[cfg(target_os = "linux")]
 pub mod linux_display;
 mod logging;
@@ -10,12 +8,11 @@ mod server;
 mod window_customizer;
 mod windows;
 
+use crate::cli::CommandChild;
 use futures::{
     FutureExt, TryFutureExt,
     future::{self, Shared},
 };
-#[cfg(windows)]
-use job_object::*;
 use std::{
     env,
     net::TcpListener,
@@ -27,7 +24,6 @@ use std::{
 use tauri::{AppHandle, Listener, Manager, RunEvent, State, ipc::Channel};
 #[cfg(any(target_os = "linux", all(debug_assertions, windows)))]
 use tauri_plugin_deep_link::DeepLinkExt;
-use tauri_plugin_shell::process::CommandChild;
 use tauri_specta::Event;
 use tokio::{
     sync::{oneshot, watch},
@@ -631,12 +627,6 @@ async fn initialize(app: AppHandle) {
 
                             tracing::info!("CLI health check OK");
 
-                            #[cfg(windows)]
-                            {
-                                let job_state = app.state::<JobObjectState>();
-                                job_state.assign_pid(child.pid());
-                            }
-
                             app.state::<ServerState>().set_child(Some(child));
 
                             Ok(ServerReadyData { url, password })
@@ -710,9 +700,6 @@ fn setup_app(app: &tauri::AppHandle, init_rx: watch::Receiver<InitStep>) {
     #[cfg(any(target_os = "linux", all(debug_assertions, windows)))]
     app.deep_link().register_all().ok();
 
-    #[cfg(windows)]
-    app.manage(JobObjectState::new());
-
     app.manage(InitState { current: init_rx });
 }
 

+ 2 - 9
packages/desktop/src-tauri/src/logging.rs

@@ -36,11 +36,7 @@ pub fn init(log_dir: &Path) -> WorkerGuard {
     tracing_subscriber::registry()
         .with(filter)
         .with(fmt::layer().with_writer(std::io::stderr))
-        .with(
-            fmt::layer()
-                .with_writer(non_blocking)
-                .with_ansi(false),
-        )
+        .with(fmt::layer().with_writer(non_blocking).with_ansi(false))
         .init();
 
     guard
@@ -55,10 +51,7 @@ pub fn tail() -> String {
         return String::new();
     };
 
-    let lines: Vec<String> = BufReader::new(file)
-        .lines()
-        .map_while(Result::ok)
-        .collect();
+    let lines: Vec<String> = BufReader::new(file).lines().map_while(Result::ok).collect();
 
     let start = lines.len().saturating_sub(TAIL_LINES);
     lines[start..].join("\n")

+ 1 - 1
packages/desktop/src-tauri/src/server.rs

@@ -2,12 +2,12 @@ use std::time::{Duration, Instant};
 
 use tauri::AppHandle;
 use tauri_plugin_dialog::{DialogExt, MessageDialogButtons, MessageDialogResult};
-use tauri_plugin_shell::process::CommandChild;
 use tauri_plugin_store::StoreExt;
 use tokio::task::JoinHandle;
 
 use crate::{
     cli,
+    cli::CommandChild,
     constants::{DEFAULT_SERVER_URL_KEY, SETTINGS_STORE, WSL_ENABLED_KEY},
 };