Browse Source

Thread-safe implementation of PuTTY callbacks

See also the note in 8ce20a988 commit message.

Source commit: 9bed745884500e15a478e612b553fbaff6989c52
Martin Prikryl 6 years ago
parent
commit
a1c81d248e

+ 17 - 6
source/core/PuttyIntf.cpp

@@ -74,10 +74,8 @@ void __fastcall DontSaveRandomSeed()
   SaveRandomSeed = false;
 }
 //---------------------------------------------------------------------------
-extern "C" char * do_select(Plug plug, SOCKET skt, int startup)
+TSecureShell * GetSecureShell(Plug plug, bool & pfwd)
 {
-  void * frontend;
-
   if (!is_ssh(plug) && !is_pfwd(plug))
   {
     // If it is not SSH/PFwd plug, then it must be Proxy plug.
@@ -86,16 +84,29 @@ extern "C" char * do_select(Plug plug, SOCKET skt, int startup)
     plug = AProxySocket->plug;
   }
 
-  bool pfwd = is_pfwd(plug);
+  pfwd = is_pfwd(plug);
   if (pfwd)
   {
     plug = (Plug)get_pfwd_backend(plug);
   }
 
-  frontend = get_ssh_frontend(plug);
+  void * frontend = get_ssh_frontend(plug);
   DebugAssert(frontend);
 
-  TSecureShell * SecureShell = reinterpret_cast<TSecureShell*>(frontend);
+  return reinterpret_cast<TSecureShell*>(frontend);
+}
+//---------------------------------------------------------------------------
+struct callback_set * get_callback_set(Plug plug)
+{
+  bool pfwd;
+  TSecureShell * SecureShell = GetSecureShell(plug, pfwd);
+  return SecureShell->GetCallbackSet();
+}
+//---------------------------------------------------------------------------
+extern "C" char * do_select(Plug plug, SOCKET skt, int startup)
+{
+  bool pfwd;
+  TSecureShell * SecureShell = GetSecureShell(plug, pfwd);
   if (!pfwd)
   {
     SecureShell->UpdateSocket(skt, startup);

+ 9 - 2
source/core/SecureShell.cpp

@@ -53,6 +53,8 @@ __fastcall TSecureShell::TSecureShell(TSessionUI* UI,
   FSimple = false;
   FCollectPrivateKeyUsage = false;
   FWaitingForData = 0;
+  FCallbackSet.reset(new callback_set());
+  memset(FCallbackSet.get(), 0, sizeof(callback_set));
 }
 //---------------------------------------------------------------------------
 __fastcall TSecureShell::~TSecureShell()
@@ -602,6 +604,11 @@ void __fastcall TSecureShell::Init()
   }
 }
 //---------------------------------------------------------------------------
