Prechádzať zdrojové kódy

Refactoring SSL globals

Source commit: ce896c4cb6db7ab2a2f28bdf11498946610b8c9a
Martin Prikryl 2 rokov pred
rodič
commit
1fdf58c748

+ 82 - 106
source/filezilla/AsyncSslSocketLayer.cpp

@@ -16,10 +16,10 @@
 
 /////////////////////////////////////////////////////////////////////////////
 // CAsyncSslSocketLayer
-CCriticalSectionWrapper CAsyncSslSocketLayer::m_sCriticalSection;
+std::unique_ptr<TCriticalSection> CAsyncSslSocketLayer::m_sCriticalSection(TraceInitPtr(new TCriticalSection()));
 
 CAsyncSslSocketLayer::t_SslLayerList* CAsyncSslSocketLayer::m_pSslLayerList = 0;
-int CAsyncSslSocketLayer::m_nSslRefCount = 0;
+bool CAsyncSslSocketLayer::m_bSslInitialized = false;
 
 CAsyncSslSocketLayer::CAsyncSslSocketLayer()
 {
@@ -62,30 +62,24 @@ CAsyncSslSocketLayer::CAsyncSslSocketLayer()
 
 CAsyncSslSocketLayer::~CAsyncSslSocketLayer()
 {
-  UnloadSSL();
+  ResetSslSession();
   delete [] m_pNetworkSendBuffer;
   delete [] m_pRetrySendBuffer;
 }
 
 int CAsyncSslSocketLayer::InitSSL()
 {
-  if (m_bSslInitialized)
-    return 0;
-
-  m_sCriticalSection.Lock();
+  TGuard Guard(m_sCriticalSection.get());
 
-  if (!m_nSslRefCount)
+  if (!m_bSslInitialized)
   {
     if (!OPENSSL_init_ssl(OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL))
     {
       return SSL_FAILURE_INITSSL;
     }
-  }
 
-  m_nSslRefCount++;
-  m_sCriticalSection.Unlock();
-
-  m_bSslInitialized = true;
+    m_bSslInitialized = true;
+  }
 
   return 0;
 }
