Sfoglia il codice sorgente

Add `dev list` command

NextTurn 5 anni fa
parent
commit
144cff7f19

+ 18 - 10
src/WinSW.Core/Native/Security.cs

@@ -37,20 +37,28 @@ namespace WinSW.Native
             _ = LookupAccountName(null, accountName, IntPtr.Zero, ref sidSize, null, ref domainNameLength, out _);
 
             IntPtr sid = Marshal.AllocHGlobal(sidSize);
-            string? domainName = domainNameLength == 0 ? null : new string('\0', domainNameLength - 1);
-
-            if (!LookupAccountName(null, accountName, sid, ref sidSize, domainName, ref domainNameLength, out _))
+            try
             {
-                Throw.Command.Win32Exception("Failed to find the account.");
-            }
+                string? domainName = domainNameLength == 0 ? null : new string('\0', domainNameLength - 1);
+
+                if (!LookupAccountName(null, accountName, sid, ref sidSize, domainName, ref domainNameLength, out _))
+                {
+                    Throw.Command.Win32Exception("Failed to find the account.");
+                }
 
-            // intentionally undocumented
-            if (!accountName.Contains("\\") && !accountName.Contains("@"))
+                // intentionally undocumented
+                if (!accountName.Contains("\\") && !accountName.Contains("@"))
+                {
+                    accountName = domainName + '\\' + accountName;
+                }
+
+                return sid;
+            }
+            catch
             {
-                accountName = domainName + '\\' + accountName;
+                Marshal.FreeHGlobal(sid);
+                throw;
             }
-
-            return sid;
         }
 
         /// <exception cref="Win32Exception" />

+ 90 - 6
src/WinSW.Core/Native/Service.cs

@@ -1,5 +1,5 @@
 using System;
-using System.Diagnostics;
+using System.Runtime.InteropServices;
 using System.Security.AccessControl;
 using System.ServiceProcess;
 using System.Text;
@@ -56,9 +56,9 @@ namespace WinSW.Native
         private ServiceManager(IntPtr handle) => this.handle = handle;
 
         /// <exception cref="CommandException" />
