ソースを参照

Bug 1706: Tool for clearing broken COM registrations

https://winscp.net/tracker/1706

Source commit: 4657333708eef2015581028778c5bd271b716088
Martin Prikryl 7 年 前
コミット
1d4e4569a0

+ 1 - 0
dotnet/properties/AssemblyInfo.cs

@@ -18,6 +18,7 @@ using System.Runtime.InteropServices;
 [assembly: ComVisible(false)]
 
 // The following GUID is for the ID of the typelib if this project is exposed to COM
+// Duplicated in ConsoleRunner.cpp
 [assembly: Guid("a0b93468-d98a-4845-a234-8076229ad93f")]
 
 [assembly: AssemblyVersion(WinSCP.AssemblyConstants.Version)]

+ 8 - 0
source/core/Common.cpp

@@ -1432,6 +1432,14 @@ bool __fastcall IsRealFile(const UnicodeString & FileName)
   return (FileName != THISDIRECTORY) && (FileName != PARENTDIRECTORY);
 }
 //---------------------------------------------------------------------------
+UnicodeString GetEnvironmentInfo()
+{
+  UnicodeString OS = WindowsVersionLong();
+  AddToList(OS, WindowsProductName(), L" - ");
+  UnicodeString Result = FORMAT(L"WinSCP %s (OS %s)", (Configuration->VersionStr, OS));
+  return Result;
+}
+//---------------------------------------------------------------------------
 void __fastcall ProcessLocalDirectory(UnicodeString DirName,
   TProcessLocalFileEvent CallBackFunc, void * Param,
   int FindAttrs)

+ 1 - 0
source/core/Common.h

@@ -175,6 +175,7 @@ void __fastcall LoadScriptFromFile(UnicodeString FileName, TStrings * Lines);
 UnicodeString __fastcall StripEllipsis(const UnicodeString & S);
 UnicodeString __fastcall GetFileMimeType(const UnicodeString & FileName);
 bool __fastcall IsRealFile(const UnicodeString & FileName);
