Browse Source

AppleTLS: Implement AppleTLS and Apple Message Digest

Nils Maier 12 years ago
parent
commit
0bcbd947b4

+ 60 - 8
configure.ac

@@ -24,6 +24,7 @@ esac
 AC_DEFINE_UNQUOTED([TARGET], ["$target"], [Define target-type])
 
 # Checks for arguments.
+ARIA2_ARG_WITHOUT([appletls])
 ARIA2_ARG_WITHOUT([gnutls])
 ARIA2_ARG_WITHOUT([libnettle])
 ARIA2_ARG_WITHOUT([libgmp])
@@ -145,7 +146,28 @@ if test "x$with_sqlite3" = "xyes"; then
   fi
 fi
 
-if test "x$with_gnutls" = "xyes"; then
+case "$host" in
+  *darwin*)
+    have_osx="yes"
+  ;;
+esac
+
+if test "x$with_appletls" = "xyes"; then
+  AC_MSG_CHECKING([whether to enable Mac OS X native SSL/TLS])
+  if test "x$have_osx" = "xyes"; then
+    AC_DEFINE([HAVE_APPLETLS], [1], [Define to 1 if you have Apple TLS])
+    LDFLAGS="$LDFLAGS -framework CoreFoundation -framework Security"
+    have_appletls="yes"
+    AC_MSG_RESULT(yes)
+  else
+    AC_MSG_RESULT(no)
+    if test "x$with_appletls_requested" = "xyes"; then
+      ARIA2_DEP_NOT_MET([appletls])
+    fi
+  fi
+fi
+
+if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes"; then
   # gnutls >= 2.8 doesn't have libgnutls-config anymore. We require
   # 2.2.0 because we use gnutls_priority_set_direct()
   PKG_CHECK_MODULES([LIBGNUTLS], [gnutls >= 2.2.0],
@@ -163,7 +185,7 @@ if test "x$with_gnutls" = "xyes"; then
   fi
 fi
 
-if test "x$with_openssl" = "xyes" && test "x$have_libgnutls" != "xyes"; then
+if test "x$with_openssl" = "xyes" && test "x$have_appletls" != "xyes" && test "x$have_libgnutls" != "xyes"; then
   PKG_CHECK_MODULES([OPENSSL], [openssl >= 0.9.8],
                     [have_openssl=yes], [have_openssl=no])
   if test "x$have_openssl" = "xyes"; then
@@ -235,8 +257,30 @@ if test "x$with_libcares" = "xyes"; then
   fi
 fi
 
+use_md=""
+if test "x$have_osx" == "xyes"; then
+  use_md="apple"
+  AC_DEFINE([USE_APPLE_MD], [1], [What message digest implementation to use])
+else
+  if test "x$have_libnettle" = "xyes"; then
+    AC_DEFINE([USE_LIBNETTLE_MD], [1], [What message digest implementation to use])
+    use_md="libnettle"
+  else
+    if test "x$have_libgcrypt" = "xyes"; then
+      AC_DEFINE([USE_LIBGCRYPT_MD], [1], [What message digest implementation to use])
+      use_md="libgcrypt"
+    else
+      if test = "x$have_openssl" = "xyes"; then
+        AC_DEFINE([USE_OPENSSL_MD], [1], [What message digest implementation to use])
+        use_md="openssl"
+      fi
+    fi
+  fi
+fi
+
 # Define variables based on the result of the checks for libraries.
-if test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then
+if test "x$have_appletls" = "xyes" || test "x$have_libgnutls" = "xyes" || test "x$have_openssl" = "xyes"; then
+  have_ssl="yes"
   AC_DEFINE([ENABLE_SSL], [1], [Define to 1 if ssl support is enabled.])
   AM_CONDITIONAL([ENABLE_SSL], true)
   AC_SUBST([ca_bundle])
@@ -244,14 +288,20 @@ else
   AM_CONDITIONAL([ENABLE_SSL], false)
 fi
 
+
+AM_CONDITIONAL([HAVE_OSX], [ test "x$have_osx" = "xyes" ])
+AM_CONDITIONAL([HAVE_APPLETLS], [ test "x$have_appletls" = "xyes" ])
+AM_CONDITIONAL([USE_APPLE_MD], [ test "x$use_md" = "xapple" ])
 AM_CONDITIONAL([HAVE_LIBGNUTLS], [ test "x$have_libgnutls" = "xyes" ])
 AM_CONDITIONAL([HAVE_LIBNETTLE], [ test "x$have_libnettle" = "xyes" ])
+AM_CONDITIONAL([USE_LIBNETTLE_MD], [ test "x$use_md" = "xlibnettle"])
 AM_CONDITIONAL([HAVE_LIBGMP], [ test "x$have_libgmp" = "xyes" ])
 AM_CONDITIONAL([HAVE_LIBGCRYPT], [ test "x$have_libgcrypt" = "xyes" ])
+AM_CONDITIONAL([USE_LIBGCRYPT_MD], [ test "x$use_md" = "xlibgcrypt"])
 AM_CONDITIONAL([HAVE_OPENSSL], [ test "x$have_openssl" = "xyes" ])
+AM_CONDITIONAL([USE_OPENSSL_MD], [ test "x$use_md" = "xopenssl"])
 
-if test "x$have_libnettle" = "xyes" || test "x$have_libgcrypt" = "xyes" ||
-   test "x$have_openssl" = "xyes"; then
+if test "x$use_md" != "x"; then
   AC_DEFINE([ENABLE_MESSAGE_DIGEST], [1],
             [Define to 1 if message digest support is enabled.])
   AM_CONDITIONAL([ENABLE_MESSAGE_DIGEST], true)
@@ -325,9 +375,9 @@ AM_CONDITIONAL([HAVE_SQLITE3], [test "x$have_sqlite3" = "xyes"])
 AC_SEARCH_LIBS([clock_gettime], [rt])
 
 case "$host" in
-	*solaris*)
-                AC_SEARCH_LIBS([getaddrinfo], [nsl socket])
-		;;
+  *solaris*)
+    AC_SEARCH_LIBS([getaddrinfo], [nsl socket])
+    ;;
 esac
 
 # Checks for header files.
@@ -670,6 +720,8 @@ echo "LDFLAGS:        $LDFLAGS"
 echo "LIBS:           $LIBS"
 echo "DEFS:           $DEFS"
 echo "SQLite3:        $have_sqlite3"
+echo "SSL Support:    $have_ssl"
+echo "AppleTLS:       $have_appletls"
 echo "GnuTLS:         $have_libgnutls"
 echo "OpenSSL:        $have_openssl"
 echo "CA Bundle:      $ca_bundle"

+ 153 - 0
src/AppleMessageDigestImpl.cc

