WinTLSSession.cc 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886
  1. /* <!-- copyright */
  2. /*
  3. * aria2 - The high speed download utility
  4. *
  5. * Copyright (C) 2013 Nils Maier
  6. *
  7. * This program is free software; you can redistribute it and/or modify
  8. * it under the terms of the GNU General Public License as published by
  9. * the Free Software Foundation; either version 2 of the License, or
  10. * (at your option) any later version.
  11. *
  12. * This program is distributed in the hope that it will be useful,
  13. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  15. * GNU General Public License for more details.
  16. *
  17. * You should have received a copy of the GNU General Public License
  18. * along with this program; if not, write to the Free Software
  19. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  20. *
  21. * In addition, as a special exception, the copyright holders give
  22. * permission to link the code of portions of this program with the
  23. * OpenSSL library under certain conditions as described in each
  24. * individual source file, and distribute linked combinations
  25. * including the two.
  26. * You must obey the GNU General Public License in all respects
  27. * for all of the code used other than OpenSSL. If you modify
  28. * file(s) with this exception, you may extend this exception to your
  29. * version of the file(s), but you are not obligated to do so. If you
  30. * do not wish to do so, delete this exception statement from your
  31. * version. If you delete this exception statement from all source
  32. * files in the program, then also delete it here.
  33. */
  34. /* copyright --> */
  35. #include "WinTLSSession.h"
  36. #include <cassert>
  37. #include <sstream>
  38. #include "LogFactory.h"
  39. #include "a2functional.h"
  40. #include "fmt.h"
  41. #include "util.h"
  42. #ifndef SECBUFFER_ALERT
  43. #define SECBUFFER_ALERT 17
  44. #endif
  45. #ifndef SZ_ALG_MAX_SIZE
  46. #define SZ_ALG_MAX_SIZE 64
  47. #endif
  48. #ifndef SECPKGCONTEXT_CIPHERINFO_V1
  49. #define SECPKGCONTEXT_CIPHERINFO_V1 1
  50. #endif
  51. #ifndef SECPKG_ATTR_CIPHER_INFO
  52. #define SECPKG_ATTR_CIPHER_INFO 0x64
  53. #endif
  54. namespace {
  55. using namespace aria2;
  56. struct WinSecPkgContext_CipherInfo
  57. {
  58. DWORD dwVersion;
  59. DWORD dwProtocol;
  60. DWORD dwCipherSuite;
  61. DWORD dwBaseCipherSuite;
  62. WCHAR szCipherSuite[SZ_ALG_MAX_SIZE];
  63. WCHAR szCipher[SZ_ALG_MAX_SIZE];
  64. DWORD dwCipherLen;
  65. DWORD dwCipherBlockLen; // in bytes
  66. WCHAR szHash[SZ_ALG_MAX_SIZE];
  67. DWORD dwHashLen;
  68. WCHAR szExchange[SZ_ALG_MAX_SIZE];
  69. DWORD dwMinExchangeLen;
  70. DWORD dwMaxExchangeLen;
  71. WCHAR szCertificate[SZ_ALG_MAX_SIZE];
  72. DWORD dwKeyType;
  73. };
  74. static const ULONG kReqFlags =
  75. ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY |
  76. ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_USE_SUPPLIED_CREDS | ISC_REQ_STREAM;
  77. static const ULONG kReqAFlags =
  78. ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
  79. ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM;
  80. class TLSBuffer : public ::SecBuffer
  81. {
  82. public:
  83. explicit TLSBuffer(ULONG type, ULONG size, void* data)
  84. {
  85. cbBuffer = size;
  86. BufferType = type;
  87. pvBuffer = data;
  88. }
  89. };
  90. class TLSBufferDesc : public ::SecBufferDesc
  91. {
  92. public:
  93. explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers)
  94. {
  95. ulVersion = SECBUFFER_VERSION;
  96. cBuffers = buffers;
  97. pBuffers = arr;
  98. }
  99. };
  100. inline static std::string getCipherSuite(CtxtHandle* handle)
  101. {
  102. WinSecPkgContext_CipherInfo info = {SECPKGCONTEXT_CIPHERINFO_V1};
  103. if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
  104. SEC_E_OK) {
  105. return wCharToUtf8(info.szCipherSuite);
  106. }
  107. return "Unknown";
  108. }
  109. inline static uint32_t getProtocolVersion(CtxtHandle* handle)
  110. {
  111. WinSecPkgContext_CipherInfo info = {SECPKGCONTEXT_CIPHERINFO_V1};
  112. if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
  113. SEC_E_OK) {
  114. return info.dwProtocol;
  115. }
  116. // XXX Assume the best?!
  117. return std::numeric_limits<uint32_t>::max();
  118. }
  119. } // namespace
  120. namespace aria2 {
  121. TLSSession* TLSSession::make(TLSContext* ctx)
  122. {
  123. return new WinTLSSession(static_cast<WinTLSContext*>(ctx));
  124. }
  125. WinTLSSession::WinTLSSession(WinTLSContext* ctx)
  126. : sockfd_(0),
  127. side_(ctx->getSide()),
  128. cred_(ctx->getCredHandle()),
  129. writeBuffered_(0),
  130. state_(st_constructed),
  131. status_(SEC_E_OK)
  132. {
  133. memset(&handle_, 0, sizeof(handle_));
  134. }
  135. WinTLSSession::~WinTLSSession()
  136. {
  137. ::DeleteSecurityContext(&handle_);
  138. state_ = st_error;
  139. }
  140. int WinTLSSession::init(sock_t sockfd)
  141. {
  142. if (state_ != st_constructed) {
  143. status_ = SEC_E_INVALID_HANDLE;
  144. return TLS_ERR_ERROR;
  145. }
  146. sockfd_ = sockfd;
  147. state_ = st_initialized;
  148. return TLS_ERR_OK;
  149. }
  150. int WinTLSSession::setSNIHostname(const std::string& hostname)
  151. {
  152. if (state_ != st_initialized) {
  153. status_ = SEC_E_INVALID_HANDLE;
  154. return TLS_ERR_ERROR;
  155. }
  156. hostname_ = hostname;
  157. return TLS_ERR_OK;
  158. }
  159. int WinTLSSession::closeConnection()
  160. {
  161. if (state_ != st_connected && state_ != st_closing) {
  162. if (state_ != st_error) {
  163. status_ = SEC_E_INVALID_HANDLE;
  164. state_ = st_error;
  165. }
  166. A2_LOG_DEBUG("WinTLS: Cannot close connection");
  167. return TLS_ERR_ERROR;
  168. }
  169. if (state_ == st_connected) {
  170. A2_LOG_DEBUG("WinTLS: Closing connection");
  171. state_ = st_closing;
  172. DWORD dwShut = SCHANNEL_SHUTDOWN;
  173. TLSBuffer shut(SECBUFFER_TOKEN, sizeof(dwShut), &dwShut);
  174. TLSBufferDesc shutDesc(&shut, 1);
  175. status_ = ::ApplyControlToken(&handle_, &shutDesc);
  176. if (status_ != SEC_E_OK) {
  177. state_ = st_error;
  178. return TLS_ERR_ERROR;
  179. }
  180. TLSBuffer ctx(SECBUFFER_EMPTY, 0, nullptr);
  181. TLSBufferDesc desc(&ctx, 1);
  182. ULONG flags = 0;
  183. if (side_ == TLS_CLIENT) {
  184. SEC_CHAR* host = hostname_.empty() ?
  185. nullptr :
  186. const_cast<SEC_CHAR*>(hostname_.c_str());
  187. status_ = ::InitializeSecurityContext(cred_,
  188. &handle_,
  189. host,
  190. kReqFlags,
  191. 0,
  192. 0,
  193. nullptr,
  194. 0,
  195. &handle_,
  196. &desc,
  197. &flags,
  198. nullptr);
  199. }
  200. else {
  201. status_ = ::AcceptSecurityContext(cred_,
  202. &handle_,
  203. nullptr,
  204. kReqAFlags,
  205. 0,
  206. &handle_,
  207. &desc,
  208. &flags,
  209. nullptr);
  210. }
  211. if (status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) {
  212. size_t len = ctx.cbBuffer;
  213. ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
  214. ::FreeContextBuffer(ctx.pvBuffer);
  215. if (rv == TLS_ERR_WOULDBLOCK) {
  216. return rv;
  217. }
  218. // Alright data is sent or buffered
  219. if (rv - len != 0) {
  220. return TLS_ERR_WOULDBLOCK;
  221. }
  222. }
  223. }
  224. // Send remaining data.
  225. while (writeBuf_.size()) {
  226. int rv = writeData(nullptr, 0);
  227. if (rv == TLS_ERR_WOULDBLOCK) {
  228. return rv;
  229. }
  230. }
  231. A2_LOG_DEBUG("WinTLS: Closed Connection");
  232. state_ = st_closed;
  233. return TLS_ERR_OK;
  234. }
  235. int WinTLSSession::checkDirection()
  236. {
  237. if (state_ == st_handshake_write || state_ == st_handshake_write_last) {
  238. return TLS_WANT_WRITE;
  239. }
  240. if (state_ == st_handshake_read) {
  241. return TLS_WANT_READ;
  242. }
  243. if (readBuf_.size() || decBuf_.size()) {
  244. return TLS_WANT_READ;
  245. }
  246. if (writeBuf_.size()) {
  247. return TLS_WANT_WRITE;
  248. }
  249. return TLS_WANT_READ;
  250. }
  251. ssize_t WinTLSSession::writeData(const void* data, size_t len)
  252. {
  253. if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
  254. state_ == st_handshake_read) {
  255. // Renegotiating
  256. std::string hn, err;
  257. TLSVersion ver;
  258. auto connect = tlsConnect(hn, ver, err);
  259. if (connect != TLS_ERR_OK) {
  260. return connect;
  261. }
  262. // Continue.
  263. }
  264. if (state_ != st_connected && state_ != st_closing) {
  265. status_ = SEC_E_INVALID_HANDLE;
  266. return TLS_ERR_ERROR;
  267. }
  268. A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
  269. (uint64_t)len,
  270. (uint64_t)writeBuf_.size()));
  271. // Write remaining buffered data, if any.
  272. size_t written = 0;
  273. while (writeBuf_.size()) {
  274. written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
  275. errno = ::WSAGetLastError();
  276. if (written < 0 && errno == WSAEINTR) {
  277. continue;
  278. }
  279. if (written < 0 && errno == WSAEWOULDBLOCK) {
  280. return TLS_ERR_WOULDBLOCK;
  281. }
  282. if (written == 0) {
  283. return written;
  284. }
  285. if (written < 0) {
  286. status_ = SEC_E_INVALID_HANDLE;
  287. state_ = st_error;
  288. return TLS_ERR_ERROR;
  289. }
  290. writeBuf_.eat(written);
  291. }
  292. if (len == 0) {
  293. return 0;
  294. }
  295. if (!streamSizes_) {
  296. streamSizes_.reset(new SecPkgContext_StreamSizes());
  297. status_ = ::QueryContextAttributes(
  298. &handle_, SECPKG_ATTR_STREAM_SIZES, streamSizes_.get());
  299. if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
  300. state_ = st_error;
  301. return TLS_ERR_ERROR;
  302. }
  303. }
  304. size_t process = len;
  305. auto bytes = reinterpret_cast<const char*>(data);
  306. if (writeBuffered_) {
  307. // There was buffered data, hence we need to "remove" that data from the
  308. // incoming buffer to avoid writing it again
  309. if (len < writeBuffered_) {
  310. // We didn't get called with the same data again, obviously.
  311. status_ = SEC_E_INVALID_HANDLE;
  312. status_ = st_error;
  313. return TLS_ERR_ERROR;
  314. }
  315. // just advance the buffer by writeBuffered_ bytes
  316. bytes += writeBuffered_;
  317. process -= writeBuffered_;
  318. writeBuffered_ = 0;
  319. }
  320. if (!process) {
  321. // The buffer contained the full remainder. At this point, the buffer has
  322. // been written, so the request is done in its entirety;
  323. return len;
  324. }
  325. // Buffered data was already written ;)
  326. // If there was no buffered data, this will be len - len = 0.
  327. len = len - process;
  328. while (process) {
  329. // Set up an outgoing message, according to streamSizes_
  330. writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
  331. size_t dl =
  332. streamSizes_->cbHeader + writeBuffered_ + streamSizes_->cbTrailer;
  333. auto buf = make_unique<char[]>(dl);
  334. TLSBuffer buffers[] = {
  335. TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
  336. TLSBuffer(
  337. SECBUFFER_DATA, writeBuffered_, buf.get() + streamSizes_->cbHeader),
  338. TLSBuffer(SECBUFFER_STREAM_TRAILER,
  339. streamSizes_->cbTrailer,
  340. buf.get() + streamSizes_->cbHeader + writeBuffered_),
  341. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  342. };
  343. TLSBufferDesc desc(buffers, 4);
  344. memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
  345. status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
  346. if (status_ != SEC_E_OK) {
  347. A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
  348. getLastErrorString().c_str()));
  349. state_ = st_error;
  350. return TLS_ERR_ERROR;
  351. }
  352. // EncryptMessage may have truncated the buffers.
  353. // Should rarely happen, if ever, except for the trailer.
  354. dl = buffers[0].cbBuffer;
  355. if (dl < streamSizes_->cbHeader) {
  356. // Move message.
  357. memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
  358. }
  359. dl += buffers[1].cbBuffer;
  360. if (dl < streamSizes_->cbHeader + writeBuffered_) {
  361. // Move trailer.
  362. memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
  363. }
  364. dl += buffers[2].cbBuffer;
  365. // Write (or buffer) the message.
  366. char* p = buf.get();
  367. while (dl) {
  368. written = ::send(sockfd_, p, dl, 0);
  369. errno = ::WSAGetLastError();
  370. if (written < 0 && errno == WSAEINTR) {
  371. continue;
  372. }
  373. if (written < 0 && errno == WSAEWOULDBLOCK) {
  374. // Buffer the rest of the message...
  375. writeBuf_.write(p, dl);
  376. // and return...
  377. return len;
  378. }
  379. if (written == 0) {
  380. A2_LOG_ERROR("WinTLS: Connection closed while writing");
  381. status_ = SEC_E_INCOMPLETE_MESSAGE;
  382. state_ = st_error;
  383. return TLS_ERR_ERROR;
  384. }
  385. if (written < 0) {
  386. A2_LOG_ERROR("WinTLS: Connection error while writing");
  387. status_ = SEC_E_INCOMPLETE_MESSAGE;
  388. state_ = st_error;
  389. return TLS_ERR_ERROR;
  390. }
  391. dl -= written;
  392. p += written;
  393. }
  394. len += writeBuffered_;
  395. bytes += writeBuffered_;
  396. process -= writeBuffered_;
  397. writeBuffered_ = 0;
  398. }
  399. A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
  400. (uint64_t)len,
  401. (uint64_t)writeBuf_.size()));
  402. if (!len) {
  403. return TLS_ERR_WOULDBLOCK;
  404. }
  405. return len;
  406. }
  407. ssize_t WinTLSSession::readData(void* data, size_t len)
  408. {
  409. A2_LOG_DEBUG(fmt("WinTLS: Read request: %" PRIu64 " buffered: %" PRIu64,
  410. (uint64_t)len,
  411. (uint64_t)readBuf_.size()));
  412. if (len == 0) {
  413. return 0;
  414. }
  415. // Can be filled from decBuffer entirely?
  416. if (decBuf_.size() >= len) {
  417. A2_LOG_DEBUG("WinTLS: Fullfilling req from buffer");
  418. memcpy(data, decBuf_.data(), len);
  419. decBuf_.eat(len);
  420. return len;
  421. }
  422. if (state_ == st_closing || state_ == st_closed || state_ == st_error) {
  423. auto nread = decBuf_.size();
  424. if (nread) {
  425. assert(nread < len);
  426. memcpy(data, decBuf_.data(), nread);
  427. decBuf_.clear();
  428. A2_LOG_DEBUG("WinTLS: Sending out decrypted buffer after EOF");
  429. return nread;
  430. }
  431. A2_LOG_DEBUG("WinTLS: Read request aborted. Connection already closed");
  432. return state_ == st_error ? TLS_ERR_ERROR : 0;
  433. }
  434. if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
  435. state_ == st_handshake_read) {
  436. // Renegotiating
  437. std::string hn, err;
  438. TLSVersion ver;
  439. auto connect = tlsConnect(hn, ver, err);
  440. if (connect != TLS_ERR_OK) {
  441. return connect;
  442. }
  443. // Continue.
  444. }
  445. if (state_ != st_connected) {
  446. status_ = SEC_E_INVALID_HANDLE;
  447. return TLS_ERR_ERROR;
  448. }
  449. // Read as many bytes as available from the connection, up to len + 4k.
  450. readBuf_.resize(len + 4096);
  451. while (readBuf_.free()) {
  452. ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
  453. errno = ::WSAGetLastError();
  454. if (read < 0 && errno == WSAEINTR) {
  455. continue;
  456. }
  457. if (read < 0 && errno == WSAEWOULDBLOCK) {
  458. break;
  459. }
  460. if (read < 0) {
  461. status_ = SEC_E_INCOMPLETE_MESSAGE;
  462. state_ = st_error;
  463. return TLS_ERR_ERROR;
  464. }
  465. if (read == 0) {
  466. A2_LOG_DEBUG("WinTLS: Connection abruptly closed!");
  467. // At least try to gracefully close our write end.
  468. closeConnection();
  469. break;
  470. }
  471. readBuf_.advance(read);
  472. }
  473. // Try to decrypt as many messages as possible from the readBuf_.
  474. while (readBuf_.size()) {
  475. TLSBuffer bufs[] = {
  476. TLSBuffer(SECBUFFER_DATA, readBuf_.size(), readBuf_.data()),
  477. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  478. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  479. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  480. };
  481. TLSBufferDesc desc(bufs, 4);
  482. status_ = ::DecryptMessage(&handle_, &desc, 0, nullptr);
  483. if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
  484. // Need to stop now, and wait for more bytes to arrive on the socket.
  485. break;
  486. }
  487. if (status_ != SEC_E_OK && status_ != SEC_I_CONTEXT_EXPIRED &&
  488. status_ != SEC_I_RENEGOTIATE) {
  489. A2_LOG_ERROR(fmt("WinTLS: Failed to decrypt a message! %s",
  490. getLastErrorString().c_str()));
  491. state_ = st_error;
  492. return TLS_ERR_ERROR;
  493. }
  494. // Decrypted message successfully.
  495. bool ate = false;
  496. for (auto& buf : bufs) {
  497. if (buf.BufferType == SECBUFFER_DATA && buf.cbBuffer > 0) {
  498. decBuf_.write(buf.pvBuffer, buf.cbBuffer);
  499. }
  500. else if (buf.BufferType == SECBUFFER_EXTRA && buf.cbBuffer > 0) {
  501. readBuf_.eat(readBuf_.size() - buf.cbBuffer);
  502. ate = true;
  503. }
  504. }
  505. if (!ate) {
  506. readBuf_.clear();
  507. }
  508. if (status_ == SEC_I_RENEGOTIATE) {
  509. // Renegotiation basically means performing another handshake
  510. state_ = st_initialized;
  511. A2_LOG_INFO("WinTLS: Renegotiate");
  512. std::string hn, err;
  513. TLSVersion ver;
  514. auto connect = tlsConnect(hn, ver, err);
  515. if (connect == TLS_ERR_WOULDBLOCK) {
  516. break;
  517. }
  518. if (connect == TLS_ERR_ERROR) {
  519. return connect;
  520. }
  521. // Still good.
  522. }
  523. if (status_ == SEC_I_CONTEXT_EXPIRED) {
  524. // Connection is gone now, but the buffered bytes are still valid.
  525. A2_LOG_DEBUG("WinTLS: Connection gracefully closed!");
  526. closeConnection();
  527. break;
  528. }
  529. }
  530. len = std::min(decBuf_.size(), len);
  531. if (len == 0) {
  532. if (state_ != st_connected) {
  533. return state_ == st_error ? TLS_ERR_ERROR : 0;
  534. }
  535. return TLS_ERR_WOULDBLOCK;
  536. }
  537. memcpy(data, decBuf_.data(), len);
  538. decBuf_.eat(len);
  539. return len;
  540. }
  541. int WinTLSSession::tlsConnect(const std::string& hostname,
  542. TLSVersion& version,
  543. std::string& handshakeErr)
  544. {
  545. // Handshaking will require sending multiple read/write exchanges until the
  546. // handshake is actually done. The client will first generate the initial
  547. // handshake message, then write that to the server, read the response
  548. // message, and write and/or read additional messages until the handshake is
  549. // either complete and successful, or something went wrong.
  550. // The server works analog to that.
  551. A2_LOG_DEBUG("WinTLS: Starting/Resuming TLS Connect");
  552. ULONG flags = 0;
  553. restart:
  554. switch (state_) {
  555. default:
  556. A2_LOG_ERROR("WinTLS: Invalid state");
  557. status_ = SEC_E_INVALID_HANDLE;
  558. return TLS_ERR_ERROR;
  559. case st_initialized: {
  560. if (side_ == TLS_SERVER) {
  561. goto read;
  562. }
  563. if (!hostname.empty()) {
  564. setSNIHostname(hostname);
  565. }
  566. A2_LOG_DEBUG("WinTLS: Initializing handshake");
  567. TLSBuffer buf(SECBUFFER_EMPTY, 0, nullptr);
  568. TLSBufferDesc desc(&buf, 1);
  569. SEC_CHAR* host =
  570. hostname_.empty() ? nullptr : const_cast<SEC_CHAR*>(hostname_.c_str());
  571. status_ = ::InitializeSecurityContext(cred_,
  572. nullptr,
  573. host,
  574. kReqFlags,
  575. 0,
  576. 0,
  577. nullptr,
  578. 0,
  579. &handle_,
  580. &desc,
  581. &flags,
  582. nullptr);
  583. if (status_ != SEC_I_CONTINUE_NEEDED) {
  584. // Has to be SEC_I_CONTINUE_NEEDED, as we did not actually send data
  585. // at this point.
  586. state_ = st_error;
  587. return TLS_ERR_ERROR;
  588. }
  589. // Queue the initial message...
  590. writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
  591. FreeContextBuffer(buf.pvBuffer);
  592. // ... and start sending it
  593. state_ = st_handshake_write;
  594. }
  595. // Fall through
  596. case st_handshake_write_last:
  597. case st_handshake_write: {
  598. A2_LOG_DEBUG("WinTLS: Writing handshake");
  599. // Write the currently queued handshake message until all data is sent.
  600. while (writeBuf_.size()) {
  601. ssize_t writ = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
  602. errno = ::WSAGetLastError();
  603. if (writ < 0 && errno == WSAEINTR) {
  604. continue;
  605. }
  606. if (writ < 0 && errno == WSAEWOULDBLOCK) {
  607. return TLS_ERR_WOULDBLOCK;
  608. }
  609. if (writ <= 0) {
  610. status_ = SEC_E_INCOMPLETE_MESSAGE;
  611. state_ = st_error;
  612. return TLS_ERR_ERROR;
  613. }
  614. writeBuf_.eat(writ);
  615. }
  616. if (state_ == st_handshake_write_last) {
  617. state_ = st_handshake_done;
  618. goto restart;
  619. }
  620. // Have to read one or more response messages.
  621. state_ = st_handshake_read;
  622. }
  623. // Fall through
  624. case st_handshake_read: {
  625. read:
  626. A2_LOG_DEBUG("WinTLS: Reading handshake...");
  627. // All write buffered data is invalid at this point!
  628. writeBuf_.clear();
  629. // Read as many bytes as possible, up to 4k new bytes.
  630. // We do not know how many bytes will arrive from the server at this
  631. // point.
  632. readBuf_.resize(readBuf_.size() + 4096);
  633. while (readBuf_.free()) {
  634. ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
  635. errno = ::WSAGetLastError();
  636. if (read < 0 && errno == WSAEINTR) {
  637. continue;
  638. }
  639. if (read < 0 && errno == WSAEWOULDBLOCK) {
  640. break;
  641. }
  642. if (read <= 0) {
  643. status_ = SEC_E_INCOMPLETE_MESSAGE;
  644. state_ = st_error;
  645. return TLS_ERR_ERROR;
  646. }
  647. if (read == 0) {
  648. A2_LOG_DEBUG("WinTLS: Connection abruptly closed during handshake!");
  649. status_ = SEC_E_INCOMPLETE_MESSAGE;
  650. state_ = st_error;
  651. return TLS_ERR_ERROR;
  652. }
  653. readBuf_.advance(read);
  654. break;
  655. }
  656. if (!readBuf_.size()) {
  657. return TLS_ERR_WOULDBLOCK;
  658. }
  659. // Need to copy the data, as Schannel is free to mess with it. But we
  660. // might later need unmodified data from the original read buffer.
  661. auto bufcopy = make_unique<char[]>(readBuf_.size());
  662. memcpy(bufcopy.get(), readBuf_.data(), readBuf_.size());
  663. // Set up buffers. inbufs will be the raw bytes the library has to decode.
  664. // outbufs will contain generated responses, if any.
  665. TLSBuffer inbufs[] = {
  666. TLSBuffer(SECBUFFER_TOKEN, readBuf_.size(), bufcopy.get()),
  667. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  668. };
  669. TLSBufferDesc indesc(inbufs, 2);
  670. TLSBuffer outbufs[] = {
  671. TLSBuffer(SECBUFFER_TOKEN, 0, nullptr),
  672. TLSBuffer(SECBUFFER_ALERT, 0, nullptr),
  673. };
  674. TLSBufferDesc outdesc(outbufs, 2);
  675. if (side_ == TLS_CLIENT) {
  676. SEC_CHAR* host = hostname_.empty() ?
  677. nullptr :
  678. const_cast<SEC_CHAR*>(hostname_.c_str());
  679. status_ = ::InitializeSecurityContext(cred_,
  680. &handle_,
  681. host,
  682. kReqFlags,
  683. 0,
  684. 0,
  685. &indesc,
  686. 0,
  687. nullptr,
  688. &outdesc,
  689. &flags,
  690. nullptr);
  691. }
  692. else {
  693. status_ =
  694. ::AcceptSecurityContext(cred_,
  695. state_ == st_initialized ? nullptr : &handle_,
  696. &indesc,
  697. kReqAFlags,
  698. 0,
  699. state_ == st_initialized ? &handle_ : nullptr,
  700. &outdesc,
  701. &flags,
  702. nullptr);
  703. }
  704. if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
  705. // Not enough raw bytes read yet to decode a full message.
  706. return TLS_ERR_WOULDBLOCK;
  707. }
  708. if (status_ != SEC_E_OK && status_ != SEC_I_CONTINUE_NEEDED) {
  709. state_ = st_error;
  710. return TLS_ERR_ERROR;
  711. }
  712. // Raw bytes where not entirely consumed, i.e. readBuf_ still contains
  713. // unprocessed data from the next message?
  714. if (inbufs[1].BufferType == SECBUFFER_EXTRA && inbufs[1].cbBuffer > 0) {
  715. readBuf_.eat(readBuf_.size() - inbufs[1].cbBuffer);
  716. }
  717. else {
  718. readBuf_.clear();
  719. }
  720. // Check if the library produced a new outgoing message and queue it.
  721. for (auto& buf : outbufs) {
  722. if (buf.BufferType == SECBUFFER_TOKEN && buf.cbBuffer > 0) {
  723. writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
  724. FreeContextBuffer(buf.pvBuffer);
  725. state_ = st_handshake_write;
  726. }
  727. }
  728. // Need to read additional messages?
  729. if (status_ == SEC_I_CONTINUE_NEEDED) {
  730. A2_LOG_DEBUG("WinTLS: Continuing with handshake");
  731. goto restart;
  732. }
  733. if (side_ == TLS_CLIENT && flags != kReqFlags) {
  734. A2_LOG_ERROR(fmt("WinTLS: Channel setup failed. Schannel provider did "
  735. "not fulfill requested flags. "
  736. "Excepted: %lu Actual: %lu",
  737. kReqFlags,
  738. flags));
  739. status_ = SEC_E_INTERNAL_ERROR;
  740. state_ = st_error;
  741. return TLS_ERR_ERROR;
  742. }
  743. if (state_ == st_handshake_write) {
  744. A2_LOG_DEBUG("WinTLS: Continuing with handshake (last write)");
  745. state_ = st_handshake_write_last;
  746. goto restart;
  747. }
  748. }
  749. // Fall through
  750. case st_handshake_done:
  751. // All ready now :D
  752. state_ = st_connected;
  753. A2_LOG_INFO(
  754. fmt("WinTLS: connected with: %s", getCipherSuite(&handle_).c_str()));
  755. switch (getProtocolVersion(&handle_)) {
  756. case 0x300:
  757. version = TLS_PROTO_SSL3;
  758. break;
  759. case 0x301:
  760. version = TLS_PROTO_TLS10;
  761. break;
  762. case 0x302:
  763. version = TLS_PROTO_TLS11;
  764. break;
  765. case 0x303:
  766. version = TLS_PROTO_TLS12;
  767. break;
  768. default:
  769. version = TLS_PROTO_NONE;
  770. break;
  771. }
  772. return TLS_ERR_OK;
  773. }
  774. A2_LOG_ERROR("WinTLS: Unreachable reached during tlsConnect! This is a bug!");
  775. state_ = st_error;
  776. return TLS_ERR_ERROR;
  777. }
  778. int WinTLSSession::tlsAccept(TLSVersion& version)
  779. {
  780. std::string host, err;
  781. return tlsConnect(host, version, err);
  782. }
  783. std::string WinTLSSession::getLastErrorString()
  784. {
  785. std::stringstream ss;
  786. wchar_t* buf = nullptr;
  787. auto rv = FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER |
  788. FORMAT_MESSAGE_FROM_SYSTEM |
  789. FORMAT_MESSAGE_IGNORE_INSERTS,
  790. nullptr,
  791. status_,
  792. MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
  793. (LPWSTR) & buf,
  794. 1024,
  795. nullptr);
  796. if (rv && buf) {
  797. ss << "Error: " << wCharToUtf8(buf) << "(" << std::hex << status_ << ")";
  798. LocalFree(buf);
  799. }
  800. else {
  801. ss << "Error: " << std::hex << status_;
  802. }
  803. return ss.str();
  804. }
  805. } // namespace aria2