+UnicodeString GetEnvironmentInfo();
 //---------------------------------------------------------------------------
 struct TSearchRecSmart : public TSearchRec
 {

+ 6 - 4
source/core/HierarchicalStorage.cpp

@@ -420,12 +420,14 @@ bool __fastcall THierarchicalStorage::GetTemporary()
 __fastcall TRegistryStorage::TRegistryStorage(const UnicodeString AStorage):
   THierarchicalStorage(IncludeTrailingBackslash(AStorage))
 {
+  FWowMode = 0;
   Init();
 };
 //---------------------------------------------------------------------------
-__fastcall TRegistryStorage::TRegistryStorage(const UnicodeString AStorage, HKEY ARootKey):
+__fastcall TRegistryStorage::TRegistryStorage(const UnicodeString AStorage, HKEY ARootKey, REGSAM WowMode):
   THierarchicalStorage(IncludeTrailingBackslash(AStorage))
 {
+  FWowMode = WowMode;
   Init();
   FRegistry->RootKey = ARootKey;
 }
@@ -434,7 +436,7 @@ void __fastcall TRegistryStorage::Init()
 {
   FFailed = 0;
   FRegistry = new TRegistry();
-  FRegistry->Access = KEY_READ;
+  FRegistry->Access = KEY_READ | FWowMode;
 }
 //---------------------------------------------------------------------------
 __fastcall TRegistryStorage::~TRegistryStorage()
@@ -498,12 +500,12 @@ void __fastcall TRegistryStorage::SetAccessMode(TStorageAccessMode value)
   {
     switch (AccessMode) {
       case smRead:
-        FRegistry->Access = KEY_READ;
+        FRegistry->Access = KEY_READ | FWowMode;
         break;
 
       case smReadWrite:
       default:
-        FRegistry->Access = KEY_READ | KEY_WRITE;
+        FRegistry->Access = KEY_READ | KEY_WRITE | FWowMode;
         break;
     }
   }

+ 2 - 1
source/core/HierarchicalStorage.h

@@ -95,7 +95,7 @@ protected:
 class TRegistryStorage : public THierarchicalStorage
 {
 public:
-  __fastcall TRegistryStorage(const UnicodeString AStorage, HKEY ARootKey);
+  __fastcall TRegistryStorage(const UnicodeString AStorage, HKEY ARootKey, REGSAM WowMode = 0);
   __fastcall TRegistryStorage(const UnicodeString AStorage);
   virtual __fastcall ~TRegistryStorage();
 
@@ -139,6 +139,7 @@ protected:
 private:
   TRegistry * FRegistry;
   int FFailed;
+  REGSAM FWowMode;
 
   void __fastcall Init();
 };

+ 1 - 3
source/core/SessionInfo.cpp

@@ -1124,9 +1124,7 @@ void __fastcall TSessionLog::DoAddStartupInfo(TSessionData * Data)
   if (Data == NULL)
   {
     AddSeparator();
-    UnicodeString OS = WindowsVersionLong();
-    AddToList(OS, WindowsProductName(), L" - ");
-    ADF(L"WinSCP %s (OS %s)", (FConfiguration->VersionStr, OS));
+    ADSTR(GetEnvironmentInfo());
     THierarchicalStorage * Storage = FConfiguration->CreateConfigStorage();
     try
     {

+ 21 - 30
source/windows/ConsoleRunner.cpp

@@ -10,6 +10,7 @@
 #include <PuttyTools.h>
 #include <Queue.h>
 #include <HierarchicalStorage.h>
+#include <Tools.h>
 
 #include <Consts.hpp>
 #include <StrUtils.hpp>
@@ -23,7 +24,7 @@
 #include "SynchronizeController.h"
 #include "GUITools.h"
 #include "VCLCommon.h"
-enum { RESULT_SUCCESS = 0, RESULT_ANY_ERROR = 1 };
+#include "Setup.h"
 //---------------------------------------------------------------------------
 #define WM_INTERUPT_IDLE (WM_WINSCP_USER + 3)
 #define BATCH_INPUT_TIMEOUT 10000
@@ -39,28 +40,6 @@ void TrimNewLine(UnicodeString & Str)
   }
 }
 //---------------------------------------------------------------------------
-class TConsole
-{
-public:
-  virtual __fastcall ~TConsole() {};
-  virtual void __fastcall Print(UnicodeString Str, bool FromBeginning = false, bool Error = false) = 0;
-  void __fastcall PrintLine(const UnicodeString & Str = UnicodeString(), bool Error = false);
-  virtual bool __fastcall Input(UnicodeString & Str, bool Echo, unsigned int Timer) = 0;
-  virtual int __fastcall Choice(
-    UnicodeString Options, int Cancel, int Break, int Continue, int Timeouted, bool Timeouting, unsigned int Timer,
-    UnicodeString Message) = 0;
-  virtual bool __fastcall PendingAbort() = 0;
-  virtual void __fastcall SetTitle(UnicodeString Title) = 0;
-  virtual bool __fastcall LimitedOutput() = 0;
-  virtual bool __fastcall LiveOutput() = 0;
-  virtual bool __fastcall NoInteractiveInput() = 0;
-  virtual void __fastcall WaitBeforeExit() = 0;
-  virtual bool __fastcall CommandLineOnly() = 0;
-  virtual bool __fastcall WantsProgress() = 0;
-  virtual void __fastcall Progress(TScriptProgress & Progress) = 0;
-  virtual UnicodeString __fastcall FinalLogMessage() = 0;
-};
-//---------------------------------------------------------------------------
 void __fastcall TConsole::PrintLine(const UnicodeString & Str, bool Error)
 {
   Print(Str + L"\n", false, Error);
@@ -2360,7 +2339,7 @@ void __fastcall BatchSettings(TConsole * Console, TProgramParams * Params)
   }
 }
 //---------------------------------------------------------------------------