@@ -0,0 +1,153 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Nils Maier
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+#include "AppleMessageDigestImpl.h"
+
+#include <CommonCrypto/CommonDigest.h>
+
+#include "array_fun.h"
+#include "HashFuncEntry.h"
+
+namespace aria2 {
+
+template<size_t dlen,
+         typename ctx_t,
+         int (*init_fn)(ctx_t*),
+         int (*update_fn)(ctx_t*, const void*, CC_LONG),
+         int(*final_fn)(unsigned char*, ctx_t*)>
+class MessageDigestBase : public MessageDigestImpl {
+public:
+  MessageDigestBase() { reset(); }
+
+  virtual size_t getDigestLength() const {
+    return dlen;
+  }
+  virtual void reset() {
+    init_fn(&ctx_);
+  }
+  virtual void update(const void* data, size_t length) {
+    while (length) {
+      CC_LONG l = std::min(length, (size_t)std::numeric_limits<uint32_t>::max());
+      update_fn(&ctx_, data, l);
+      length -= l;
+    }
+  }
+  virtual void digest(unsigned char* md) {
+    final_fn(md, &ctx_);
+  }
+private:
+  ctx_t ctx_;
+};
+
+typedef MessageDigestBase<CC_MD5_DIGEST_LENGTH,
+                          CC_MD5_CTX,
+                          CC_MD5_Init,
+                          CC_MD5_Update,
+                          CC_MD5_Final>
+MessageDigestMD5;
+typedef MessageDigestBase<CC_SHA1_DIGEST_LENGTH,
+                          CC_SHA1_CTX,
+                          CC_SHA1_Init,
+                          CC_SHA1_Update,
+                          CC_SHA1_Final>
+MessageDigestSHA1;
+typedef MessageDigestBase<CC_SHA224_DIGEST_LENGTH,
+                          CC_SHA256_CTX,
+                          CC_SHA224_Init,
+                          CC_SHA224_Update,
+                          CC_SHA224_Final>
+MessageDigestSHA224;
+typedef MessageDigestBase<CC_SHA256_DIGEST_LENGTH,
+                          CC_SHA256_CTX,
+                          CC_SHA256_Init,
+                          CC_SHA256_Update,
+                          CC_SHA256_Final>
+MessageDigestSHA256;
+typedef MessageDigestBase<CC_SHA384_DIGEST_LENGTH,
+                          CC_SHA512_CTX,
+                          CC_SHA384_Init,
+                          CC_SHA384_Update,
+                          CC_SHA384_Final>
+MessageDigestSHA384;
+typedef MessageDigestBase<CC_SHA512_DIGEST_LENGTH,
+                          CC_SHA512_CTX,
+                          CC_SHA512_Init,
+                          CC_SHA512_Update,
+                          CC_SHA512_Final>
+MessageDigestSHA512;
+
+SharedHandle<MessageDigestImpl> MessageDigestImpl::sha1()
+{
+  return SharedHandle<MessageDigestImpl>(new MessageDigestSHA1());
+}
+
+SharedHandle<MessageDigestImpl> MessageDigestImpl::create
+(const std::string& hashType)
+{
+  if (hashType == "sha-1") {
+    return SharedHandle<MessageDigestImpl>(new MessageDigestSHA1());
+  }
+  if (hashType == "sha-224") {
+    return SharedHandle<MessageDigestImpl>(new MessageDigestSHA224());
+  }
+  if (hashType == "sha-256") {
+    return SharedHandle<MessageDigestImpl>(new MessageDigestSHA256());
+  }
+  if (hashType == "sha-384") {
+    return SharedHandle<MessageDigestImpl>(new MessageDigestSHA384());
+  }
+  if (hashType == "sha-512") {
+    return SharedHandle<MessageDigestImpl>(new MessageDigestSHA512());
+  }
+  if (hashType == "md5") {
+    return SharedHandle<MessageDigestImpl>(new MessageDigestMD5());
+  }
+  return SharedHandle<MessageDigestImpl>();
+}
+
+bool MessageDigestImpl::supports(const std::string& hashType)
+{
+  return hashType == "sha-1" || hashType == "sha-224" || hashType == "sha-256" || hashType == "sha-384" || hashType == "sha-512" || hashType == "md5";
+}
+
+size_t MessageDigestImpl::getDigestLength(const std::string& hashType)
+{
+  SharedHandle<MessageDigestImpl> impl = create(hashType);
+  if (!impl) {
+    return 0;
+  }
+  return impl->getDigestLength();
+}
+
+} // namespace aria2

+ 71 - 0
src/AppleMessageDigestImpl.h

@@ -0,0 +1,71 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Nils Maier
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+#ifndef D_APPLE_MESSAGE_DIGEST_IMPL_H
+#define D_APPLE_MESSAGE_DIGEST_IMPL_H
+
+#include "common.h"
+
+#include <string>
+
+#include "SharedHandle.h"
+
+namespace aria2 {
+
+class MessageDigestImpl {
+public:
+  static SharedHandle<MessageDigestImpl> sha1();
+  static SharedHandle<MessageDigestImpl> create(const std::string& hashType);
+
+  static bool supports(const std::string& hashType);
+  static size_t getDigestLength(const std::string& hashType);
+
+public:
+  virtual size_t getDigestLength() const = 0;
+  virtual void reset() = 0;
+  virtual void update(const void* data, size_t length) = 0;
+  virtual void digest(unsigned char* md) = 0;
+
+protected:
+  MessageDigestImpl() {}
+
+private:
+  MessageDigestImpl(const MessageDigestImpl&);
+  MessageDigestImpl& operator=(const MessageDigestImpl&);
+
+};
+
+} // namespace aria2
+
+#endif // D_APPLE_MESSAGE_DIGEST_IMPL_H

+ 22 - 15
src/TLSSessionConst.h → src/AppleTLSContext.cc

@@ -2,7 +2,7 @@
 /*
  * aria2 - The high speed download utility
  *
- * Copyright (C) 2013 Tatsuhiro Tsujikawa
+ * Copyright (C) 2013 Nils Maier
  *
  * This program is free software; you can redistribute it and/or modify
  * it under the terms of the GNU General Public License as published by
@@ -32,24 +32,31 @@
  * files in the program, then also delete it here.
  */
 /* copyright --> */
-#ifndef TLS_SESSION_CONST_H
-#define TLS_SESSION_CONST_H
+#include "AppleTLSContext.h"
 
-#include "common.h"
+#include "LogFactory.h"
+#include "Logger.h"
+#include "fmt.h"
+#include "message.h"
 
 namespace aria2 {
 
-enum TLSDirection {
-  TLS_WANT_READ = 1,
-  TLS_WANT_WRITE
-};
+TLSContext* TLSContext::make(TLSSessionSide side) {
+  return new AppleTLSContext(side);
+}
 
-enum TLSErrorCode {
-  TLS_ERR_OK = 0,
-  TLS_ERR_ERROR = -1,
-  TLS_ERR_WOULDBLOCK = -2
-};
+bool AppleTLSContext::addCredentialFile(const std::string& certfile,
+                                        const std::string& keyfile)
+{
+  A2_LOG_WARN("TLS credential files are not supported. Use the KeyChain to manage your certificates.");
+  return false;
+}
+
+bool AppleTLSContext::addTrustedCACertFile(const std::string& certfile)
+{
+  A2_LOG_WARN("TLS CA bundle files are not supported. Use the KeyChain to manage your certificates.");
+  return false;
+}
 
-} // namespace aria2
 
-#endif // TLS_SESSION_CONST_H
+} // namespace aria2

+ 90 - 0
src/AppleTLSContext.h

@@ -0,0 +1,90 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Nils Maier
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+#ifndef D_APPLE_TLS_CONTEXT_H
+#define D_APPLE_TLS_CONTEXT_H
+
+#include "common.h"
+
+#include <string>
+#include <Security/Security.h>
+#include <Security/SecureTransport.h>
+
+#include "TLSContext.h"
+#include "DlAbortEx.h"
+
+namespace aria2 {
+
+class AppleTLSContext : public TLSContext {
+public:
+  AppleTLSContext(TLSSessionSide side)
+    : side_(side),
+      verifyPeer_(true)
+  {}
+
+  virtual ~AppleTLSContext() {}
+
+  // private key `keyfile' must be decrypted.
+  virtual bool addCredentialFile(const std::string& certfile,
+                                 const std::string& keyfile);
+
+  virtual bool addSystemTrustedCACerts() {
+    return true;
+  }
+
+  // certfile can contain multiple certificates.
+  virtual bool addTrustedCACertFile(const std::string& certfile);
+
+  virtual bool good() const {
+    return true;
+  }
+  virtual TLSSessionSide getSide() const {
+    return side_;
+  }
+
+  virtual bool getVerifyPeer() const {
+    return verifyPeer_;
+  }
+  virtual void setVerifyPeer(bool verify) {
+    verifyPeer_ = verify;
+  }
+
+private:
+  TLSSessionSide side_;
+  bool verifyPeer_;
+};
+
+} // namespace aria2
+
+#endif // D_LIBSSL_TLS_CONTEXT_H

+ 354 - 0
src/AppleTLSSession.cc