+struct callback_set * TSecureShell::GetCallbackSet()
+{
+  return FCallbackSet.get();
+}
+//---------------------------------------------------------------------------
 UnicodeString __fastcall TSecureShell::ConvertFromPutty(const char * Str, int Length)
 {
   int BomLength = strlen(MPEXT_BOM);
@@ -1904,7 +1911,7 @@ bool __fastcall TSecureShell::EventSelectLoop(unsigned int MSec, bool ReadEventR
       do
       {
         unsigned int TimeoutStep = std::min(GUIUpdateInterval, Timeout);
-        if (toplevel_callback_pending())
+        if (toplevel_callback_pending(GetCallbackSet()))
         {
           TimeoutStep = 0;
         }
@@ -1916,7 +1923,7 @@ bool __fastcall TSecureShell::EventSelectLoop(unsigned int MSec, bool ReadEventR
         int PrevDataLen = (-static_cast<int>(OutLen) + static_cast<int>(PendLen));
         // 2) Changes in session state - wait criteria in Init()
         bool PrevSessionState = get_ssh_state_session(FBackendHandle);
-        if (run_toplevel_callbacks() &&
+        if (run_toplevel_callbacks(GetCallbackSet()) &&
             (((-static_cast<int>(OutLen) + static_cast<int>(PendLen)) > PrevDataLen) ||
              (PrevSessionState != get_ssh_state_session(FBackendHandle))))
         {

+ 3 - 0
source/core/SecureShell.h

@@ -17,6 +17,7 @@ typedef struct _WSANETWORKEVENTS WSANETWORKEVENTS;
 typedef UINT_PTR SOCKET;
 typedef std::set<SOCKET> TSockets;
 struct TPuttyTranslation;
+struct callback_set;
 enum TSshImplementation { sshiUnknown, sshiOpenSSH, sshiProFTPD, sshiBitvise, sshiTitan, sshiOpenVMS, sshiCerberus };
 //---------------------------------------------------------------------------
 class TSecureShell
@@ -69,6 +70,7 @@ private:
   bool FUtfStrings;
   DWORD FLastSendBufferUpdate;
   int FSendBuf;
+  std::auto_ptr<callback_set> FCallbackSet;
 
   static TCipher __fastcall FuncToSsh1Cipher(const void * Cipher);
   static TCipher __fastcall FuncToSsh2Cipher(const void * Cipher);
@@ -165,6 +167,7 @@ public:
   void __fastcall OldKeyfileWarning();
   void __fastcall PuttyLogEvent(const char * Str);
   UnicodeString __fastcall ConvertFromPutty(const char * Str, int Length);
+  struct callback_set * GetCallbackSet();
 
   __property bool Active = { read = FActive, write = SetActive };
   __property bool Ready = { read = GetReady };

+ 24 - 12
source/putty/callback.c

@@ -14,19 +14,31 @@ struct callback {
     void *ctx;
 };
 
+#ifdef MPEXT
+// PuTTY has one thread only, so run_toplevel_callbacks does not cater for multi threaded uses.
+// It would call callbacks registered any on thread from the thread that happens to call it.
+// We need to create separate callback queue for every SSH session.
+#define CALLBACK_SET_VAR callback_set_v
+#define CALLBACK_SET_VAR_PARAM CALLBACK_SET_VAR,
+#define cbcurr CALLBACK_SET_VAR->cbcurr
+#define cbhead CALLBACK_SET_VAR->cbhead
+#define cbtail CALLBACK_SET_VAR->cbtail
+#else
+#define CALLBACK_SET_VAR_PARAM
 struct callback *cbcurr = NULL, *cbhead = NULL, *cbtail = NULL;
+#endif
 
+#ifndef MPEXT
 toplevel_callback_notify_fn_t notify_frontend = NULL;
 void *frontend = NULL;
 
 void request_callback_notifications(toplevel_callback_notify_fn_t fn,
                                     void *fr)
 {
-    MPEXT_PUTTY_SECTION_ENTER;
     notify_frontend = fn;
     frontend = fr;
-    MPEXT_PUTTY_SECTION_LEAVE;
 }
+#endif
 
 static void run_idempotent_callback(void *ctx)
 {
@@ -35,14 +47,15 @@ static void run_idempotent_callback(void *ctx)
     ic->fn(ic->ctx);
 }
 
-void queue_idempotent_callback(struct IdempotentCallback *ic)
+void queue_idempotent_callback(CALLBACK_SET struct IdempotentCallback *ic)
 {
     if (ic->queued)
         return;
     ic->queued = TRUE;
-    queue_toplevel_callback(run_idempotent_callback, ic);
+    queue_toplevel_callback(CALLBACK_SET_VAR_PARAM run_idempotent_callback, ic);
 }
 
+#ifndef MPEXT
 void delete_callbacks_for_context(void *ctx)
 {
     struct callback *newhead, *newtail;
@@ -68,16 +81,17 @@ void delete_callbacks_for_context(void *ctx)
     cbhead = newhead;
     cbtail = newtail;
 }
+#endif
 
-void queue_toplevel_callback(toplevel_callback_fn_t fn, void *ctx)
-{
+void queue_toplevel_callback(CALLBACK_SET toplevel_callback_fn_t fn, void *ctx)
+{
     struct callback *cb;
 
-    MPEXT_PUTTY_SECTION_ENTER;
     cb = snew(struct callback);
     cb->fn = fn;
     cb->ctx = ctx;
 
+#ifndef MPEXT
     /*
      * If the front end has requested notification of pending
      * callbacks, and we didn't already have one queued, let it know
@@ -91,6 +105,7 @@ void queue_toplevel_callback(toplevel_callback_fn_t fn, void *ctx)
      */
     if (notify_frontend && !cbhead && !cbcurr)
         notify_frontend(frontend);
+#endif
 
     if (cbtail)
         cbtail->next = cb;
@@ -98,13 +113,11 @@ void queue_toplevel_callback(toplevel_callback_fn_t fn, void *ctx)
         cbhead = cb;
     cbtail = cb;
     cb->next = NULL;
-    MPEXT_PUTTY_SECTION_LEAVE;
 }
 
-int run_toplevel_callbacks(void)
+int run_toplevel_callbacks(CALLBACK_SET_ONLY)
 {
     int done_something = FALSE;
-    MPEXT_PUTTY_SECTION_ENTER;
 
     if (cbhead) {
         /*
@@ -127,11 +140,10 @@ int run_toplevel_callbacks(void)
 
         done_something = TRUE;
     }
-    MPEXT_PUTTY_SECTION_LEAVE;
     return done_something;
 }
 
-int toplevel_callback_pending(void)
+int toplevel_callback_pending(CALLBACK_SET_ONLY)
 {
     // MP does not have to be guarded
     return cbcurr != NULL || cbhead != NULL;

+ 20 - 4
source/putty/putty.h

@@ -1601,10 +1601,24 @@ unsigned long timing_last_clock(void);
  * it might have done whatever the loop's caller was waiting for.
  */
 typedef void (*toplevel_callback_fn_t)(void *ctx);
-void queue_toplevel_callback(toplevel_callback_fn_t fn, void *ctx);
-int run_toplevel_callbacks(void);
-int toplevel_callback_pending(void);
+#ifdef MPEXT
+typedef struct callback callback;
+struct callback_set {
+    struct callback *cbcurr, *cbhead, *cbtail;
+};
+#define CALLBACK_SET_ONLY struct callback_set * callback_set_v
+#define CALLBACK_SET CALLBACK_SET_ONLY,
+#else
+#define CALLBACK_SET_ONLY void
+#define CALLBACK_SET
+#endif
+void queue_toplevel_callback(CALLBACK_SET toplevel_callback_fn_t fn, void *ctx);
+int run_toplevel_callbacks(CALLBACK_SET_ONLY);
+int toplevel_callback_pending(CALLBACK_SET_ONLY);
+struct callback_set * get_callback_set(Plug plug);
+#ifndef MPEXT
 void delete_callbacks_for_context(void *ctx);
+#endif
 
 /*
  * Another facility in callback.c deals with 'idempotent' callbacks,
@@ -1620,11 +1634,13 @@ struct IdempotentCallback {
     void *ctx;
     int queued;
 };
-void queue_idempotent_callback(struct IdempotentCallback *ic);
+void queue_idempotent_callback(CALLBACK_SET struct IdempotentCallback *ic);
 
+#ifndef MPEXT
 typedef void (*toplevel_callback_notify_fn_t)(void *frontend);
 void request_callback_notifications(toplevel_callback_notify_fn_t notify,
                                     void *frontend);
+#endif
 
 /*
  * Define no-op macros for the jump list functions, on platforms that

+ 4 - 0
source/putty/ssh.c

@@ -27,6 +27,10 @@
 #define GSS_CTXT_MAYFAIL (1<<3)	/* Context may expire during handshake */
 #endif
 
+#ifdef MPEXT
+#define queue_idempotent_callback(IC) queue_idempotent_callback(get_callback_set(&ssh->plugvt), IC)
+#endif
+
 static const char *const ssh2_disconnect_reasons[] = {
     NULL,
     "host not allowed to connect",

+ 1 - 0
source/putty/windows/winhsock.c

@@ -14,6 +14,7 @@
 
 #ifdef MPEXT
 extern char *do_select(Plug plug, SOCKET skt, int startup);
+#define queue_toplevel_callback(FN, CTX) queue_toplevel_callback(get_callback_set(CTX->plug), FN, CTX)
 #endif
 
 typedef struct HandleSocket {

+ 1 - 1
source/putty/windows/winnet.c

@@ -1629,7 +1629,7 @@ void try_send(NetSocket *s)
 		 * moment.
 		 */
 		s->pending_error = err;
-                queue_toplevel_callback(socket_error_callback, s);
+                queue_toplevel_callback(get_callback_set(s->plug), socket_error_callback, s);
 		return;
 	    }
 	} else {