-static int __fastcall HandleException(TConsole * Console, Exception & E)
+int __fastcall HandleException(TConsole * Console, Exception & E)
 {
   UnicodeString Message;
   if (ExceptionFullMessage(&E, Message))
@@ -2629,7 +2608,6 @@ int __fastcall DumpCallstack(TConsole * Console, TProgramParams * Params)
 //---------------------------------------------------------------------------
 void static PrintList(TConsole * Console, const UnicodeString & Caption, TStrings * List)
 {
-  std::unique_ptr<TStrings> Owner(List);
   Console->PrintLine(Caption);
   for (int Index = 0; Index < List->Count; Index++)
   {
@@ -2638,16 +2616,22 @@ void static PrintList(TConsole * Console, const UnicodeString & Caption, TString
   Console->PrintLine();
 }
 //---------------------------------------------------------------------------
+void static PrintListAndFree(TConsole * Console, const UnicodeString & Caption, TStrings * List)
+{
+  std::unique_ptr<TStrings> Owner(List);
+  PrintList(Console, Caption, List);
+}
+//---------------------------------------------------------------------------
 int Info(TConsole * Console)
 {
   int Result = RESULT_SUCCESS;
   try
   {
-    PrintList(Console, L"SSH encryption ciphers:", SshCipherList());
-    PrintList(Console, L"SSH key exchange algoritms:", SshKexList());
-    PrintList(Console, L"SSH host key algoritms:", SshHostKeyList());
-    PrintList(Console, L"SSH MAC algoritms:", SshMacList());
-    PrintList(Console, L"TLS/SSL cipher suites:", TlsCipherList());
+    PrintListAndFree(Console, L"SSH encryption ciphers:", SshCipherList());
+    PrintListAndFree(Console, L"SSH key exchange algoritms:", SshKexList());
+    PrintListAndFree(Console, L"SSH host key algoritms:", SshHostKeyList());
+    PrintListAndFree(Console, L"SSH MAC algoritms:", SshMacList());
+    PrintListAndFree(Console, L"TLS/SSL cipher suites:", TlsCipherList());
   }
   catch (Exception & E)
   {
@@ -2726,6 +2710,13 @@ int __fastcall Console(TConsoleMode Mode)
     {
       Result = Info(Console);
     }
+    else if (Mode == cmComRegistration)
+    {
+      if (CheckSafe(Params))
+      {
+        Result = ComRegistration(Console);
+      }
+    }
     else
     {
       Runner = new TConsoleRunner(Console);

+ 448 - 0
source/windows/Setup.cpp

@@ -2204,3 +2204,451 @@ UnicodeString __fastcall GetPowerShellVersionStr()
 
   return PowerShellVersionStr;
 }
+//---------------------------------------------------------------------------
+//---------------------------------------------------------------------------
+static void CollectCLSIDKey(
+  TConsole * Console, TStrings * Keys, int PlatformSet, TRegistryStorage * Storage, const UnicodeString & CLSID,
+  UnicodeString & CommonCodeBase, const UnicodeString & Platform, UnicodeString & Platforms)
+{
+  UnicodeString CLSIDKey = FORMAT(L"CLSID\\%s", (CLSID));
+  if (Storage->OpenSubKey(CLSIDKey, false, true))
+  {
+    int Index = Keys->IndexOf(CLSIDKey);
+    if (Index >= 0)
+    {
+      Keys->Objects[Index] = reinterpret_cast<TObject *>(PlatformSet | reinterpret_cast<int>(Keys->Objects[Index]));
+    }
+    else
+    {
+      Keys->AddObject(CLSIDKey, reinterpret_cast<TObject *>(PlatformSet));
+    }
+    UnicodeString CodeBase;
+    if (Storage->OpenSubKey(L"InprocServer32", false))
+    {
+      CodeBase = Storage->ReadString(L"CodeBase", UnicodeString());
+      UnicodeString Assembly = Storage->ReadString(L"Assembly", UnicodeString());
+      UnicodeString Version;
+      if (!Assembly.IsEmpty())
+      {
+        UnicodeString VersionPrefix = L"Version=";
+        int P = Assembly.UpperCase().Pos(VersionPrefix.UpperCase());
+        if (P > 0)
+        {
+          Assembly.Delete(1, P + VersionPrefix.Length() - 1);
+          Version = CutToChar(Assembly, L',', true);
+        }
+      }
+      if (CodeBase.IsEmpty() || Version.IsEmpty())
+      {
+        Console->PrintLine(FORMAT(L"Warning: Could not find codebase and version for %s.", (CLSID)));
+        CodeBase = UnicodeString();
+      }
+      else
+      {
+        CodeBase = FORMAT(L"%s (%s)", (CodeBase, Version));
+        if (CommonCodeBase.IsEmpty())
+        {
+          Console->PrintLine(FORMAT(L"Codebase %s, unless stated otherwise", (CodeBase)));
+          CommonCodeBase = CodeBase;
+        }
+        if (SameText(CommonCodeBase, CodeBase))
+        {
+          CodeBase = L"";
+        }
+      }
+      Storage->CloseSubKey();
+    }
+    Storage->CloseSubKey();
+
+    UnicodeString Buf = Platform;
+    AddToList(Buf, CodeBase, L" - ");
+    AddToList(Platforms, Buf, ", ");
+  }
+}
+//---------------------------------------------------------------------------
+static UnicodeString PlatformStr(int PlatformSet)
+{
+  UnicodeString Result;
+  if (PlatformSet == 0)
+  {
+    Result = L"shared";
+  }
+  else
+  {
+    if (FLAGSET(PlatformSet, 32))
+    {
+      Result = L"32-bit";
+    }
+    if (FLAGSET(PlatformSet, 64))
+    {
+      AddToList(Result, L"64-bit", L", ");
+    }
+  }
+  return Result;
+}
+//---------------------------------------------------------------------------
+static void DoCollectComRegistration(TConsole * Console, TStrings * Keys)
+{
+  UnicodeString TypeLib = L"{A0B93468-D98A-4845-A234-8076229AD93F}"; // Duplicated in AssemblyInfo.cs
+  std::unique_ptr<TRegistryStorage> Storage(new TRegistryStorage(UnicodeString(), HKEY_CLASSES_ROOT));
+  Storage->MungeStringValues = false;
+  Storage->AccessMode = smRead;
+  std::unique_ptr<TRegistryStorage> Storage64;
+  if (IsWin64())
+  {
+    Storage64.reset(new TRegistryStorage(UnicodeString(), HKEY_CLASSES_ROOT, KEY_WOW64_64KEY));
+    Storage64->MungeStringValues = false;
+    Storage64->AccessMode = smRead;
+  }
+
+  // Classes, TypeLib and Record are shared between 32-bit and 64-bit registry view.
+  // 32-bit and 64-bit version of regasm adds CLSID keys to its respective view.
+  // Both 32-bit and 64-bit version of regasm seem to add Interface keys to both 32-bit and 64-bit views.
+  // On Vista, Interface keys are reflected (so when 32-bit keys is deleted, it's also deleted from 64-bit key,
+  // and we show an error, trying to delete the 64-bit key).
+
+  if (Storage->OpenRootKey(false))
+  {
+    Console->PrintLine(FORMAT(L"Versions of type library %s:", (TypeLib)));
+    UnicodeString TypeLibKey = FORMAT(L"TypeLib\\%s", (TypeLib));
+    if (Storage->OpenSubKey(TypeLibKey, false, true))
+    {
+      Keys->Add(TypeLibKey);
+      std::unique_ptr<TStringList> KeyNames(new TStringList());
+      Storage->GetSubKeyNames(KeyNames.get());
+      if (KeyNames->Count == 0)
+      {
+        Console->PrintLine(L"Warning: The type library key exists, but no type libraries are present.");
+      }
+      else
+      {
+        for (int Index = 0; Index < KeyNames->Count; Index++)
+        {
+          UnicodeString Version = KeyNames->Strings[Index];
+          if (!Storage->OpenSubKey(FORMAT(L"%s\\0", (Version)), false, true))
+          {
+            Console->PrintLine(FORMAT(L"Warning: Subkey \"0\" for type library \"%s\" cannot be opened.", (Version)));
+          }
+          else
+          {
+            std::unique_ptr<TStringList> Platforms(new TStringList());
+            Storage->GetSubKeyNames(Platforms.get());
+            if (Platforms->Count == 0)
+            {
+              Console->PrintLine(FORMAT(L"Warning: Subkey \"0\" for type library \"%s\" exists, but platforms are present.", (Version)));
+            }
+            else
+            {
+              for (int Index2 = 0; Index2 < Platforms->Count; Index2++)
+              {
+                UnicodeString Platform = Platforms->Strings[Index2];
+                if (!Storage->OpenSubKey(Platform, false))
+                {
+                  Console->PrintLine(FORMAT(L"Warning: Platform \"%s\" for type library \"%s\" cannot be opened.", (Platform, Version)));
+                }
+                else
+                {
+                  UnicodeString TypeLibraryPath = Storage->ReadString(UnicodeString(), UnicodeString());
+                  UnicodeString Exists = FileExists(TypeLibraryPath) ? L"exists" : L"does not exist";
+                  Console->PrintLine(FORMAT(L"%s (%s): %s (%s)", (Version, Platform, TypeLibraryPath, Exists)));
+                  Storage->CloseSubKey();
+                }
+              }
+            }
+            Storage->CloseSubKey();
+          }
+        }
+      }
+      Storage->CloseSubKey();
+    }
+    else
+    {
+      Console->PrintLine(L"Type library not registered.");
+    }
+    Console->PrintLine();
+
+    std::unique_ptr<TStringList> KeyNames(new TStringList());
+    Storage->GetSubKeyNames(KeyNames.get());
+    UnicodeString NamespacePrefix = L"WinSCP.";
+    Console->PrintLine(L"Classes:");
+    UnicodeString CommonCodeBase;
+    int Found = 0;
+    for (int Index = 0; Index < KeyNames->Count; Index++)
+    {
+      UnicodeString KeyName = KeyNames->Strings[Index];
+      if (StartsText(NamespacePrefix, KeyName))
+      {
+        if (Storage->OpenSubKey(FORMAT(L"%s\\%s", (KeyName, L"CLSID")), false, true))
+        {
+          UnicodeString Class = KeyName;
+          UnicodeString CLSID = Trim(Storage->ReadString(UnicodeString(), UnicodeString()));
+          Storage->CloseSubKey();
+
+          if (!CLSID.IsEmpty())
+          {
+            Keys->Add(KeyName);
+            UnicodeString Platforms;
+            CollectCLSIDKey(Console, Keys, 32, Storage.get(), CLSID, CommonCodeBase, L"32-bit", Platforms);
+            if (Storage64.get() != NULL)
+            {
+              CollectCLSIDKey(Console, Keys, 64, Storage64.get(), CLSID, CommonCodeBase, L"64-bit", Platforms);
+            }
+
+            UnicodeString Line = FORMAT(L"%s - %s", (Class, CLSID));
+
+            if (Platforms.IsEmpty())
+            {
+              Console->PrintLine(FORMAT(L"Warning: Could not find CLSID %s for class %s.", (CLSID, Class)));
+            }
+            else
+            {
+              Line += FORMAT(L" [%s]", (Platforms));
+            }
+
+            Console->PrintLine(Line);
+            Found++;
+          }
+        }
+      }
+    }
+    if (Found == 0)
+    {
+      Console->PrintLine(L"No classes found.");
+    }
+    Console->PrintLine();
+
+    UnicodeString InterfaceKey = L"Interface";
+    if (Storage->OpenSubKey(InterfaceKey, false) &&
+        ((Storage64.get() == NULL) || Storage64->OpenSubKey(InterfaceKey, false)))
+    {
+      Console->PrintLine(L"Interfaces:");
+      std::unique_ptr<TStringList> KeyNames(new TStringList());
+      Storage->GetSubKeyNames(KeyNames.get());
+      KeyNames->Sorted = true;
+      for (int Index = 0; Index < KeyNames->Count; Index++)
+      {
+        KeyNames->Objects[Index] = reinterpret_cast<TObject *>(32);
+      }
+      if (Storage64.get() != NULL)
+      {
+        std::unique_ptr<TStringList> KeyNames64(new TStringList());
+        Storage64->GetSubKeyNames(KeyNames64.get());
+        for (int Index = 0; Index < KeyNames64->Count; Index++)
+        {
+          UnicodeString Key64 = KeyNames64->Strings[Index];
+          int Index32 = KeyNames->IndexOf(Key64);
+          if (Index32 >= 0)
+          {
+            KeyNames->Objects[Index32] = reinterpret_cast<TObject *>(32 | 64);
+          }
+          else
+          {
+            KeyNames->AddObject(Key64, reinterpret_cast<TObject *>(64));
+          }
+        }
+      }
+      int Found = 0;
+      for (int Index = 0; Index < KeyNames->Count; Index++)
+      {
+        UnicodeString KeyName = KeyNames->Strings[Index];
+        // Open sub key first, to check if we are interested in the interface, as an optimization
+        int PlatformSet = reinterpret_cast<int>(KeyNames->Objects[Index]);
+        THierarchicalStorage * KeyStorage = FLAGSET(PlatformSet, 32) ? Storage.get() : Storage64.get();
+        if (KeyStorage->OpenSubKey(FORMAT(L"%s\\TypeLib", (KeyName)), false, true))
+        {
+          UnicodeString InterfaceTypeLib = KeyStorage->ReadString(UnicodeString(), UnicodeString());
+          UnicodeString Version = KeyStorage->ReadString(L"Version", UnicodeString());
+          KeyStorage->CloseSubKey();
+          if (SameText(InterfaceTypeLib, TypeLib))
+          {
+            if (KeyStorage->OpenSubKey(KeyName, false))
+            {
+              UnicodeString Key = ExcludeTrailingBackslash(KeyStorage->CurrentSubKey);
+              UnicodeString Interface = KeyStorage->ReadString(UnicodeString(), UnicodeString());
+              KeyStorage->CloseSubKey();
+              Keys->AddObject(Key, reinterpret_cast<TObject *>(PlatformSet));
+              Console->PrintLine(FORMAT(L"%s - %s (%s) [%s]", (Interface, KeyName, Version, PlatformStr(PlatformSet))));
+              Found++;
+            }
+          }
+        }
+      }
+      if (Found == 0)
+      {
+        Console->PrintLine(L"No interfaces found.");
+      }
+      Storage->CloseSubKey();
+      Console->PrintLine();
+    }
+
+    if (Storage->OpenSubKey(L"Record", false))
+    {
+      Console->PrintLine(L"Value types:");
+      std::unique_ptr<TStringList> KeyNames(new TStringList());
+      Storage->GetSubKeyNames(KeyNames.get());
+      int Found = 0;
+      for (int Index = 0; Index < KeyNames->Count; Index++)
+      {
+        UnicodeString KeyName = KeyNames->Strings[Index];
+        if (Storage->OpenSubKey(KeyName, false))
+        {
+          std::unique_ptr<TStringList> Versions(new TStringList());
+          Storage->GetSubKeyNames(Versions.get());
+          UnicodeString VersionsStr;
+          std::unique_ptr<TStringList> Classes(CreateSortedStringList());
+          for (int Index2 = 0; Index2 < Versions->Count; Index2++)
+          {
+            UnicodeString Version = Versions->Strings[Index2];
+            if (Storage->OpenSubKey(Version, false))
+            {
+              UnicodeString Class = Storage->ReadString(L"Class", UnicodeString());
+              Classes->Add(Class);
+              if (StartsStr(NamespacePrefix, Class))
+              {
+                AddToList(VersionsStr, Version, L", ");
+              }
+              Storage->CloseSubKey();
+            }
+          }
+          if (!VersionsStr.IsEmpty())
+          {
+            if (Classes->Count != 1)
+            {
+              Console->PrintLine(FORMAT(L"Warning: Different class names for the same value type %s: %s", (KeyName, Classes->CommaText)));
+            }
+            else
+            {
+              Keys->Add(ExcludeTrailingBackslash(Storage->CurrentSubKey));
+              Console->PrintLine(FORMAT(L"%s - %s (%s)", (Classes->Strings[0], KeyName, VersionsStr)));
+              Found++;
+            }
+          }
+          Storage->CloseSubKey();
+        }
+      }
+      if (Found == 0)
+      {
+        Console->PrintLine(L"No value types found.");
+      }
+      Storage->CloseSubKey();
+      Console->PrintLine();
+    }
+  }
+}
+//---------------------------------------------------------------------------
+bool DoUnregisterChoice(TConsole * Console)
+{
+  return (Console->Choice(L"U", -1, -1, -1, 0, 0, 0, UnicodeString()) == 1);
+}
+//---------------------------------------------------------------------------
+void DoDeleteKey(TConsole * Console, TRegistry * Registry, const UnicodeString & Key, int Platform, bool & AnyDeleted, bool & AllDeleted)
+{
+  UnicodeString ParentKey = ExtractFileDir(Key);
+  UnicodeString ChildKey = ExtractFileName(Key);
+  bool Result = Registry->OpenKey(ParentKey, false);
+  if (Result)
+  {
+    Result = (RegDeleteTreeW(Registry->CurrentKey, ChildKey.c_str()) == 0);
+    Registry->CloseKey();
+  }
+
+  UnicodeString Status;
+  if (Result)
+  {
+    Status = L"removed";
+    AnyDeleted = true;
+  }
+  else
+  {
+    AllDeleted = false;
+    Status = L"NOT removed";
+  }
+  Console->PrintLine(FORMAT(L"%s [%s] - %s", (Key, PlatformStr(Platform), Status)));
+}
+//---------------------------------------------------------------------------
+int ComRegistration(TConsole * Console)
+{
+  int Result = RESULT_SUCCESS;
+  try
+  {
+    Console->PrintLine(GetEnvironmentInfo());
+    Console->PrintLine();
+
+    std::unique_ptr<TStrings> Keys(new TStringList());
+    DoCollectComRegistration(Console, Keys.get());
+
+    if (Keys->Count == 0)
+    {
+      Console->PrintLine(L"No registration found.");
+      Console->WaitBeforeExit();
+    }
+    else
+    {
+      Console->PrintLine(L"Press (U) to unregister or Esc to exit...");
+      if (DoUnregisterChoice(Console))
+      {
+        Console->PrintLine();
+        Console->PrintLine(L"The following registry keys will be removed from HKCR registry hive:");
+        for (int Index = 0; Index < Keys->Count; Index++)
+        {
+          Console->PrintLine(FORMAT(L"%s [%s]", (Keys->Strings[Index], PlatformStr(reinterpret_cast<int>(Keys->Objects[Index])))));
+        }
+        Console->PrintLine();
+
+        Console->PrintLine(L"You need Administrator privileges to remove the keys.");
+        Console->PrintLine(L"Press (U) again to proceed with unregistration or Esc to abort...");
+        if (DoUnregisterChoice(Console))
+        {
+          std::unique_ptr<TRegistry> Registry64(new TRegistry(KEY_READ | KEY_WRITE | KEY_WOW64_64KEY));
+          Registry64->RootKey = HKEY_CLASSES_ROOT;
+          std::unique_ptr<TRegistry> Registry32(new TRegistry(KEY_READ | KEY_WRITE | KEY_WOW64_32KEY));
+          Registry32->RootKey = HKEY_CLASSES_ROOT;
+
+          bool AnyDeleted = false;
+          bool AllDeleted = true;
+          for (int Index = 0; Index < Keys->Count; Index++)
+          {
+            UnicodeString Key = Keys->Strings[Index];
+            int PlatformSet = reinterpret_cast<int>(Keys->Objects[Index]);
+            if (PlatformSet == 0)
+            {
+              DoDeleteKey(Console, Registry32.get(), Key, 0, AnyDeleted, AllDeleted);
+            }
+            else
+            {
+              if (FLAGSET(PlatformSet, 32))
+              {
+                DoDeleteKey(Console, Registry32.get(), Key, 32, AnyDeleted, AllDeleted);
+              }
+              if (FLAGSET(PlatformSet, 64))
+              {
+                DoDeleteKey(Console, Registry64.get(), Key, 64, AnyDeleted, AllDeleted);
+              }
+            }
+          }
+
+          Console->PrintLine();
+          if (!AnyDeleted)
+          {
+            Console->PrintLine(L"No keys were removed. Make sure you have Administrator privileges.");
+          }
+          else if (!AllDeleted)
+          {
+            Console->PrintLine(L"Some keys were not removed. Make sure you have Administrator privileges.");
+          }
+          else
+          {
+            Console->PrintLine(L"All keys were removed. Unregistration succeeded.");
+          }
+          Console->WaitBeforeExit();
+        }
+      }
+    }
+  }
+  catch (Exception & E)
+  {
+    Result = HandleException(Console, E);
+    Console->WaitBeforeExit();
+  }
+  return Result;
+}

+ 2 - 0
source/windows/Setup.h

@@ -4,6 +4,7 @@
 //---------------------------------------------------------------------------
 #include <Interface.h>
 #include <WinConfiguration.h>
+#include <WinInterface.h>
 //---------------------------------------------------------------------------
 void __fastcall SetupInitialize();
 void __fastcall AddSearchPath(const UnicodeString Path);
@@ -34,5 +35,6 @@ void __fastcall TipsUpdateStaticUsage();
 int __fastcall GetNetVersion();
 UnicodeString __fastcall GetNetVersionStr();
 UnicodeString __fastcall GetPowerShellVersionStr();
+int ComRegistration(TConsole * Console);
 //---------------------------------------------------------------------------
 #endif

+ 29 - 1
source/windows/WinInterface.h

@@ -8,6 +8,7 @@
 #include <WinConfiguration.h>
 #include <Terminal.h>
 #include <SynchronizeController.h>
+#include <Script.h>
 
 #ifdef LOCALINTERFACE
 #include <LocalInterface.h>
@@ -46,6 +47,7 @@ const int mpAllowContinueOnError = 0x02;
 #define FINGERPRINTSCAN_SWITCH L"FingerprintScan"
 #define DUMPCALLSTACK_SWITCH L"DumpCallstack"
 #define INFO_SWITCH L"Info"
+#define COMREGISTRATION_SWITCH L"ComRegistration"
 
 #define DUMPCALLSTACK_EVENT L"WinSCPCallstack%d"
 
@@ -419,7 +421,7 @@ extern const UnicodeString OKButtonName;
 // windows\Console.cpp
 enum TConsoleMode
 {
-  cmNone, cmScripting, cmHelp, cmBatchSettings, cmKeyGen, cmFingerprintScan, cmDumpCallstack, cmInfo,
+  cmNone, cmScripting, cmHelp, cmBatchSettings, cmKeyGen, cmFingerprintScan, cmDumpCallstack, cmInfo, cmComRegistration,
 };
 int __fastcall Console(TConsoleMode Mode);
 
@@ -602,4 +604,30 @@ private:
   void __fastcall BalloonCancelled();
 };
 //---------------------------------------------------------------------------
+class TConsole
+{
+public:
+  virtual __fastcall ~TConsole() {};
+  virtual void __fastcall Print(UnicodeString Str, bool FromBeginning = false, bool Error = false) = 0;
+  void __fastcall PrintLine(const UnicodeString & Str = UnicodeString(), bool Error = false);
+  virtual bool __fastcall Input(UnicodeString & Str, bool Echo, unsigned int Timer) = 0;
+  virtual int __fastcall Choice(
+    UnicodeString Options, int Cancel, int Break, int Continue, int Timeouted, bool Timeouting, unsigned int Timer,
+    UnicodeString Message) = 0;
+  virtual bool __fastcall PendingAbort() = 0;
+  virtual void __fastcall SetTitle(UnicodeString Title) = 0;
+  virtual bool __fastcall LimitedOutput() = 0;
+  virtual bool __fastcall LiveOutput() = 0;
+  virtual bool __fastcall NoInteractiveInput() = 0;
+  virtual void __fastcall WaitBeforeExit() = 0;
+  virtual bool __fastcall CommandLineOnly() = 0;
+  virtual bool __fastcall WantsProgress() = 0;
+  virtual void __fastcall Progress(TScriptProgress & Progress) = 0;
+  virtual UnicodeString __fastcall FinalLogMessage() = 0;
+};
+//---------------------------------------------------------------------------
+int __fastcall HandleException(TConsole * Console, Exception & E);
+//---------------------------------------------------------------------------
+enum { RESULT_SUCCESS = 0, RESULT_ANY_ERROR = 1 };
+//---------------------------------------------------------------------------
 #endif // WinInterfaceH

+ 4 - 0
source/windows/WinMain.cpp

@@ -801,6 +801,10 @@ int __fastcall Execute()
   {
     Mode = cmInfo;
   }
+  else if (Params->FindSwitch(COMREGISTRATION_SWITCH))
+  {
+    Mode = cmComRegistration;
+  }
   // We have to check for /console only after the other options,
   // as the /console is always used when we are run by winscp.com
   // (ambiguous use to pass console version)