@@ -0,0 +1,354 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Nils Maier
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+
+#include "AppleTLSSession.h"
+
+#include <CoreFoundation/CoreFoundation.h>
+
+#include "fmt.h"
+#include "LogFactory.h"
+
+#define ioErr -36
+#define paramErr -50
+#define errSSLServerAuthCompleted -9841
+
+namespace {
+  static const SSLProtocol kTLSProtocol11_h = (SSLProtocol)(kSSLProtocolAll + 1);
+  static const SSLProtocol kTLSProtocol12_h = (SSLProtocol)(kSSLProtocolAll + 2);
+}
+
+namespace aria2 {
+
+TLSSession* TLSSession::make(TLSContext* ctx)
+{
+  return new AppleTLSSession(static_cast<AppleTLSContext*>(ctx));
+}
+
+AppleTLSSession::AppleTLSSession(AppleTLSContext* ctx)
+  : ctx_(ctx),
+    sslCtx_(0),
+    sockfd_(0),
+    state_(st_constructed),
+    lastError_(noErr),
+    writeBuffered_(0)
+{
+  lastError_ = SSLNewContext(ctx->getSide() == TLS_SERVER, &sslCtx_) == noErr;
+  if (lastError_ == noErr) {
+    state_ = st_error;
+    return;
+  }
+#if defined(__MAC_10_8)
+  (void)SSLSetProtocolVersionMin(sslCtx_, kSSLProtocol3);
+  (void)SSLSetProtocolVersionMax(sslCtx_, kTLSProtocol12);
+#else
+  (void)SSLSetProtocolVersionEnabled(sslCtx_, kSSLProtocolAll, false);
+  (void)SSLSetProtocolVersionEnabled(sslCtx_, kSSLProtocol3, true);
+  (void)SSLSetProtocolVersionEnabled(sslCtx_, kTLSProtocol1, true);
+  (void)SSLSetProtocolVersionEnabled(sslCtx_, kTLSProtocol11_h, true);
+  (void)SSLSetProtocolVersionEnabled(sslCtx_, kTLSProtocol12_h, true);
+#endif
+  (void)SSLSetEnableCertVerify(sslCtx_, ctx->getVerifyPeer());
+}
+
+AppleTLSSession::~AppleTLSSession()
+{
+  closeConnection();
+  if (sslCtx_) {
+    SSLDisposeContext(sslCtx_);
+    sslCtx_ = 0;
+  }
+  state_ = st_error;
+}
+
+int AppleTLSSession::init(sock_t sockfd)
+{
+  if (state_ != st_constructed) {
+    lastError_ = noErr;
+    return TLS_ERR_ERROR;
+  }
+  lastError_ = SSLSetIOFuncs(sslCtx_, SocketRead, SocketWrite);
+  if (lastError_ != noErr) {
+    state_ = st_error;
+    return TLS_ERR_ERROR;
+  }
+  lastError_ = SSLSetConnection(sslCtx_, this);
+  if (lastError_ != noErr) {
+    state_ = st_error;
+    return TLS_ERR_ERROR;
+  }
+  sockfd_ = sockfd;
+  state_ = st_initialized;
+  return TLS_ERR_OK;
+}
+
+int AppleTLSSession::setSNIHostname(const std::string& hostname)
+{
+  if (state_ != st_initialized) {
+    lastError_ = noErr;
+    return TLS_ERR_ERROR;
+  }
+  lastError_ = SSLSetPeerDomainName(sslCtx_, hostname.c_str(), hostname.length());
+  return (lastError_ != noErr) ? TLS_ERR_ERROR : TLS_ERR_OK;
+}
+
+int AppleTLSSession::closeConnection()
+{
+  if (state_ != st_connected) {
+    lastError_ = noErr;
+    return TLS_ERR_ERROR;
+  }
+  lastError_ = SSLClose(sslCtx_);
+  state_ = st_closed;
+  return lastError_ == noErr ?  TLS_ERR_OK : TLS_ERR_ERROR;
+}
+
+int AppleTLSSession::checkDirection() {
+  if (writeBuffered_) {
+    return TLS_WANT_WRITE;
+  }
+  if (state_ == st_connected) {
+    size_t buffered;
+    lastError_ = SSLGetBufferedReadSize(sslCtx_, &buffered);
+    if (lastError_ == noErr && buffered) {
+      return TLS_WANT_READ;
+    }
+  }
+  return 0;
+}
+
+ssize_t AppleTLSSession::writeData(const void* data, size_t len)
+{
+  if (state_ != st_connected) {
+    lastError_ = noErr;
+    return TLS_ERR_ERROR;
+  }
+  size_t processed = 0;
+  if (writeBuffered_) {
+    lastError_ = SSLWrite(sslCtx_, 0, 0, &processed);
+    switch (lastError_) {
+      case noErr:
+        processed = writeBuffered_;
+        writeBuffered_ = 0;
+        return processed;
+      case errSSLWouldBlock:
+        return TLS_ERR_WOULDBLOCK;
+      case errSSLClosedGraceful:
+      case errSSLClosedNoNotify:
+        closeConnection();
+        return TLS_ERR_ERROR;
+      default:
+        closeConnection();
+        state_ = st_error;
+        return TLS_ERR_ERROR;
+    }
+  }
+
+  lastError_ = SSLWrite(sslCtx_, data, len, &processed);
+  switch (lastError_) {
+    case noErr:
+      return processed;
+    case errSSLWouldBlock:
+      writeBuffered_ = len;
+      return TLS_ERR_WOULDBLOCK;
+    case errSSLClosedGraceful:
+    case errSSLClosedNoNotify:
+      closeConnection();
+      return TLS_ERR_ERROR;
+    default:
+      closeConnection();
+      state_ = st_error;
+      return TLS_ERR_ERROR;
+  }
+}
+OSStatus AppleTLSSession::sockWrite(const void* data, size_t* len)
+{
+  size_t remain = *len;
+  const uint8_t *buffer = static_cast<const uint8_t*>(data);
+  *len = 0;
+  while (remain) {
+    ssize_t w = write(sockfd_, buffer, remain);
+    if (w <= 0) {
+      switch (errno) {
+        case EAGAIN:
+          return errSSLWouldBlock;
+        default:
+          return errSSLClosedAbort;
+      }
+    }
+    remain -= w;
+    buffer += w;
+    *len += w;
+  }
+  return noErr;
+}
+ssize_t AppleTLSSession::readData(void* data, size_t len)
+{
+  if (state_ != st_connected) {
+    lastError_ = noErr;
+    return TLS_ERR_ERROR;
+  }
+  size_t processed = 0;
+  lastError_ = SSLRead(sslCtx_, data, len, &processed);
+  switch (lastError_) {
+    case noErr:
+      return processed;
+    case errSSLWouldBlock:
+      if (processed) {
+        return processed;
+      }
+      return TLS_ERR_WOULDBLOCK;
+    case errSSLClosedGraceful:
+    case errSSLClosedNoNotify:
+      closeConnection();
+      return TLS_ERR_ERROR;
+    default:
+      closeConnection();
+      state_ = st_error;
+      return TLS_ERR_ERROR;
+  }
+}
+
+OSStatus AppleTLSSession::sockRead(void* data, size_t* len)
+{
+  size_t remain = *len;
+  uint8_t *buffer = static_cast<uint8_t*>(data);
+  *len = 0;
+  while (remain) {
+    ssize_t r = read(sockfd_, buffer, remain);
+    if (r == 0) {
+      return errSSLClosedGraceful;
+    }
+    if (r < 0) {
+      switch (errno) {
+        case ENOENT:
+          return errSSLClosedGraceful;
+        case ECONNRESET:
+          return errSSLClosedAbort;
+        case EAGAIN:
+          return errSSLWouldBlock;
+        default:
+          return errSSLClosedAbort;
+      }
+    }
+    remain -= r;
+    buffer += r;
+    *len += r;
+  }
+  return noErr;
+}
+
+int AppleTLSSession::tlsConnect(const std::string& hostname, std::string& handshakeErr)
+{
+  if (state_ != st_initialized) {
+    return TLS_ERR_ERROR;
+  }
+  if (!hostname.empty()) {
+    setSNIHostname(hostname);
+  }
+  lastError_ = SSLHandshake(sslCtx_);
+  switch (lastError_) {
+    case noErr:
+      state_ = st_connected;
+      return TLS_ERR_OK;
+    case errSSLWouldBlock:
+      return TLS_ERR_WOULDBLOCK;
+    case errSSLServerAuthCompleted:
+      return tlsConnect(hostname, handshakeErr);
+    default:
+      handshakeErr = getLastErrorString();
+      return TLS_ERR_ERROR;
+  }
+}
+
+int AppleTLSSession::tlsAccept()
+{
+  std::string hostname, err;
+  return tlsConnect(hostname, err);
+}
+
+std::string AppleTLSSession::getLastErrorString()
+{
+  switch (lastError_) {
+    case errSSLProtocol:
+      return "Protocol error";
+    case errSSLNegotiation:
+      return "No common cipher suites";
+    case errSSLFatalAlert:
+      return "Received fatal alert";
+    case errSSLSessionNotFound:
+      return "Unknown session";
+    case errSSLClosedGraceful:
+      return "Closed gracefully";
+    case errSSLClosedAbort:
+      return "Connection aborted";
+    case errSSLXCertChainInvalid:
+      return "Invalid certificate chain";
+    case errSSLBadCert:
+      return "Invalid certificate format";
+    case errSSLCrypto:
+      return "Cryptographic error";
+    case paramErr:
+    case errSSLInternal:
+      return "Internal SSL error";
+    case errSSLUnknownRootCert:
+      return "Self-signed certificate";
+    case errSSLNoRootCert:
+      return "No root certificate";
+    case errSSLCertExpired:
+      return "Certificate expired";
+    case errSSLCertNotYetValid:
+      return "Certificate not yet valid";
+    case errSSLClosedNoNotify:
+      return "Closed without notification";
+    case errSSLBufferOverflow:
+      return "Buffer not large enough";
+    case errSSLBadCipherSuite:
+      return "Bad cipher suite";
+    case errSSLPeerUnexpectedMsg:
+      return "Unexpected peer message";
+    case errSSLPeerBadRecordMac:
+      return "Bad MAC";
+    case errSSLPeerDecryptionFail:
+      return "Decryption failure";
+    case errSSLHostNameMismatch:
+      return "Invalid hostname";
+    case errSSLConnectionRefused:
+      return "Connection refused";
+    default:
+      return fmt("Unspecified error %d", lastError_);
+  }
+}
+
+}