-        internal static ServiceManager Open()
+        internal static ServiceManager Open(ServiceManagerAccess access = ServiceManagerAccess.All)
         {
-            IntPtr handle = OpenSCManager(null, null, ServiceManagerAccess.ALL_ACCESS);
+            IntPtr handle = OpenSCManager(null, null, access);
             if (handle == IntPtr.Zero)
             {
                 Throw.Command.Win32Exception("Failed to open the service control manager database.");
@@ -81,7 +81,7 @@ namespace WinSW.Native
                 this.handle,
                 serviceName,
                 displayName,
-                ServiceAccess.ALL_ACCESS,
+                ServiceAccess.All,
                 ServiceType.Win32OwnProcess,
                 startMode,
                 ServiceErrorControl.Normal,
@@ -100,7 +100,58 @@ namespace WinSW.Native
         }
 
         /// <exception cref="CommandException" />
-        internal Service OpenService(string serviceName, ServiceAccess access = ServiceAccess.ALL_ACCESS)
+        internal unsafe (IntPtr Services, int Count) EnumerateServices()
+        {
+            int resume = 0;
+            _ = EnumServicesStatus(
+                this.handle,
+                ServiceType.Win32OwnProcess,
+                ServiceState.All,
+                IntPtr.Zero,
+                0,
+                out int bytesNeeded,
+                out _,
+                ref resume);
+
+            IntPtr services = Marshal.AllocHGlobal(bytesNeeded);
+            try
+            {
+                if (!EnumServicesStatus(
+                    this.handle,
+                    ServiceType.Win32OwnProcess,
+                    ServiceState.All,
+                    services,
+                    bytesNeeded,
+                    out _,
+                    out int count,
+                    ref resume))
+                {
+                    Throw.Command.Win32Exception("Failed to enumerate services.");
+                }
+
+                return (services, count);
+            }
+            catch
+            {
+                Marshal.FreeHGlobal(services);
+                throw;
+            }
+        }
+
+        /// <exception cref="CommandException" />
+        internal unsafe Service OpenService(char* serviceName, ServiceAccess access = ServiceAccess.All)
+        {
+            IntPtr serviceHandle = ServiceApis.OpenService(this.handle, serviceName, access);
+            if (serviceHandle == IntPtr.Zero)
+            {
+                Throw.Command.Win32Exception("Failed to open the service.");
+            }
+
+            return new Service(serviceHandle);
+        }
+
+        /// <exception cref="CommandException" />
+        internal Service OpenService(string serviceName, ServiceAccess access = ServiceAccess.All)
         {
             IntPtr serviceHandle = ServiceApis.OpenService(this.handle, serviceName, access);
             if (serviceHandle == IntPtr.Zero)
@@ -113,7 +164,7 @@ namespace WinSW.Native
 
         internal bool ServiceExists(string serviceName)
         {
-            IntPtr serviceHandle = ServiceApis.OpenService(this.handle, serviceName, ServiceAccess.ALL_ACCESS);
+            IntPtr serviceHandle = ServiceApis.OpenService(this.handle, serviceName, ServiceAccess.All);
             if (serviceHandle == IntPtr.Zero)
             {
                 return false;
@@ -140,6 +191,39 @@ namespace WinSW.Native
 
         internal Service(IntPtr handle) => this.handle = handle;
 
+        /// <exception cref="CommandException" />
+        internal unsafe string ExecutablePath
+        {
+            get
+            {
+                _ = QueryServiceConfig(
+                    this.handle,
+                    IntPtr.Zero,
+                    0,
+                    out int bytesNeeded);
+
+                IntPtr config = Marshal.AllocHGlobal(bytesNeeded);
+                try
+                {
+                    if (!QueryServiceConfig(
+                        this.handle,
+                        config,
+                        bytesNeeded,
+                        out _))
+                    {
+                        Throw.Command.Win32Exception("Failed to query service config.");
+                    }
+
+                    return Marshal.PtrToStringUni((IntPtr)((QUERY_SERVICE_CONFIG*)config)->BinaryPathName)!;
+                }
+                finally
+                {
+                    Marshal.FreeHGlobal(config);
+                }
+            }
+        }
+
+        /// <exception cref="CommandException" />
         internal unsafe int ProcessId
         {
             get

+ 91 - 34
src/WinSW.Core/Native/ServiceApis.cs

@@ -60,12 +60,29 @@ namespace WinSW.Native
         [DllImport(Libraries.Advapi32, SetLastError = true)]
         internal static extern bool DeleteService(IntPtr serviceHandle);
 
+        [DllImport(Libraries.Advapi32, SetLastError = true, CharSet = CharSet.Unicode, EntryPoint = "EnumServicesStatusW")]
+        internal static extern unsafe bool EnumServicesStatus(
+            IntPtr databaseHandle,
+            ServiceType serviceType,
+            ServiceState serviceState,
+            IntPtr services,
+            int bufferSize,
+            out int bytesNeeded,
+            out int servicesReturned,
+            ref int resumeHandle);
+
         [DllImport(Libraries.Advapi32, SetLastError = true, CharSet = CharSet.Unicode, EntryPoint = "OpenSCManagerW")]
         internal static extern IntPtr OpenSCManager(string? machineName, string? databaseName, ServiceManagerAccess desiredAccess);
 
+        [DllImport(Libraries.Advapi32, SetLastError = true, CharSet = CharSet.Unicode, EntryPoint = "OpenServiceW")]
+        internal static unsafe extern IntPtr OpenService(IntPtr databaseHandle, char* serviceName, ServiceAccess desiredAccess);
+
         [DllImport(Libraries.Advapi32, SetLastError = true, CharSet = CharSet.Unicode, EntryPoint = "OpenServiceW")]
         internal static extern IntPtr OpenService(IntPtr databaseHandle, string serviceName, ServiceAccess desiredAccess);
 
+        [DllImport(Libraries.Advapi32, SetLastError = true, CharSet = CharSet.Unicode, EntryPoint = "QueryServiceConfigW")]
+        internal static extern bool QueryServiceConfig(IntPtr serviceHandle, IntPtr serviceConfig, int bufferSize, out int bytesNeeded);
+
         [DllImport(Libraries.Advapi32, SetLastError = true)]
         internal static extern bool QueryServiceStatus(IntPtr serviceHandle, out SERVICE_STATUS serviceStatus);
 
@@ -88,27 +105,27 @@ namespace WinSW.Native
         [Flags]
         internal enum ServiceAccess : uint
         {
-            QUERY_CONFIG = 0x0001,
-            CHANGE_CONFIG = 0x0002,
-            QUERY_STATUS = 0x0004,
-            ENUMERATE_DEPENDENTS = 0x0008,
-            START = 0x0010,
-            STOP = 0x0020,
-            PAUSE_CONTINUE = 0x0040,
-            INTERROGATE = 0x0080,
-            USER_DEFINED_CONTROL = 0x0100,
-
-            ALL_ACCESS =
+            QueryConfig = 0x0001,
+            ChangeConfig = 0x0002,
+            QueryStatus = 0x0004,
+            EnumerateDependents = 0x0008,
+            Start = 0x0010,
+            Stop = 0x0020,
+            PauseContinue = 0x0040,
+            Interrogate = 0x0080,
+            UserDefinedControl = 0x0100,
+
+            All =
                 SecurityApis.StandardAccess.REQUIRED |
-                QUERY_CONFIG |
-                CHANGE_CONFIG |
-                QUERY_STATUS |
-                ENUMERATE_DEPENDENTS |
-                START |
-                STOP |
-                PAUSE_CONTINUE |
-                INTERROGATE |
-                USER_DEFINED_CONTROL,
+                QueryConfig |
+                ChangeConfig |
+                QueryStatus |
+                EnumerateDependents |
+                Start |
+                Stop |
+                PauseContinue |
+                Interrogate |
+                UserDefinedControl,
         }
 
         // SERVICE_CONFIG_
@@ -140,21 +157,29 @@ namespace WinSW.Native
         [Flags]
         internal enum ServiceManagerAccess : uint
         {
-            CONNECT = 0x0001,
-            CREATE_SERVICE = 0x0002,
-            ENUMERATE_SERVICE = 0x0004,
-            LOCK = 0x0008,
-            QUERY_LOCK_STATUS = 0x0010,
-            MODIFY_BOOT_CONFIG = 0x0020,
-
-            ALL_ACCESS =
+            Connect = 0x0001,
+            CreateService = 0x0002,
+            EnumerateService = 0x0004,
+            Lock = 0x0008,
+            QueryLockStatus = 0x0010,
+            ModifyBootConfig = 0x0020,
+
+            All =
                 SecurityApis.StandardAccess.REQUIRED |
-                CONNECT |
-                CREATE_SERVICE |
-                ENUMERATE_SERVICE |
-                LOCK |
-                QUERY_LOCK_STATUS |
-                MODIFY_BOOT_CONFIG,
+                Connect |
+                CreateService |
+                EnumerateService |
+                Lock |
+                QueryLockStatus |
+                ModifyBootConfig,
+        }
+
+        // SERVICE_
+        internal enum ServiceState : uint
+        {
+            Active = 0x00000001,
+            Inactive = 0x00000002,
+            All = 0x00000003,
         }
 
         // SC_STATUS_
@@ -163,6 +188,38 @@ namespace WinSW.Native
             ProcessInfo = 0,
         }
 
+        internal readonly unsafe struct ENUM_SERVICE_STATUS
+        {
+            public readonly char* ServiceName;
+            public readonly char* DisplayName;
+            public readonly SERVICE_STATUS ServiceStatus;
+
+            public override string ToString()
+            {
+                var serviceName = new ReadOnlySpan<char>(this.ServiceName, new ReadOnlySpan<char>(this.ServiceName, 256).IndexOf('\0'));
+                var displayName = new ReadOnlySpan<char>(this.DisplayName, new ReadOnlySpan<char>(this.DisplayName, 256).IndexOf('\0'));
+
+#if NETCOREAPP
+                return string.Concat(displayName, " (", serviceName, ")");
+#else
+                return string.Concat(displayName.ToString(), " (", serviceName.ToString(), ")");
+#endif
+            }
+        }
+
+        internal unsafe struct QUERY_SERVICE_CONFIG
+        {
+            public ServiceType ServiceType;
+            public ServiceStartMode StartType;
+            public ServiceErrorControl ErrorControl;
+            public char* BinaryPathName;
+            public char* LoadOrderGroup;
+            public uint TagId;
+            public char* Dependencies;
+            public char* ServiceStartName;
+            public char* DisplayName;
+        }
+
         internal struct SERVICE_DELAYED_AUTO_START_INFO
         {
             public bool DelayedAutostart;

+ 3 - 1
src/WinSW.Core/WinSW.Core.csproj

@@ -9,7 +9,7 @@
 
   <ItemGroup>
     <PackageReference Include="log4net" Version="2.0.8" />
-    <PackageReference Include="stylecop.analyzers" Version="1.2.0-beta.*">
+    <PackageReference Include="StyleCop.Analyzers" Version="1.2.0-beta.*">
       <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
       <PrivateAssets>all</PrivateAssets>
     </PackageReference>
@@ -32,6 +32,8 @@
   </ItemGroup>
 
   <ItemGroup Condition="'$(TargetFramework)' != 'net5.0'">
+    <PackageReference Include="System.Memory" Version="4.5.4" />
+    <PackageReference Include="System.ValueTuple" Version="4.5.0" />
     <Reference Include="System.ServiceProcess" />
   </ItemGroup>
 

+ 2 - 2
src/WinSW.Tests/Util/InterProcessCodeCoverageSession.cs

@@ -29,8 +29,8 @@ namespace WinSW.Tests.Util
             var hitsField = this.hitsField = trackerType.GetField("HitsArray", BindingFlags.Public | BindingFlags.Static);
             Assert.NotNull(hitsField);
 
-            using var scm = ServiceManager.Open();
-            using var sc = scm.OpenService(serviceName, ServiceApis.ServiceAccess.QUERY_STATUS);
+            using var scm = ServiceManager.Open(ServiceApis.ServiceManagerAccess.Connect);
+            using var sc = scm.OpenService(serviceName, ServiceApis.ServiceAccess.QueryStatus);
 
             int processId = sc.ProcessId;
             Assert.True(processId >= 0);

+ 57 - 25
src/WinSW/Program.cs

@@ -9,6 +9,7 @@ using System.Diagnostics;
 using System.IO;
 using System.Linq;
 using System.Reflection;
+using System.Runtime.InteropServices;
 using System.Security.AccessControl;
 using System.Security.Principal;
 using System.ServiceProcess;
@@ -22,6 +23,7 @@ using log4net.Layout;
 using WinSW.Logging;
 using WinSW.Native;
 using WinSW.Util;
+using static WinSW.Native.ServiceApis;
 using Process = System.Diagnostics.Process;
 using TimeoutException = System.ServiceProcess.TimeoutException;
 
@@ -271,27 +273,41 @@ namespace WinSW
             }
 
             {
-                var dev = new Command("dev", "Experimental commands.")
-                {
-                    config,
-                    noElevate,
-                };
+                var dev = new Command("dev", "Experimental commands.");
 
                 root.Add(dev);
 
-                var ps = new Command("ps", "Draws the process tree associated with the service.")
                 {
-                    Handler = CommandHandler.Create<string?, bool>(DevPs),
-                };
+                    var ps = new Command("ps", "Draws the process tree associated with the service.")
+                    {
+                        Handler = CommandHandler.Create<string?>(DevPs),
+                    };
 
-                dev.Add(ps);
+                    ps.Add(config);
+
+                    dev.Add(ps);
+                }
 
-                var kill = new Command("kill", "Terminates the service if it has stopped responding.")
                 {
-                    Handler = CommandHandler.Create<string?, bool>(DevKill),
-                };
+                    var kill = new Command("kill", "Terminates the service if it has stopped responding.")
+                    {
+                        Handler = CommandHandler.Create<string?, bool>(DevKill),
+                    };
+
+                    kill.Add(config);
+                    kill.Add(noElevate);
 
-                dev.Add(kill);
+                    dev.Add(kill);
+                }
+
+                {
+                    var list = new Command("list", "Lists services managed by the current executable.")
+                    {
+                        Handler = CommandHandler.Create(DevList),
+                    };
+
+                    dev.Add(list);
+                }
             }
 
             return new CommandLineBuilder(root)
@@ -374,7 +390,7 @@ namespace WinSW
 
                 Log.Info($"Installing service '{config.Format()}'...");
 
-                using ServiceManager scm = ServiceManager.Open();
+                using ServiceManager scm = ServiceManager.Open(ServiceManagerAccess.CreateService);
 
                 if (scm.ServiceExists(config.Name))
                 {
@@ -499,7 +515,7 @@ namespace WinSW
 
                 Log.Info($"Uninstalling service '{config.Format()}'...");
 
-                using ServiceManager scm = ServiceManager.Open();
+                using ServiceManager scm = ServiceManager.Open(ServiceManagerAccess.Connect);
                 try
                 {
                     using Service sc = scm.OpenService(config.Name);
@@ -822,7 +838,7 @@ namespace WinSW
                     return;
                 }
 
-                using ServiceManager scm = ServiceManager.Open();
+                using ServiceManager scm = ServiceManager.Open(ServiceManagerAccess.Connect);
                 try
                 {
                     using Service sc = scm.OpenService(config.Name);
@@ -864,18 +880,12 @@ namespace WinSW
                 Log.Info($"Service '{config.Format()}' was refreshed successfully.");
             }
 
-            void DevPs(string? pathToConfig, bool noElevate)
+            static void DevPs(string? pathToConfig)
             {
                 XmlServiceConfig config = CreateConfig(pathToConfig);
 
-                if (!elevated)
-                {
-                    Elevate(noElevate);
-                    return;
-                }
-
-                using ServiceManager scm = ServiceManager.Open();
-                using Service sc = scm.OpenService(config.Name);
+                using ServiceManager scm = ServiceManager.Open(ServiceManagerAccess.Connect);
+                using Service sc = scm.OpenService(config.Name, ServiceAccess.QueryStatus);
 
                 int processId = sc.ProcessId;
                 if (processId >= 0)
@@ -938,6 +948,28 @@ namespace WinSW
                 }
             }
 
+            static unsafe void DevList()
+            {
+                using var scm = ServiceManager.Open(ServiceManagerAccess.EnumerateService);
+                (IntPtr services, int count) = scm.EnumerateServices();
+                try
+                {
+                    for (int i = 0; i < count; i++)
+                    {
+                        var status = (ServiceApis.ENUM_SERVICE_STATUS*)services + i;
+                        using var sc = scm.OpenService(status->ServiceName, ServiceAccess.QueryConfig);
+                        if (sc.ExecutablePath.StartsWith($"\"{ExecutablePath}\""))
+                        {
+                            Console.WriteLine(status->ToString());
+                        }
+                    }
+                }
+                finally
+                {
+                    Marshal.FreeHGlobal(services);
+                }
+            }
+
             static void Customize(string output, string manufacturer)
             {
                 if (Resources.UpdateCompanyName(ExecutablePath, output, manufacturer))

+ 1 - 1
src/WinSW/WrapperService.cs

@@ -502,7 +502,7 @@ namespace WinSW
         private void SignalStopped()
         {
             using ServiceManager scm = ServiceManager.Open();
-            using Service sc = scm.OpenService(this.ServiceName, ServiceApis.ServiceAccess.QUERY_STATUS);
+            using Service sc = scm.OpenService(this.ServiceName, ServiceApis.ServiceAccess.QueryStatus);
 
             sc.SetStatus(this.ServiceHandle, ServiceControllerStatus.Stopped);
         }