@@ -708,68 +702,70 @@ int CAsyncSslSocketLayer::InitSSLConnection(bool clientMode,
   if (res)
     return res;
 
-  m_sCriticalSection.Lock();
-  if (!m_ssl_ctx)
   {
-    // Create new context if none given
-    if (!(m_ssl_ctx = SSL_CTX_new( SSLv23_method())))
+    // What is the point of this guard?
+    // Maybe the m_ssl_ctx was intended to be global. But as it is not, the guard is probably pointless.
+    TGuard Guard(m_sCriticalSection.get());
+
+    if (!m_ssl_ctx)
+    {
+      // Create new context if none given
+      if (!(m_ssl_ctx = SSL_CTX_new( SSLv23_method())))
+      {
+        ResetSslSession();
+        return SSL_FAILURE_INITSSL;
+      }
+
+      if (clientMode)
+      {
+        USES_CONVERSION;
+        SSL_CTX_set_verify(m_ssl_ctx, SSL_VERIFY_PEER, verify_callback);
+        SSL_CTX_set_client_cert_cb(m_ssl_ctx, ProvideClientCert);
+        // https://www.mail-archive.com/[email protected]/msg86186.html
+        SSL_CTX_set_session_cache_mode(m_ssl_ctx, SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL_STORE | SSL_SESS_CACHE_NO_AUTO_CLEAR);
+        SSL_CTX_sess_set_new_cb(m_ssl_ctx, NewSessionCallback);
+        CFileStatus Dummy;
+        if (!m_CertStorage.IsEmpty() &&
+            CFile::GetStatus((LPCTSTR)m_CertStorage, Dummy))
+        {
+          SSL_CTX_load_verify_locations(m_ssl_ctx, T2CA(m_CertStorage), 0);
+        }
+      }
+    }
+
+    //Create new SSL session
+    if (!(m_ssl = SSL_new(m_ssl_ctx)))
     {
-      m_sCriticalSection.Unlock();
       ResetSslSession();
       return SSL_FAILURE_INITSSL;
     }
 
-    if (clientMode)
+    if (clientMode && (host.GetLength() > 0))
     {
       USES_CONVERSION;
-      SSL_CTX_set_verify(m_ssl_ctx, SSL_VERIFY_PEER, verify_callback);
-      SSL_CTX_set_client_cert_cb(m_ssl_ctx, ProvideClientCert);
-      // https://www.mail-archive.com/[email protected]/msg86186.html
-      SSL_CTX_set_session_cache_mode(m_ssl_ctx, SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL_STORE | SSL_SESS_CACHE_NO_AUTO_CLEAR);
-      SSL_CTX_sess_set_new_cb(m_ssl_ctx, NewSessionCallback);
-      CFileStatus Dummy;
-      if (!m_CertStorage.IsEmpty() &&
-          CFile::GetStatus((LPCTSTR)m_CertStorage, Dummy))
-      {
-        SSL_CTX_load_verify_locations(m_ssl_ctx, T2CA(m_CertStorage), 0);
-      }
+      SSL_set_tlsext_host_name(m_ssl, T2CA(host));
     }
-  }
-
-  //Create new SSL session
-  if (!(m_ssl = SSL_new(m_ssl_ctx)))
-  {
-    m_sCriticalSection.Unlock();
-    ResetSslSession();
-    return SSL_FAILURE_INITSSL;
-  }
-
-  if (clientMode && (host.GetLength() > 0))
-  {
-    USES_CONVERSION;
-    SSL_set_tlsext_host_name(m_ssl, T2CA(host));
-  }
 
 #ifdef _DEBUG
-  if ((main == NULL) && LoggingSocketMessage(FZ_LOG_INFO))
-  {
-    USES_CONVERSION;
-    LogSocketMessageRaw(FZ_LOG_INFO, L"Supported ciphersuites:");
-    STACK_OF(SSL_CIPHER) * ciphers = SSL_get_ciphers(m_ssl);
-    for (int i = 0; i < sk_SSL_CIPHER_num(ciphers); i++)
+    if ((main == NULL) && LoggingSocketMessage(FZ_LOG_INFO))
     {
-      const SSL_CIPHER * cipher = sk_SSL_CIPHER_value(ciphers, i);
-      LogSocketMessageRaw(FZ_LOG_INFO, A2CT(SSL_CIPHER_get_name(cipher)));
+      USES_CONVERSION;
+      LogSocketMessageRaw(FZ_LOG_INFO, L"Supported ciphersuites:");
+      STACK_OF(SSL_CIPHER) * ciphers = SSL_get_ciphers(m_ssl);
+      for (int i = 0; i < sk_SSL_CIPHER_num(ciphers); i++)
+      {
+        const SSL_CIPHER * cipher = sk_SSL_CIPHER_value(ciphers, i);
+        LogSocketMessageRaw(FZ_LOG_INFO, A2CT(SSL_CIPHER_get_name(cipher)));
+      }
     }
-  }
 #endif
 
-  //Add current instance to list of active instances
-  t_SslLayerList *tmp = m_pSslLayerList;
-  m_pSslLayerList = new t_SslLayerList;
-  m_pSslLayerList->pNext = tmp;
-  m_pSslLayerList->pLayer = this;
-  m_sCriticalSection.Unlock();
+    //Add current instance to list of active instances
+    t_SslLayerList *tmp = m_pSslLayerList;
+    m_pSslLayerList = new t_SslLayerList;
+    m_pSslLayerList->pNext = tmp;
+    m_pSslLayerList->pLayer = this;
+  }
 
   SSL_set_info_callback(m_ssl, apps_ssl_info_callback);
 
@@ -913,7 +909,7 @@ void CAsyncSslSocketLayer::ResetSslSession()
     SSL_free(m_ssl);
   }
 
-  m_sCriticalSection.Lock();
+  TGuard Guard(m_sCriticalSection.get());
 
   if (m_ssl_ctx)
   {
@@ -925,7 +921,6 @@ void CAsyncSslSocketLayer::ResetSslSession()
   t_SslLayerList *cur = m_pSslLayerList;
   if (!cur)
   {
-    m_sCriticalSection.Unlock();
     return;
   }
 
@@ -943,7 +938,6 @@ void CAsyncSslSocketLayer::ResetSslSession()
         cur->pNext = cur->pNext->pNext;
         delete tmp;
 
-        m_sCriticalSection.Unlock();
         return;
       }
       cur = cur->pNext;
@@ -957,7 +951,6 @@ void CAsyncSslSocketLayer::ResetSslSession()
   m_sessionreuse = true;
   m_sessionreuse_failed = false;
 
-  m_sCriticalSection.Unlock();
 }
 
 bool CAsyncSslSocketLayer::IsUsingSSL()