+ 127 - 0
src/AppleTLSSession.h

@@ -0,0 +1,127 @@
+/* <!-- copyright */
+/*
+ * aria2 - The high speed download utility
+ *
+ * Copyright (C) 2013 Nils Maier
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ *
+ * In addition, as a special exception, the copyright holders give
+ * permission to link the code of portions of this program with the
+ * OpenSSL library under certain conditions as described in each
+ * individual source file, and distribute linked combinations
+ * including the two.
+ * You must obey the GNU General Public License in all respects
+ * for all of the code used other than OpenSSL.  If you modify
+ * file(s) with this exception, you may extend this exception to your
+ * version of the file(s), but you are not obligated to do so.  If you
+ * do not wish to do so, delete this exception statement from your
+ * version.  If you delete this exception statement from all source
+ * files in the program, then also delete it here.
+ */
+/* copyright --> */
+#ifndef APPLE_TLS_SESSION_H
+#define APPLE_TLS_SESSION_H
+
+#include "common.h"
+#include "TLSSession.h"
+#include "AppleTLSContext.h"
+
+namespace aria2 {
+
+class AppleTLSSession : public TLSSession {
+  enum state_t {
+    st_constructed,
+    st_initialized,
+    st_connected,
+    st_closed,
+    st_error
+  };
+public:
+  AppleTLSSession(AppleTLSContext* ctx);
+
+  // MUST deallocate all resources
+  virtual ~AppleTLSSession();
+
+  // Initializes SSL/TLS session. The |sockfd| is the underlying
+  // tranport socket. This function returns TLS_ERR_OK if it
+  // succeeds, or TLS_ERR_ERROR.
+  virtual int init(sock_t sockfd);
+
+  // Sets |hostname| for TLS SNI extension. This is only meaningful for
+  // client side session. This function returns TLS_ERR_OK if it
+  // succeeds, or TLS_ERR_ERROR.
+  virtual int setSNIHostname(const std::string& hostname);
+
+  // Closes the SSL/TLS session. Don't close underlying transport
+  // socket. This function returns TLS_ERR_OK if it succeeds, or
+  // TLS_ERR_ERROR.
+  virtual int closeConnection();
+
+  // Returns TLS_WANT_READ if SSL/TLS session needs more data from
+  // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session
+  // needs to write more data to proceed. If SSL/TLS session needs
+  // neither read nor write data at the moment, return value is
+  // undefined.
+  virtual int checkDirection();
+
+  // Sends |data| with length |len|. This function returns the number
+  // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the
+  // underlying tranport blocks, or TLS_ERR_ERROR.
+  virtual ssize_t writeData(const void* data, size_t len);
+
+  // Receives data into |data| with length |len|. This function returns
+  // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK
+  // if the underlying tranport blocks, or TLS_ERR_ERROR.
+  virtual ssize_t readData(void* data, size_t len);
+
+  // Performs client side handshake. The |hostname| is the hostname of
+  // the remote endpoint and is used to verify its certificate. This
+  // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK
+  // if the underlying transport blocks, or TLS_ERR_ERROR.
+  // When returning TLS_ERR_ERROR, provide certificate validation error
+  // in |handshakeErr|.
+  virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr);
+
+  // Performs server side handshake. This function returns TLS_ERR_OK
+  // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
+  // blocks, or TLS_ERR_ERROR.
+  virtual int tlsAccept();
+
+  // Returns last error string
+  virtual std::string getLastErrorString();
+
+private:
+  static OSStatus SocketWrite(SSLConnectionRef conn, const void* data, size_t* len) {
+    return ((AppleTLSSession*)conn)->sockWrite(data, len);
+  }
+  static OSStatus SocketRead(SSLConnectionRef conn, void* data, size_t* len) {
+    return ((AppleTLSSession*)conn)->sockRead(data, len);
+  }
+
+  AppleTLSContext *ctx_;
+  SSLContextRef sslCtx_;
+  sock_t sockfd_;
+  state_t state_;
+  OSStatus lastError_;
+  size_t writeBuffered_;
+
+  OSStatus sockWrite(const void* data, size_t* len);
+  OSStatus sockRead(void* data, size_t* len);
+};
+
+}
+
+#endif // TLS_SESSION_H

+ 14 - 28
src/LibgnutlsTLSContext.cc

