aes-ni.c 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. /*
  2. * Hardware-accelerated implementation of AES using x86 AES-NI.
  3. */
  4. #include "ssh.h"
  5. #include "aes.h"
  6. #ifndef WINSCP_VS
  7. bool aes_ni_available(void);
  8. ssh_cipher *aes_ni_new(const ssh_cipheralg *alg);
  9. void aes_ni_free(ssh_cipher *ciph);
  10. void aes_ni_setiv_cbc(ssh_cipher *ciph, const void *iv);
  11. void aes_ni_setkey(ssh_cipher *ciph, const void *vkey);
  12. void aes_ni_setiv_sdctr(ssh_cipher *ciph, const void *iv);
  13. void aes_ni_setiv_gcm(ssh_cipher *ciph, const void *iv);
  14. #define NI_ENC_DEC_H(len) \
  15. void aes##len##_ni_cbc_encrypt( \
  16. ssh_cipher *ciph, void *vblk, int blklen); \
  17. void aes##len##_ni_cbc_decrypt( \
  18. ssh_cipher *ciph, void *vblk, int blklen); \
  19. void aes##len##_ni_sdctr( \
  20. ssh_cipher *ciph, void *vblk, int blklen); \
  21. void aes##len##_ni_gcm( \
  22. ssh_cipher *ciph, void *vblk, int blklen); \
  23. void aes##len##_ni_encrypt_ecb_block( \
  24. ssh_cipher *ciph, void *vblk); \
  25. NI_ENC_DEC_H(128)
  26. NI_ENC_DEC_H(192)
  27. NI_ENC_DEC_H(256)
  28. #else
  29. #include <wmmintrin.h>
  30. #include <smmintrin.h>
  31. #if defined(__clang__) || defined(__GNUC__)
  32. #include <cpuid.h>
  33. #define GET_CPU_ID(out) __cpuid(1, (out)[0], (out)[1], (out)[2], (out)[3])
  34. #else
  35. #define GET_CPU_ID(out) __cpuid(out, 1)
  36. #endif
  37. /*static WINSCP*/ bool aes_ni_available(void)
  38. {
  39. /*
  40. * Determine if AES is available on this CPU, by checking that
  41. * both AES itself and SSE4.1 are supported.
  42. */
  43. unsigned int CPUInfo[4];
  44. GET_CPU_ID(CPUInfo);
  45. return (CPUInfo[2] & (1 << 25)) && (CPUInfo[2] & (1 << 19));
  46. }
  47. /*
  48. * Core AES-NI encrypt/decrypt functions, one per length and direction.
  49. */
  50. #define NI_CIPHER(len, dir, dirlong, repmacro) \
  51. static inline __m128i aes_ni_##len##_##dir( \
  52. __m128i v, const __m128i *keysched) \
  53. { \
  54. v = _mm_xor_si128(v, *keysched++); \
  55. repmacro(v = _mm_aes##dirlong##_si128(v, *keysched++);); \
  56. return _mm_aes##dirlong##last_si128(v, *keysched); \
  57. }
  58. NI_CIPHER(128, e, enc, REP9)
  59. NI_CIPHER(128, d, dec, REP9)
  60. NI_CIPHER(192, e, enc, REP11)
  61. NI_CIPHER(192, d, dec, REP11)
  62. NI_CIPHER(256, e, enc, REP13)
  63. NI_CIPHER(256, d, dec, REP13)
  64. /*
  65. * The main key expansion.
  66. */
  67. static void aes_ni_key_expand(
  68. const unsigned char *key, size_t key_words,
  69. __m128i *keysched_e, __m128i *keysched_d)
  70. {
  71. size_t rounds = key_words + 6;
  72. size_t sched_words = (rounds + 1) * 4;
  73. /*
  74. * Store the key schedule as 32-bit integers during expansion, so
  75. * that it's easy to refer back to individual previous words. We
  76. * collect them into the final __m128i form at the end.
  77. */
  78. uint32_t sched[MAXROUNDKEYS * 4];
  79. unsigned rconpos = 0;
  80. for (size_t i = 0; i < sched_words; i++) {
  81. if (i < key_words) {
  82. sched[i] = GET_32BIT_LSB_FIRST(key + 4 * i);
  83. } else {
  84. uint32_t temp = sched[i - 1];
  85. bool rotate_and_round_constant = (i % key_words == 0);
  86. bool only_sub = (key_words == 8 && i % 8 == 4);
  87. if (rotate_and_round_constant) {
  88. __m128i v = _mm_setr_epi32(0,temp,0,0);
  89. v = _mm_aeskeygenassist_si128(v, 0);
  90. temp = _mm_extract_epi32(v, 1);
  91. assert(rconpos < lenof(aes_key_setup_round_constants));
  92. temp ^= aes_key_setup_round_constants[rconpos++];
  93. } else if (only_sub) {
  94. __m128i v = _mm_setr_epi32(0,temp,0,0);
  95. v = _mm_aeskeygenassist_si128(v, 0);
  96. temp = _mm_extract_epi32(v, 0);
  97. }
  98. sched[i] = sched[i - key_words] ^ temp;
  99. }
  100. }
  101. /*
  102. * Combine the key schedule words into __m128i vectors and store
  103. * them in the output context.
  104. */
  105. for (size_t round = 0; round <= rounds; round++)
  106. keysched_e[round] = _mm_setr_epi32(
  107. sched[4*round ], sched[4*round+1],
  108. sched[4*round+2], sched[4*round+3]);
  109. smemclr(sched, sizeof(sched));
  110. /*
  111. * Now prepare the modified keys for the inverse cipher.
  112. */
  113. for (size_t eround = 0; eround <= rounds; eround++) {
  114. size_t dround = rounds - eround;
  115. __m128i rkey = keysched_e[eround];
  116. if (eround && dround) /* neither first nor last */
  117. rkey = _mm_aesimc_si128(rkey);
  118. keysched_d[dround] = rkey;
  119. }
  120. }
  121. // WINSCP
  122. // WORKAROUND
  123. // Cannot use _mm_setr_epi* - it results in the constant being stored in .rdata segment.
  124. // objconv reports:
  125. // Warning 1060: Different alignments specified for same segment, %s. Using highest alignment.rdata
  126. // Despite that the code crashes.
  127. // This macro is based on:
  128. // Based on https://stackoverflow.com/q/35268036/850848
  129. #define _MM_SETR_EPI8(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, aa, ab, ac, ad, ae, af) \
  130. { (char)a0, (char)a1, (char)a2, (char)a3, (char)a4, (char)a5, (char)a6, (char)a7, \
  131. (char)a8, (char)a9, (char)aa, (char)ab, (char)ac, (char)ad, (char)ae, (char)af }
  132. /*
  133. * Auxiliary routine to increment the 128-bit counter used in SDCTR
  134. * mode.
  135. */
  136. static inline __m128i aes_ni_sdctr_increment(__m128i v)
  137. {
  138. const __m128i ONE = _MM_SETR_EPI8(1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0); // WINSCP
  139. const __m128i ZERO = _mm_setzero_si128();
  140. /* Increment the low-order 64 bits of v */
  141. v = _mm_add_epi64(v, ONE);
  142. /* Check if they've become zero */
  143. __m128i cmp = _mm_cmpeq_epi64(v, ZERO);
  144. /* If so, the low half of cmp is all 1s. Pack that into the high
  145. * half of addend with zero in the low half. */
  146. __m128i addend = _mm_unpacklo_epi64(ZERO, cmp);
  147. /* And subtract that from v, which increments the high 64 bits iff
  148. * the low 64 wrapped round. */
  149. v = _mm_sub_epi64(v, addend);
  150. return v;
  151. }
  152. /*
  153. * Much simpler auxiliary routine to increment the counter for GCM
  154. * mode. This only has to increment the low word.
  155. */
  156. static inline __m128i aes_ni_gcm_increment(__m128i v)
  157. {
  158. const __m128i ONE = _MM_SETR_EPI8(1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0); // WINSCP
  159. return _mm_add_epi32(v, ONE);
  160. }
  161. /*
  162. * Auxiliary routine to reverse the byte order of a vector, so that
  163. * the SDCTR IV can be made big-endian for feeding to the cipher.
  164. */
  165. static inline __m128i aes_ni_sdctr_reverse(__m128i v)
  166. {
  167. const __m128i R = _MM_SETR_EPI8(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0); // WINSCP
  168. v = _mm_shuffle_epi8(
  169. v, R); // WINSCP
  170. return v;
  171. }
  172. /*
  173. * The SSH interface and the cipher modes.
  174. */
  175. typedef struct aes_ni_context aes_ni_context;
  176. struct aes_ni_context {
  177. __m128i keysched_e[MAXROUNDKEYS], keysched_d[MAXROUNDKEYS], iv;
  178. void *pointer_to_free;
  179. ssh_cipher ciph;
  180. };
  181. /*static WINSCP*/ ssh_cipher *aes_ni_new(const ssh_cipheralg *alg)
  182. {
  183. const struct aes_extra *extra = (const struct aes_extra *)alg->extra;
  184. if (!check_availability(extra))
  185. return NULL;
  186. /*
  187. * The __m128i variables in the context structure need to be
  188. * 16-byte aligned, but not all malloc implementations that this
  189. * code has to work with will guarantee to return a 16-byte
  190. * aligned pointer. So we over-allocate, manually realign the
  191. * pointer ourselves, and store the original one inside the
  192. * context so we know how to free it later.
  193. */
  194. void *allocation = smalloc(sizeof(aes_ni_context) + 15);
  195. uintptr_t alloc_address = (uintptr_t)allocation;
  196. uintptr_t aligned_address = (alloc_address + 15) & ~15;
  197. aes_ni_context *ctx = (aes_ni_context *)aligned_address;
  198. ctx->ciph.vt = alg;
  199. ctx->pointer_to_free = allocation;
  200. return &ctx->ciph;
  201. }
  202. /*static WINSCP*/ void aes_ni_free(ssh_cipher *ciph)
  203. {
  204. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  205. void *allocation = ctx->pointer_to_free;
  206. smemclr(ctx, sizeof(*ctx));
  207. sfree(allocation);
  208. }
  209. /*static WINSCP*/ void aes_ni_setkey(ssh_cipher *ciph, const void *vkey)
  210. {
  211. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  212. const unsigned char *key = (const unsigned char *)vkey;
  213. aes_ni_key_expand(key, ctx->ciph.vt->real_keybits / 32,
  214. ctx->keysched_e, ctx->keysched_d);
  215. }
  216. /*static WINSCP*/ void aes_ni_setiv_cbc(ssh_cipher *ciph, const void *iv)
  217. {
  218. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  219. ctx->iv = _mm_loadu_si128(iv);
  220. }
  221. /*static WINSCP*/ void aes_ni_setiv_sdctr(ssh_cipher *ciph, const void *iv)
  222. {
  223. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  224. __m128i counter = _mm_loadu_si128(iv);
  225. ctx->iv = aes_ni_sdctr_reverse(counter);
  226. }
  227. /*WINSCP static*/ void aes_ni_setiv_gcm(ssh_cipher *ciph, const void *iv)
  228. {
  229. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  230. __m128i counter = _mm_loadu_si128(iv);
  231. ctx->iv = aes_ni_sdctr_reverse(counter);
  232. ctx->iv = _mm_insert_epi32(ctx->iv, 1, 0);
  233. }
  234. static void aes_ni_next_message_gcm(ssh_cipher *ciph)
  235. {
  236. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  237. uint32_t fixed = _mm_extract_epi32(ctx->iv, 3);
  238. uint64_t msg_counter = _mm_extract_epi32(ctx->iv, 2);
  239. msg_counter <<= 32;
  240. msg_counter |= (uint32_t)_mm_extract_epi32(ctx->iv, 1);
  241. msg_counter++;
  242. ctx->iv = _mm_set_epi32(fixed, (int)(msg_counter >> 32), (int)msg_counter, 1); // WINSCP
  243. }
  244. typedef __m128i (*aes_ni_fn)(__m128i v, const __m128i *keysched);
  245. static inline void aes_cbc_ni_encrypt(
  246. ssh_cipher *ciph, void *vblk, int blklen, aes_ni_fn encrypt)
  247. {
  248. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  249. for (uint8_t *blk = (uint8_t *)vblk, *finish = blk + blklen;
  250. blk < finish; blk += 16) {
  251. __m128i plaintext = _mm_loadu_si128((const __m128i *)blk);
  252. __m128i cipher_input = _mm_xor_si128(plaintext, ctx->iv);
  253. __m128i ciphertext = encrypt(cipher_input, ctx->keysched_e);
  254. _mm_storeu_si128((__m128i *)blk, ciphertext);
  255. ctx->iv = ciphertext;
  256. }
  257. }
  258. static inline void aes_cbc_ni_decrypt(
  259. ssh_cipher *ciph, void *vblk, int blklen, aes_ni_fn decrypt)
  260. {
  261. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  262. for (uint8_t *blk = (uint8_t *)vblk, *finish = blk + blklen;
  263. blk < finish; blk += 16) {
  264. __m128i ciphertext = _mm_loadu_si128((const __m128i *)blk);
  265. __m128i decrypted = decrypt(ciphertext, ctx->keysched_d);
  266. __m128i plaintext = _mm_xor_si128(decrypted, ctx->iv);
  267. _mm_storeu_si128((__m128i *)blk, plaintext);
  268. ctx->iv = ciphertext;
  269. }
  270. }
  271. static inline void aes_sdctr_ni(
  272. ssh_cipher *ciph, void *vblk, int blklen, aes_ni_fn encrypt)
  273. {
  274. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  275. for (uint8_t *blk = (uint8_t *)vblk, *finish = blk + blklen;
  276. blk < finish; blk += 16) {
  277. __m128i counter = aes_ni_sdctr_reverse(ctx->iv);
  278. __m128i keystream = encrypt(counter, ctx->keysched_e);
  279. __m128i input = _mm_loadu_si128((const __m128i *)blk);
  280. __m128i output = _mm_xor_si128(input, keystream);
  281. _mm_storeu_si128((__m128i *)blk, output);
  282. ctx->iv = aes_ni_sdctr_increment(ctx->iv);
  283. }
  284. }
  285. static inline void aes_encrypt_ecb_block_ni(
  286. ssh_cipher *ciph, void *blk, aes_ni_fn encrypt)
  287. {
  288. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  289. __m128i plaintext = _mm_loadu_si128(blk);
  290. __m128i ciphertext = encrypt(plaintext, ctx->keysched_e);
  291. _mm_storeu_si128(blk, ciphertext);
  292. }
  293. // WINSCP (fixes linker alignment issues for the following function)
  294. const __m128i DUMMY; // WINSCP
  295. static inline void aes_gcm_ni(
  296. ssh_cipher *ciph, void *vblk, int blklen, aes_ni_fn encrypt)
  297. {
  298. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  299. for (uint8_t *blk = (uint8_t *)vblk, *finish = blk + blklen;
  300. blk < finish; blk += 16) {
  301. __m128i counter = aes_ni_sdctr_reverse(ctx->iv);
  302. __m128i keystream = encrypt(counter, ctx->keysched_e);
  303. __m128i input = _mm_loadu_si128((const __m128i *)blk);
  304. __m128i output = _mm_xor_si128(input, keystream);
  305. _mm_storeu_si128((__m128i *)blk, output);
  306. ctx->iv = aes_ni_gcm_increment(ctx->iv);
  307. }
  308. }
  309. #define NI_ENC_DEC(len) \
  310. /*static WINSCP*/ void aes##len##_ni_cbc_encrypt( \
  311. ssh_cipher *ciph, void *vblk, int blklen) \
  312. { aes_cbc_ni_encrypt(ciph, vblk, blklen, aes_ni_##len##_e); } \
  313. /*static WINSCP*/ void aes##len##_ni_cbc_decrypt( \
  314. ssh_cipher *ciph, void *vblk, int blklen) \
  315. { aes_cbc_ni_decrypt(ciph, vblk, blklen, aes_ni_##len##_d); } \
  316. /*static WINSCP*/ void aes##len##_ni_sdctr( \
  317. ssh_cipher *ciph, void *vblk, int blklen) \
  318. { aes_sdctr_ni(ciph, vblk, blklen, aes_ni_##len##_e); } \
  319. /*static WINSCP*/ void aes##len##_ni_gcm( \
  320. ssh_cipher *ciph, void *vblk, int blklen) \
  321. { aes_gcm_ni(ciph, vblk, blklen, aes_ni_##len##_e); } \
  322. /*static WINSCP*/ void aes##len##_ni_encrypt_ecb_block( \
  323. ssh_cipher *ciph, void *vblk) \
  324. { aes_encrypt_ecb_block_ni(ciph, vblk, aes_ni_##len##_e); }
  325. NI_ENC_DEC(128)
  326. NI_ENC_DEC(192)
  327. NI_ENC_DEC(256)
  328. #endif // WINSCP_VS
  329. AES_EXTRA(_ni);
  330. AES_ALL_VTABLES(_ni, "AES-NI accelerated");