@@ -1084,23 +1077,23 @@ void CAsyncSslSocketLayer::apps_ssl_info_callback(const SSL *s, int where, int r
 {
   USES_CONVERSION;
   CAsyncSslSocketLayer *pLayer = 0;
-  m_sCriticalSection.Lock();
-  t_SslLayerList *cur = m_pSslLayerList;
-  while (cur)
   {
-    if (cur->pLayer->m_ssl == s)
-      break;
-    cur = cur->pNext;
-  }
-  if (!cur)
-  {
-    m_sCriticalSection.Unlock();
-    MessageBox(0, L"Can't lookup TLS session!", L"Critical error", MB_ICONEXCLAMATION);
-    return;
+    TGuard Guard(m_sCriticalSection.get());
+    t_SslLayerList *cur = m_pSslLayerList;
+    while (cur)
+    {
+      if (cur->pLayer->m_ssl == s)
+        break;
+      cur = cur->pNext;
+    }
+    if (!cur)
+    {
+      MessageBox(0, L"Can't lookup TLS session!", L"Critical error", MB_ICONEXCLAMATION);
+      return;
+    }
+    else
+      pLayer = cur->pLayer;
   }
-  else
-    pLayer = cur->pLayer;
-  m_sCriticalSection.Unlock();
 
   // Called while unloading?
   if (!pLayer->m_bUseSSL && (where != SSL_CB_LOOP))
@@ -1216,26 +1209,6 @@ void CAsyncSslSocketLayer::apps_ssl_info_callback(const SSL *s, int where, int r
   }
 }
 
-
-void CAsyncSslSocketLayer::UnloadSSL()
-{
-  if (!m_bSslInitialized)
-    return;
-  ResetSslSession();
-
-  m_bSslInitialized = false;
-
-  m_sCriticalSection.Lock();
-  m_nSslRefCount--;
-  if (m_nSslRefCount)
-  {
-    m_sCriticalSection.Unlock();
-    return;
-  }
-
-  m_sCriticalSection.Unlock();
-}
-
 bool AsnTimeToValidTime(ASN1_TIME * AsnTime, t_SslCertData::t_validTime & ValidTime)
 {
   int i = AsnTime->length;
@@ -1714,17 +1687,20 @@ void CAsyncSslSocketLayer::OnConnect(int nErrorCode)
 CAsyncSslSocketLayer * CAsyncSslSocketLayer::LookupLayer(SSL * Ssl)
 {
   CAsyncSslSocketLayer * Result = NULL;
-  m_sCriticalSection.Lock();
-  t_SslLayerList * Cur = m_pSslLayerList;
-  while (Cur != NULL)
+  t_SslLayerList * Cur = NULL;
+
   {
-    if (Cur->pLayer->m_ssl == Ssl)
+    TGuard Guard(m_sCriticalSection.get());
+    Cur = m_pSslLayerList;
+    while (Cur != NULL)
     {
-      break;
+      if (Cur->pLayer->m_ssl == Ssl)
+      {
+        break;
+      }
+      Cur = Cur->pNext;
     }
-    Cur = Cur->pNext;
   }
-  m_sCriticalSection.Unlock();
 
   if (Cur == NULL)
   {

+ 2 - 5
source/filezilla/AsyncSslSocketLayer.h

@@ -115,7 +115,6 @@ struct t_SslCertData
   int priv_data; //Internal data, do not modify
 };
 //---------------------------------------------------------------------------
-class CCriticalSectionWrapper;
 class CFileZillaTools;
 //---------------------------------------------------------------------------
 class CAsyncSslSocketLayer : public CAsyncSocketExLayer
@@ -158,7 +157,6 @@ private:
   void PrintSessionInfo();
   BOOL ShutDownComplete();
   int InitSSL();
-  void UnloadSSL();
   void PrintLastErrorMsg();
   bool HandleSession(SSL_SESSION * Session);
   int ProcessSendBuffer();
@@ -177,11 +175,10 @@ private:
   BOOL m_bFailureSent;
 
   // Critical section for thread synchronization
-  static CCriticalSectionWrapper m_sCriticalSection;
+  static std::unique_ptr<TCriticalSection> m_sCriticalSection;
 
   // Status variables
-  static int m_nSslRefCount;
-  BOOL m_bSslInitialized;
+  static bool m_bSslInitialized;
   int m_nShutDown;
   int m_nNetworkError;
   int m_nSslAsyncNotifyId;