@@ -45,10 +45,15 @@
 
 namespace aria2 {
 
-TLSContext::TLSContext(TLSSessionSide side)
+TLSContext* TLSContext::make(TLSSessionSide side)
+{
+  return new GnuTLSContext(side);
+}
+
+GnuTLSContext::GnuTLSContext(TLSSessionSide side)
   : certCred_(0),
     side_(side),
-    peerVerificationEnabled_(false)
+    verifyPeer_(true)
 {
   int r = gnutls_certificate_allocate_credentials(&certCred_);
   if(r == GNUTLS_E_SUCCESS) {
@@ -63,24 +68,19 @@ TLSContext::TLSContext(TLSSessionSide side)
   }
 }
 
-TLSContext::~TLSContext()
+GnuTLSContext::~GnuTLSContext()
 {
   if(certCred_) {
     gnutls_certificate_free_credentials(certCred_);
   }
 }
 
-bool TLSContext::good() const
+bool GnuTLSContext::good() const
 {
   return good_;
 }
 
-bool TLSContext::bad() const
-{
-  return !good_;
-}
-
-bool TLSContext::addCredentialFile(const std::string& certfile,
+bool GnuTLSContext::addCredentialFile(const std::string& certfile,
                                    const std::string& keyfile)
 {
   int ret = gnutls_certificate_set_x509_key_file(certCred_,
@@ -101,7 +101,7 @@ bool TLSContext::addCredentialFile(const std::string& certfile,
   }
 }
 
-bool TLSContext::addSystemTrustedCACerts()
+bool GnuTLSContext::addSystemTrustedCACerts()
 {
 #ifdef HAVE_GNUTLS_CERTIFICATE_SET_X509_SYSTEM_TRUST
   int ret = gnutls_certificate_set_x509_system_trust(certCred_);
@@ -114,11 +114,12 @@ bool TLSContext::addSystemTrustedCACerts()
     return true;
   }
 #else
+  A2_LOG_WARN("System certificates not supported");
   return false;
 #endif
 }
 
-bool TLSContext::addTrustedCACertFile(const std::string& certfile)
+bool GnuTLSContext::addTrustedCACertFile(const std::string& certfile)
 {
   int ret = gnutls_certificate_set_x509_trust_file(certCred_,
                                                    certfile.c_str(),
@@ -133,24 +134,9 @@ bool TLSContext::addTrustedCACertFile(const std::string& certfile)
   }
 }
 
-gnutls_certificate_credentials_t TLSContext::getCertCred() const
+gnutls_certificate_credentials_t GnuTLSContext::getCertCred() const
 {
   return certCred_;
 }
 
-void TLSContext::enablePeerVerification()
-{
-  peerVerificationEnabled_ = true;
-}
-
-void TLSContext::disablePeerVerification()
-{
-  peerVerificationEnabled_ = false;
-}
-
-bool TLSContext::peerVerificationEnabled() const
-{
-  return peerVerificationEnabled_;
-}
-
 } // namespace aria2

+ 21 - 27
src/LibgnutlsTLSContext.h

@@ -37,8 +37,6 @@
 
 #include "common.h"
 
-#include <string>
-
 #include <gnutls/gnutls.h>
 
 #include "TLSContext.h"
@@ -46,45 +44,41 @@
 
 namespace aria2 {
 
-class TLSContext {
-private:
-  gnutls_certificate_credentials_t certCred_;
-
-  TLSSessionSide side_;
-
-  bool good_;
-
-  bool peerVerificationEnabled_;
+class GnuTLSContext : public TLSContext {
 public:
-  TLSContext(TLSSessionSide side);
+  GnuTLSContext(TLSSessionSide side);
 
-  ~TLSContext();
+  virtual ~GnuTLSContext();
 
   // private key `keyfile' must be decrypted.
-  bool addCredentialFile(const std::string& certfile,
-                         const std::string& keyfile);
+  virtual bool addCredentialFile(const std::string& certfile,
+                                 const std::string& keyfile);
 
-  bool addSystemTrustedCACerts();
+  virtual bool addSystemTrustedCACerts();
 
   // certfile can contain multiple certificates.
-  bool addTrustedCACertFile(const std::string& certfile);
-
-  bool good() const;
+  virtual bool addTrustedCACertFile(const std::string& certfile);
 
-  bool bad() const;
+  virtual bool good() const;
 
-  gnutls_certificate_credentials_t getCertCred() const;
-
-  TLSSessionSide getSide() const
-  {
+  virtual TLSSessionSide getSide() const {
     return side_;
   }
 
-  void enablePeerVerification();
+  virtual bool getVerifyPeer() const {
+    return verifyPeer_;
+  }
+  virtual void setVerifyPeer(bool verify) {
+    verifyPeer_ = verify;
+  }
 
-  void disablePeerVerification();
+  gnutls_certificate_credentials_t getCertCred() const;
 
-  bool peerVerificationEnabled() const;
+private:
+  gnutls_certificate_credentials_t certCred_;
+  TLSSessionSide side_;
+  bool good_;
+  bool verifyPeer_;
 };
 
 } // namespace aria2

+ 17 - 12
src/LibgnutlsTLSSession.cc

@@ -42,20 +42,25 @@
 
 namespace aria2 {
 
-TLSSession::TLSSession(TLSContext* tlsContext)
+TLSSession* TLSSession::make(TLSContext* ctx)
+{
+  return new GnuTLSSession(static_cast<GnuTLSContext*>(ctx));
+}
+
+GnuTLSSession::GnuTLSSession(GnuTLSContext* tlsContext)
   : sslSession_(0),
     tlsContext_(tlsContext),
     rv_(0)
 {}
 
-TLSSession::~TLSSession()
+GnuTLSSession::~GnuTLSSession()
 {
   if(sslSession_) {
     gnutls_deinit(sslSession_);
   }
 }
 
-int TLSSession::init(sock_t sockfd)
+int GnuTLSSession::init(sock_t sockfd)
 {
   rv_ = gnutls_init(&sslSession_,
                     tlsContext_->getSide() == TLS_CLIENT ?
@@ -89,7 +94,7 @@ int TLSSession::init(sock_t sockfd)
   return TLS_ERR_OK;
 }
 
-int TLSSession::setSNIHostname(const std::string& hostname)
+int GnuTLSSession::setSNIHostname(const std::string& hostname)
 {
   // TLS extensions: SNI
   rv_ = gnutls_server_name_set(sslSession_, GNUTLS_NAME_DNS,
@@ -100,7 +105,7 @@ int TLSSession::setSNIHostname(const std::string& hostname)
   return TLS_ERR_OK;
 }
 
-int TLSSession::closeConnection()
+int GnuTLSSession::closeConnection()
 {
   rv_ = gnutls_bye(sslSession_, GNUTLS_SHUT_WR);
   if(rv_ == GNUTLS_E_SUCCESS) {
@@ -112,13 +117,13 @@ int TLSSession::closeConnection()
   }
 }
 
-int TLSSession::checkDirection()
+int GnuTLSSession::checkDirection()
 {
   int direction = gnutls_record_get_direction(sslSession_);
   return direction == 0 ? TLS_WANT_READ : TLS_WANT_WRITE;
 }
 
-ssize_t TLSSession::writeData(const void* data, size_t len)
+ssize_t GnuTLSSession::writeData(const void* data, size_t len)
 {
   while((rv_ = gnutls_record_send(sslSession_, data, len)) ==
         GNUTLS_E_INTERRUPTED);
@@ -133,7 +138,7 @@ ssize_t TLSSession::writeData(const void* data, size_t len)
   }
 }
 
-ssize_t TLSSession::readData(void* data, size_t len)
+ssize_t GnuTLSSession::readData(void* data, size_t len)
 {
   while((rv_ = gnutls_record_recv(sslSession_, data, len)) ==
         GNUTLS_E_INTERRUPTED);
@@ -148,7 +153,7 @@ ssize_t TLSSession::readData(void* data, size_t len)
   }
 }
 
-int TLSSession::tlsConnect(const std::string& hostname,
+int GnuTLSSession::tlsConnect(const std::string& hostname,
                            std::string& handshakeErr)
 {
   handshakeErr = "";
@@ -160,7 +165,7 @@ int TLSSession::tlsConnect(const std::string& hostname,
       return TLS_ERR_ERROR;
     }
   }
-  if(tlsContext_->peerVerificationEnabled()) {
+  if(tlsContext_->getVerifyPeer()) {
     // verify peer
     unsigned int status;
     rv_ = gnutls_certificate_verify_peers2(sslSession_, &status);
@@ -246,7 +251,7 @@ int TLSSession::tlsConnect(const std::string& hostname,
   return TLS_ERR_OK;
 }
 
-int TLSSession::tlsAccept()
+int GnuTLSSession::tlsAccept()
 {
   rv_ = gnutls_handshake(sslSession_);
   if(rv_ == GNUTLS_E_SUCCESS) {
@@ -258,7 +263,7 @@ int TLSSession::tlsAccept()
   }
 }
 
-std::string TLSSession::getLastErrorString()
+std::string GnuTLSSession::getLastErrorString()
 {
   return gnutls_strerror(rv_);
 }

+ 15 - 18
src/LibgnutlsTLSSession.h

@@ -39,31 +39,28 @@
 
 #include <gnutls/gnutls.h>
 
-#include <string>
-
-#include "TLSSessionConst.h"
+#include "LibgnutlsTLSContext.h"
+#include "TLSSession.h"
 #include "a2netcompat.h"
 
 namespace aria2 {
 
-class TLSContext;
-
-class TLSSession {
+class GnuTLSSession : public TLSSession {
 public:
-  TLSSession(TLSContext* tlsContext);
-  ~TLSSession();
-  int init(sock_t sockfd);
-  int setSNIHostname(const std::string& hostname);
-  int closeConnection();
-  int checkDirection();
-  ssize_t writeData(const void* data, size_t len);
-  ssize_t readData(void* data, size_t len);
-  int tlsConnect(const std::string& hostname, std::string& handshakeErr);
-  int tlsAccept();
-  std::string getLastErrorString();
+  GnuTLSSession(GnuTLSContext* tlsContext);
+  ~GnuTLSSession();
+  virtual int init(sock_t sockfd);
+  virtual int setSNIHostname(const std::string& hostname);
+  virtual int closeConnection();
+  virtual int checkDirection();
+  virtual ssize_t writeData(const void* data, size_t len);
+  virtual ssize_t readData(void* data, size_t len);
+  virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr);
+  virtual int tlsAccept();
+  virtual std::string getLastErrorString();
 private:
   gnutls_session_t sslSession_;
-  TLSContext* tlsContext_;
+  GnuTLSContext* tlsContext_;
   // Last error code from gnutls library functions
   int rv_;
 };

+ 12 - 22
src/LibsslTLSContext.cc

@@ -43,10 +43,15 @@
 
 namespace aria2 {
 
-TLSContext::TLSContext(TLSSessionSide side)
+TLSContext* TLSContext::make(TLSSessionSide side)
+{
+  return new OpenSSLTLSContext(side);
+}
+
+OpenSSLTLSContext::OpenSSLTLSContext(TLSSessionSide side)
   : sslCtx_(0),
     side_(side),
-    peerVerificationEnabled_(false)
+    verifyPeer_(true)
 {
   sslCtx_ = SSL_CTX_new(SSLv23_method());
   if(sslCtx_) {
@@ -70,22 +75,17 @@ TLSContext::TLSContext(TLSSessionSide side)
   #endif
 }
 
-TLSContext::~TLSContext()
+OpenSSLTLSContext::~OpenSSLTLSContext()
 {
   SSL_CTX_free(sslCtx_);
 }
 
-bool TLSContext::good() const
+bool OpenSSLTLSContext::good() const
 {
   return good_;
 }
 
-bool TLSContext::bad() const
-{
-  return !good_;
-}
-
-bool TLSContext::addCredentialFile(const std::string& certfile,
+bool OpenSSLTLSContext::addCredentialFile(const std::string& certfile,
                                    const std::string& keyfile)
 {
   if(SSL_CTX_use_PrivateKey_file(sslCtx_, keyfile.c_str(),
@@ -107,7 +107,7 @@ bool TLSContext::addCredentialFile(const std::string& certfile,
   return true;
 }
 
-bool TLSContext::addSystemTrustedCACerts()
+bool OpenSSLTLSContext::addSystemTrustedCACerts()
 {
   if(SSL_CTX_set_default_verify_paths(sslCtx_) != 1) {
     A2_LOG_INFO(fmt(MSG_LOADING_SYSTEM_TRUSTED_CA_CERTS_FAILED,
@@ -119,7 +119,7 @@ bool TLSContext::addSystemTrustedCACerts()
   }
 }
 
-bool TLSContext::addTrustedCACertFile(const std::string& certfile)
+bool OpenSSLTLSContext::addTrustedCACertFile(const std::string& certfile)
 {
   if(SSL_CTX_load_verify_locations(sslCtx_, certfile.c_str(), 0) != 1) {
     A2_LOG_ERROR(fmt(MSG_LOADING_TRUSTED_CA_CERT_FAILED,
@@ -132,14 +132,4 @@ bool TLSContext::addTrustedCACertFile(const std::string& certfile)
   }
 }
 
-void TLSContext::enablePeerVerification()
-{
-  peerVerificationEnabled_ = true;
-}
-
-void TLSContext::disablePeerVerification()
-{
-  peerVerificationEnabled_ = false;
-}
-
 } // namespace aria2

+ 22 - 31
src/LibsslTLSContext.h

@@ -46,52 +46,43 @@
 
 namespace aria2 {
 
-class TLSContext {
-private:
-  SSL_CTX* sslCtx_;
-
-  TLSSessionSide side_;
-
-  bool good_;
-
-  bool peerVerificationEnabled_;
+class OpenSSLTLSContext : public TLSContext {
 public:
-  TLSContext(TLSSessionSide side);
+  OpenSSLTLSContext(TLSSessionSide side);
 
-  ~TLSContext();
+  ~OpenSSLTLSContext();
 
   // private key `keyfile' must be decrypted.
-  bool addCredentialFile(const std::string& certfile,
-                         const std::string& keyfile);
+  virtual bool addCredentialFile(const std::string& certfile,
+                                 const std::string& keyfile);
 
-  bool addSystemTrustedCACerts();
+  virtual bool addSystemTrustedCACerts();
 
   // certfile can contain multiple certificates.
-  bool addTrustedCACertFile(const std::string& certfile);
-
-  bool good() const;
+  virtual bool addTrustedCACertFile(const std::string& certfile);
 
-  bool bad() const;
+  virtual bool good() const;
 
-  SSL_CTX* getSSLCtx() const
-  {
-    return sslCtx_;
-  }
-
-  TLSSessionSide getSide() const
-  {
+  virtual TLSSessionSide getSide() const {
     return side_;
   }
 
-  void enablePeerVerification();
-
-  void disablePeerVerification();
+  virtual bool getVerifyPeer() const {
+    return verifyPeer_;
+  }
+  virtual void setVerifyPeer(bool verify) {
+    verifyPeer_ = verify;
+  }
 
-  bool peerVerificationEnabled() const
-  {
-    return peerVerificationEnabled_;
+  SSL_CTX* getSSLCtx() const {
+    return sslCtx_;
   }
 
+private:
+  SSL_CTX* sslCtx_;
+  TLSSessionSide side_;
+  bool good_;
+  bool verifyPeer_;
 };
 
 } // namespace aria2

+ 19 - 14
src/LibsslTLSSession.cc

@@ -38,26 +38,31 @@
 #include <openssl/x509.h>
 #include <openssl/x509v3.h>
 
-#include "TLSContext.h"
+#include "LogFactory.h"
 #include "util.h"
 #include "SocketCore.h"
 
 namespace aria2 {
 
-TLSSession::TLSSession(TLSContext* tlsContext)
+TLSSession* TLSSession::make(TLSContext* ctx)
+{
+  return new OpenSSLTLSSession(static_cast<OpenSSLTLSContext*>(ctx));
+}
+
+OpenSSLTLSSession::OpenSSLTLSSession(OpenSSLTLSContext* tlsContext)
   : ssl_(0),
     tlsContext_(tlsContext),
     rv_(1)
 {}
 
-TLSSession::~TLSSession()
+OpenSSLTLSSession::~OpenSSLTLSSession()
 {
   if(ssl_) {
     SSL_shutdown(ssl_);
   }
 }
 
-int TLSSession::init(sock_t sockfd)
+int OpenSSLTLSSession::init(sock_t sockfd)
 {
   ERR_clear_error();
   ssl_ = SSL_new(tlsContext_->getSSLCtx());
@@ -71,7 +76,7 @@ int TLSSession::init(sock_t sockfd)
   return TLS_ERR_OK;
 }
 
-int TLSSession::setSNIHostname(const std::string& hostname)
+int OpenSSLTLSSession::setSNIHostname(const std::string& hostname)
 {
 #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME
   ERR_clear_error();
@@ -83,7 +88,7 @@ int TLSSession::setSNIHostname(const std::string& hostname)
   return TLS_ERR_OK;
 }
 
-int TLSSession::closeConnection()
+int OpenSSLTLSSession::closeConnection()
 {
   ERR_clear_error();
   SSL_shutdown(ssl_);
@@ -91,7 +96,7 @@ int TLSSession::closeConnection()
   return TLS_ERR_OK;
 }
 
-int TLSSession::checkDirection()
+int OpenSSLTLSSession::checkDirection()
 {
   int error = SSL_get_error(ssl_, rv_);
   if(error == SSL_ERROR_WANT_WRITE) {
@@ -110,7 +115,7 @@ bool wouldblock(SSL* ssl, int rv)
 }
 } // namespace
 
-ssize_t TLSSession::writeData(const void* data, size_t len)
+ssize_t OpenSSLTLSSession::writeData(const void* data, size_t len)
 {
   ERR_clear_error();
   rv_ = SSL_write(ssl_, data, len);
@@ -127,7 +132,7 @@ ssize_t TLSSession::writeData(const void* data, size_t len)
   }
 }
 
-ssize_t TLSSession::readData(void* data, size_t len)
+ssize_t OpenSSLTLSSession::readData(void* data, size_t len)
 {
   ERR_clear_error();
   rv_ = SSL_read(ssl_, data, len);
@@ -144,7 +149,7 @@ ssize_t TLSSession::readData(void* data, size_t len)
   }
 }
 
-int TLSSession::handshake()
+int OpenSSLTLSSession::handshake()
 {
   ERR_clear_error();
   if(tlsContext_->getSide() == TLS_CLIENT) {
@@ -171,7 +176,7 @@ int TLSSession::handshake()
   return TLS_ERR_OK;
 }
 
-int TLSSession::tlsConnect(const std::string& hostname,
+int OpenSSLTLSSession::tlsConnect(const std::string& hostname,
                            std::string& handshakeErr)
 {
   handshakeErr = "";
@@ -181,7 +186,7 @@ int TLSSession::tlsConnect(const std::string& hostname,
     return ret;
   }
   if(tlsContext_->getSide() == TLS_CLIENT &&
-     tlsContext_->peerVerificationEnabled()) {
+     tlsContext_->getVerifyPeer()) {
     // verify peer
     X509* peerCert = SSL_get_peer_certificate(ssl_);
     if(!peerCert) {
@@ -256,12 +261,12 @@ int TLSSession::tlsConnect(const std::string& hostname,
   return TLS_ERR_OK;
 }
 
-int TLSSession::tlsAccept()
+int OpenSSLTLSSession::tlsAccept()
 {
   return handshake();
 }
 
-std::string TLSSession::getLastErrorString()
+std::string OpenSSLTLSSession::getLastErrorString()
 {
   if(rv_ <= 0) {
     int sslError = SSL_get_error(ssl_, rv_);

+ 15 - 18
src/LibsslTLSSession.h

@@ -39,32 +39,29 @@
 
 #include <openssl/ssl.h>
 
-#include <string>
-
-#include "TLSSessionConst.h"
+#include "LibsslTLSContext.h"
+#include "TLSSession.h"
 #include "a2netcompat.h"
 
 namespace aria2 {
 
-class TLSContext;
-
-class TLSSession {
+class OpenSSLTLSSession : public TLSSession {
 public:
-  TLSSession(TLSContext* tlsContext);
-  ~TLSSession();
-  int init(sock_t sockfd);
-  int setSNIHostname(const std::string& hostname);
-  int closeConnection();
-  int checkDirection();
-  ssize_t writeData(const void* data, size_t len);
-  ssize_t readData(void* data, size_t len);
-  int tlsConnect(const std::string& hostname, std::string& handshakeErr);
-  int tlsAccept();
-  std::string getLastErrorString();
+  OpenSSLTLSSession(OpenSSLTLSContext* tlsContext);
+  virtual ~OpenSSLTLSSession();
+  virtual int init(sock_t sockfd);
+  virtual int setSNIHostname(const std::string& hostname);
+  virtual int closeConnection();
+  virtual int checkDirection();
+  virtual ssize_t writeData(const void* data, size_t len);
+  virtual ssize_t readData(void* data, size_t len);
+  virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr);
+  virtual int tlsAccept();
+  virtual std::string getLastErrorString();
 private:
   int handshake();
   SSL* ssl_;
-  TLSContext* tlsContext_;
+  OpenSSLTLSContext* tlsContext_;
   // Last error code from openSSL library functions
   int rv_;
 };

+ 32 - 17
src/Makefile.am

@@ -299,38 +299,53 @@ SRCS += EpollEventPoll.cc EpollEventPoll.h
 endif # HAVE_EPOLL
 
 if ENABLE_SSL
-SRCS += TLSContext.h\
-	TLSSession.h\
-	TLSSessionConst.h
+SRCS += TLSSession.h TLSSessionConst.h
 endif # ENABLE_SSL
 
+if USE_APPLE_MD
+SRCS += AppleMessageDigestImpl.cc AppleMessageDigestImpl.h
+endif
+
+if HAVE_APPLETLS
+SRCS += AppleTLSContext.cc AppleTLSContext.h \
+        AppleTLSSession.cc AppleTLSSession.h
+endif
+
 if HAVE_LIBGNUTLS
-SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h\
-	LibgnutlsTLSSession.cc LibgnutlsTLSSession.h
+SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h \
+        LibgnutlsTLSSession.cc LibgnutlsTLSSession.h
 endif # HAVE_LIBGNUTLS
 
 if HAVE_LIBGCRYPT
-SRCS += LibgcryptMessageDigestImpl.cc LibgcryptMessageDigestImpl.h\
-	LibgcryptARC4Encryptor.cc LibgcryptARC4Encryptor.h\
-	LibgcryptDHKeyExchange.cc LibgcryptDHKeyExchange.h
+SRCS += LibgcryptARC4Encryptor.cc LibgcryptARC4Encryptor.h \
+        LibgcryptDHKeyExchange.cc LibgcryptDHKeyExchange.h
+if USE_LIBGCRYPT_MD
+SRCS += LibgcryptMessageDigestImpl.cc LibgcryptMessageDigestImpl.h
+endif
 endif # HAVE_LIBGCRYPT
 
 if HAVE_LIBNETTLE
-SRCS += LibnettleMessageDigestImpl.cc LibnettleMessageDigestImpl.h\
-	LibnettleARC4Encryptor.cc LibnettleARC4Encryptor.h
+SRCS += LibnettleARC4Encryptor.cc LibnettleARC4Encryptor.h
+if USE_LIBNETTLE_MD
+SRCS += LibnettleMessageDigestImpl.cc LibnettleMessageDigestImpl.h
+endif
 endif # HAVE_LIBNETTLE
 
 if HAVE_LIBGMP
-SRCS += a2gmp.cc a2gmp.h\
-	LibgmpDHKeyExchange.cc LibgmpDHKeyExchange.h
+SRCS += a2gmp.cc a2gmp.h \
+        LibgmpDHKeyExchange.cc LibgmpDHKeyExchange.h
 endif # HAVE_LIBGMP
 
 if HAVE_OPENSSL
-SRCS += LibsslTLSContext.cc LibsslTLSContext.h\
-	LibsslTLSSession.cc LibsslTLSSession.h\
-	LibsslMessageDigestImpl.cc LibsslMessageDigestImpl.h\
-	LibsslARC4Encryptor.cc LibsslARC4Encryptor.h\
-	LibsslDHKeyExchange.cc LibsslDHKeyExchange.h
+SRCS += LibsslARC4Encryptor.cc LibsslARC4Encryptor.h \
+        LibsslDHKeyExchange.cc LibsslDHKeyExchange.h
+if !HAVE_APPLETLS
+SRCS += LibsslTLSContext.cc LibsslTLSContext.h \
+        LibsslTLSSession.cc LibsslTLSSession.h
+endif
+if USE_OPENSSL_MD
+SRCS += LibsslMessageDigestImpl.cc LibsslMessageDigestImpl.h
+endif
 endif # HAVE_OPENSSL
 
 if HAVE_ZLIB

+ 7 - 4
src/MessageDigestImpl.h

@@ -35,12 +35,15 @@
 #ifndef D_MESSAGE_DIGEST_IMPL_H
 #define D_MESSAGE_DIGEST_IMPL_H
 
-#ifdef HAVE_LIBNETTLE
+
+#ifdef USE_APPLE_MD
+# include "AppleMessageDigestImpl.h"
+#elif defined(USE_LIBNETTLE_MD)
 # include "LibnettleMessageDigestImpl.h"
-#elif HAVE_LIBGCRYPT
+#elif defined(USE_LIBGCRYPT_MD)
 # include "LibgcryptMessageDigestImpl.h"
-#elif HAVE_OPENSSL
+#elif defined(USE_OPENSSL_MD)
 # include "LibsslMessageDigestImpl.h"
-#endif // HAVE_OPENSSL
+#endif
 
 #endif // D_MESSAGE_DIGEST_IMPL_H

+ 3 - 5
src/MultiUrlRequestInfo.cc

@@ -145,7 +145,7 @@ error_code::Value MultiUrlRequestInfo::execute()
          !option_->blank(PREF_RPC_PRIVATE_KEY)) {
         // We set server TLS context to the SocketCore before creating
         // DownloadEngine instance.
-        SharedHandle<TLSContext> svTlsContext(new TLSContext(TLS_SERVER));
+        SharedHandle<TLSContext> svTlsContext(TLSContext::make(TLS_SERVER));
         svTlsContext->addCredentialFile(option_->get(PREF_RPC_CERTIFICATE),
                                         option_->get(PREF_RPC_PRIVATE_KEY));
         SocketCore::setServerTLSContext(svTlsContext);
@@ -194,7 +194,7 @@ error_code::Value MultiUrlRequestInfo::execute()
     e->setAuthConfigFactory(authConfigFactory);
 
 #ifdef ENABLE_SSL
-    SharedHandle<TLSContext> clTlsContext(new TLSContext(TLS_CLIENT));
+    SharedHandle<TLSContext> clTlsContext(TLSContext::make(TLS_CLIENT));
     if(!option_->blank(PREF_CERTIFICATE) &&
        !option_->blank(PREF_PRIVATE_KEY)) {
       clTlsContext->addCredentialFile(option_->get(PREF_CERTIFICATE),
@@ -211,9 +211,7 @@ error_code::Value MultiUrlRequestInfo::execute()
         A2_LOG_INFO(MSG_WARN_NO_CA_CERT);
       }
     }
-    if(option_->getAsBool(PREF_CHECK_CERTIFICATE)) {
-      clTlsContext->enablePeerVerification();
-    }
+    clTlsContext->setVerifyPeer(option_->getAsBool(PREF_CHECK_CERTIFICATE));
     SocketCore::setClientTLSContext(clTlsContext);
 #endif
 #ifdef HAVE_ARES_ADDR_NODE

+ 1 - 1
src/SocketCore.cc

@@ -819,7 +819,7 @@ bool SocketCore::tlsHandshake(TLSContext* tlsctx, const std::string& hostname)
   wantWrite_ = false;
   switch(secure_) {
   case A2_TLS_NONE:
-    tlsSession_.reset(new TLSSession(tlsctx));
+    tlsSession_.reset(TLSSession::make(tlsctx));
     rv = tlsSession_->init(sockfd_);
     if(rv != TLS_ERR_OK) {
       std::string error = tlsSession_->getLastErrorString();

+ 23 - 6
src/TLSContext.h

@@ -35,6 +35,8 @@
 #ifndef D_TLS_CONTEXT_H
 #define D_TLS_CONTEXT_H
 
+#include <string>
+
 #include "common.h"
 
 namespace aria2 {
@@ -44,12 +46,27 @@ enum TLSSessionSide {
   TLS_SERVER
 };
 
-} // namespace aria2
+class TLSContext {
+public:
+  static TLSContext* make(TLSSessionSide side);
+  virtual ~TLSContext() {}
+
+  // private key `keyfile' must be decrypted.
+  virtual bool addCredentialFile(const std::string& certfile,
+                                 const std::string& keyfile) = 0;
+
+  virtual bool addSystemTrustedCACerts() = 0;
+
+  // certfile can contain multiple certificates.
+  virtual bool addTrustedCACertFile(const std::string& certfile) = 0;
 
-#ifdef HAVE_OPENSSL
-# include "LibsslTLSContext.h"
-#elif HAVE_LIBGNUTLS
-# include "LibgnutlsTLSContext.h"
-#endif // HAVE_LIBGNUTLS
+  virtual bool good() const = 0;
+
+  virtual TLSSessionSide getSide() const = 0;
+  virtual bool getVerifyPeer() const = 0;
+  virtual void setVerifyPeer(bool) = 0;
+};
+
+} // namespace aria2
 
 #endif // D_TLS_CONTEXT_H

+ 78 - 61
src/TLSSession.h

@@ -36,69 +36,86 @@
 #define TLS_SESSION_H
 
 #include "common.h"
+#include "a2netcompat.h"
+#include "TLSContext.h"
+
+namespace aria2 {
+
+enum TLSDirection {
+  TLS_WANT_READ = 1,
+  TLS_WANT_WRITE
+};
+
+enum TLSErrorCode {
+  TLS_ERR_OK = 0,
+  TLS_ERR_ERROR = -1,
+  TLS_ERR_WOULDBLOCK = -2
+};
 
 // To create another SSL/TLS backend, implement TLSSession class below.
 //
-// class TLSSession {
-// public:
-//   TLSSession(TLSContext* tlsContext);
-//
-//   // MUST deallocate all resources
-//   ~TLSSession();
-//
-//   // Initializes SSL/TLS session. The |sockfd| is the underlying
-//   // tranport socket. This function returns TLS_ERR_OK if it
-//   // succeeds, or TLS_ERR_ERROR.
-//   int init(sock_t sockfd);
-//
-//   // Sets |hostname| for TLS SNI extension. This is only meaningful for
-//   // client side session. This function returns TLS_ERR_OK if it
-//   // succeeds, or TLS_ERR_ERROR.
-//   int setSNIHostname(const std::string& hostname);
-//
-//   // Closes the SSL/TLS session. Don't close underlying transport
-//   // socket. This function returns TLS_ERR_OK if it succeeds, or
-//   // TLS_ERR_ERROR.
-//   int closeConnection();
-//
-//   // Returns TLS_WANT_READ if SSL/TLS session needs more data from
-//   // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session
-//   // needs to write more data to proceed. If SSL/TLS session needs
-//   // neither read nor write data at the moment, return value is
-//   // undefined.
-//   int checkDirection();
-//
-//   // Sends |data| with length |len|. This function returns the number
-//   // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the
-//   // underlying tranport blocks, or TLS_ERR_ERROR.
-//   ssize_t writeData(const void* data, size_t len);
-//
-//   // Receives data into |data| with length |len|. This function returns
-//   // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK
-//   // if the underlying tranport blocks, or TLS_ERR_ERROR.
-//   ssize_t readData(void* data, size_t len);
-//
-//   // Performs client side handshake. The |hostname| is the hostname of
-//   // the remote endpoint and is used to verify its certificate. This
-//   // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK
-//   // if the underlying transport blocks, or TLS_ERR_ERROR.
-//   // When returning TLS_ERR_ERROR, provide certificate validation error
-//   // in |handshakeErr|.
-//   int tlsConnect(const std::string& hostname, std::string& handshakeErr);
-//
-//   // Performs server side handshake. This function returns TLS_ERR_OK
-//   // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
-//   // blocks, or TLS_ERR_ERROR.
-//   int tlsAccept();
-//
-//   // Returns last error string
-//   std::string getLastErrorString();
-// };
-
-#ifdef HAVE_OPENSSL
-# include "LibsslTLSSession.h"
-#elif defined HAVE_LIBGNUTLS
-# include "LibgnutlsTLSSession.h"
-#endif
+class TLSSession {
+public:
+  static TLSSession* make(TLSContext* ctx);
+
+  // MUST deallocate all resources
+  virtual ~TLSSession() {}
+
+  // Initializes SSL/TLS session. The |sockfd| is the underlying
+  // tranport socket. This function returns TLS_ERR_OK if it
+  // succeeds, or TLS_ERR_ERROR.
+  virtual int init(sock_t sockfd) = 0;
+
+  // Sets |hostname| for TLS SNI extension. This is only meaningful for
+  // client side session. This function returns TLS_ERR_OK if it
+  // succeeds, or TLS_ERR_ERROR.
+  virtual int setSNIHostname(const std::string& hostname) = 0;
+
+  // Closes the SSL/TLS session. Don't close underlying transport
+  // socket. This function returns TLS_ERR_OK if it succeeds, or
+  // TLS_ERR_ERROR.
+  virtual int closeConnection() = 0;
+
+  // Returns TLS_WANT_READ if SSL/TLS session needs more data from
+  // remote endpoint to proceed, or TLS_WANT_WRITE if SSL/TLS session
+  // needs to write more data to proceed. If SSL/TLS session needs
+  // neither read nor write data at the moment, return value is
+  // undefined.
+  virtual int checkDirection() = 0;
+
+  // Sends |data| with length |len|. This function returns the number
+  // of bytes sent if it succeeds, or TLS_ERR_WOULDBLOCK if the
+  // underlying tranport blocks, or TLS_ERR_ERROR.
+  virtual ssize_t writeData(const void* data, size_t len) = 0;
+
+  // Receives data into |data| with length |len|. This function returns
+  // the number of bytes received if it succeeds, or TLS_ERR_WOULDBLOCK
+  // if the underlying tranport blocks, or TLS_ERR_ERROR.
+  virtual ssize_t readData(void* data, size_t len) = 0;
+
+  // Performs client side handshake. The |hostname| is the hostname of
+  // the remote endpoint and is used to verify its certificate. This
+  // function returns TLS_ERR_OK if it succeeds, or TLS_ERR_WOULDBLOCK
+  // if the underlying transport blocks, or TLS_ERR_ERROR.
+  // When returning TLS_ERR_ERROR, provide certificate validation error
+  // in |handshakeErr|.
+  virtual int tlsConnect(const std::string& hostname, std::string& handshakeErr) = 0;
+
+  // Performs server side handshake. This function returns TLS_ERR_OK
+  // if it succeeds, or TLS_ERR_WOULDBLOCK if the underlying transport
+  // blocks, or TLS_ERR_ERROR.
+  virtual int tlsAccept() = 0;
+
+  // Returns last error string
+  virtual std::string getLastErrorString() = 0;
+
+protected:
+  TLSSession() {}
+private:
+  TLSSession(const TLSSession&);
+  TLSSession& operator=(const TLSSession&);
+};
+
+}
 
 #endif // TLS_SESSION_H