aes-ni.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. /*
  2. * Hardware-accelerated implementation of AES using x86 AES-NI.
  3. */
  4. #include "ssh.h"
  5. #include "aes.h"
  6. #ifndef WINSCP_VS
  7. bool aes_ni_available(void);
  8. ssh_cipher *aes_ni_new(const ssh_cipheralg *alg);
  9. void aes_ni_free(ssh_cipher *ciph);
  10. void aes_ni_setiv_cbc(ssh_cipher *ciph, const void *iv);
  11. void aes_ni_setkey(ssh_cipher *ciph, const void *vkey);
  12. void aes_ni_setiv_sdctr(ssh_cipher *ciph, const void *iv);
  13. #define NI_ENC_DEC_H(len) \
  14. void aes##len##_ni_cbc_encrypt( \
  15. ssh_cipher *ciph, void *vblk, int blklen); \
  16. void aes##len##_ni_cbc_decrypt( \
  17. ssh_cipher *ciph, void *vblk, int blklen); \
  18. void aes##len##_ni_sdctr( \
  19. ssh_cipher *ciph, void *vblk, int blklen); \
  20. NI_ENC_DEC_H(128)
  21. NI_ENC_DEC_H(192)
  22. NI_ENC_DEC_H(256)
  23. #else
  24. #include <wmmintrin.h>
  25. #include <smmintrin.h>
  26. #if defined(__clang__) || defined(__GNUC__)
  27. #include <cpuid.h>
  28. #define GET_CPU_ID(out) __cpuid(1, (out)[0], (out)[1], (out)[2], (out)[3])
  29. #else
  30. #define GET_CPU_ID(out) __cpuid(out, 1)
  31. #endif
  32. /*static WINSCP*/ bool aes_ni_available(void)
  33. {
  34. /*
  35. * Determine if AES is available on this CPU, by checking that
  36. * both AES itself and SSE4.1 are supported.
  37. */
  38. unsigned int CPUInfo[4];
  39. GET_CPU_ID(CPUInfo);
  40. return (CPUInfo[2] & (1 << 25)) && (CPUInfo[2] & (1 << 19));
  41. }
  42. /*
  43. * Core AES-NI encrypt/decrypt functions, one per length and direction.
  44. */
  45. #define NI_CIPHER(len, dir, dirlong, repmacro) \
  46. static inline __m128i aes_ni_##len##_##dir( \
  47. __m128i v, const __m128i *keysched) \
  48. { \
  49. v = _mm_xor_si128(v, *keysched++); \
  50. repmacro(v = _mm_aes##dirlong##_si128(v, *keysched++);); \
  51. return _mm_aes##dirlong##last_si128(v, *keysched); \
  52. }
  53. NI_CIPHER(128, e, enc, REP9)
  54. NI_CIPHER(128, d, dec, REP9)
  55. NI_CIPHER(192, e, enc, REP11)
  56. NI_CIPHER(192, d, dec, REP11)
  57. NI_CIPHER(256, e, enc, REP13)
  58. NI_CIPHER(256, d, dec, REP13)
  59. /*
  60. * The main key expansion.
  61. */
  62. static void aes_ni_key_expand(
  63. const unsigned char *key, size_t key_words,
  64. __m128i *keysched_e, __m128i *keysched_d)
  65. {
  66. size_t rounds = key_words + 6;
  67. size_t sched_words = (rounds + 1) * 4;
  68. /*
  69. * Store the key schedule as 32-bit integers during expansion, so
  70. * that it's easy to refer back to individual previous words. We
  71. * collect them into the final __m128i form at the end.
  72. */
  73. uint32_t sched[MAXROUNDKEYS * 4];
  74. unsigned rconpos = 0;
  75. for (size_t i = 0; i < sched_words; i++) {
  76. if (i < key_words) {
  77. sched[i] = GET_32BIT_LSB_FIRST(key + 4 * i);
  78. } else {
  79. uint32_t temp = sched[i - 1];
  80. bool rotate_and_round_constant = (i % key_words == 0);
  81. bool only_sub = (key_words == 8 && i % 8 == 4);
  82. if (rotate_and_round_constant) {
  83. __m128i v = _mm_setr_epi32(0,temp,0,0);
  84. v = _mm_aeskeygenassist_si128(v, 0);
  85. temp = _mm_extract_epi32(v, 1);
  86. assert(rconpos < lenof(aes_key_setup_round_constants));
  87. temp ^= aes_key_setup_round_constants[rconpos++];
  88. } else if (only_sub) {
  89. __m128i v = _mm_setr_epi32(0,temp,0,0);
  90. v = _mm_aeskeygenassist_si128(v, 0);
  91. temp = _mm_extract_epi32(v, 0);
  92. }
  93. sched[i] = sched[i - key_words] ^ temp;
  94. }
  95. }
  96. /*
  97. * Combine the key schedule words into __m128i vectors and store
  98. * them in the output context.
  99. */
  100. for (size_t round = 0; round <= rounds; round++)
  101. keysched_e[round] = _mm_setr_epi32(
  102. sched[4*round ], sched[4*round+1],
  103. sched[4*round+2], sched[4*round+3]);
  104. smemclr(sched, sizeof(sched));
  105. /*
  106. * Now prepare the modified keys for the inverse cipher.
  107. */
  108. for (size_t eround = 0; eround <= rounds; eround++) {
  109. size_t dround = rounds - eround;
  110. __m128i rkey = keysched_e[eround];
  111. if (eround && dround) /* neither first nor last */
  112. rkey = _mm_aesimc_si128(rkey);
  113. keysched_d[dround] = rkey;
  114. }
  115. }
  116. // WINSCP
  117. // WORKAROUND
  118. // Cannot use _mm_setr_epi* - it results in the constant being stored in .rdata segment.
  119. // objconv reports:
  120. // Warning 1060: Different alignments specified for same segment, %s. Using highest alignment.rdata
  121. // Despite that the code crashes.
  122. // This macro is based on:
  123. // Based on https://stackoverflow.com/q/35268036/850848
  124. #define _MM_SETR_EPI8(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, aa, ab, ac, ad, ae, af) \
  125. { (char)a0, (char)a1, (char)a2, (char)a3, (char)a4, (char)a5, (char)a6, (char)a7, \
  126. (char)a8, (char)a9, (char)aa, (char)ab, (char)ac, (char)ad, (char)ae, (char)af }
  127. /*
  128. * Auxiliary routine to increment the 128-bit counter used in SDCTR
  129. * mode.
  130. */
  131. static inline __m128i aes_ni_sdctr_increment(__m128i v)
  132. {
  133. const __m128i ONE = _MM_SETR_EPI8(1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0); // WINSCP
  134. const __m128i ZERO = _mm_setzero_si128();
  135. /* Increment the low-order 64 bits of v */
  136. v = _mm_add_epi64(v, ONE);
  137. /* Check if they've become zero */
  138. __m128i cmp = _mm_cmpeq_epi64(v, ZERO);
  139. /* If so, the low half of cmp is all 1s. Pack that into the high
  140. * half of addend with zero in the low half. */
  141. __m128i addend = _mm_unpacklo_epi64(ZERO, cmp);
  142. /* And subtract that from v, which increments the high 64 bits iff
  143. * the low 64 wrapped round. */
  144. v = _mm_sub_epi64(v, addend);
  145. return v;
  146. }
  147. /*
  148. * Much simpler auxiliary routine to increment the counter for GCM
  149. * mode. This only has to increment the low word.
  150. */
  151. static inline __m128i aes_ni_gcm_increment(__m128i v)
  152. {
  153. const __m128i ONE = _mm_setr_epi32(1,0,0,0);
  154. return _mm_add_epi32(v, ONE);
  155. }
  156. /*
  157. * Auxiliary routine to reverse the byte order of a vector, so that
  158. * the SDCTR IV can be made big-endian for feeding to the cipher.
  159. */
  160. static inline __m128i aes_ni_sdctr_reverse(__m128i v)
  161. {
  162. const __m128i R = _MM_SETR_EPI8(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0); // WINSCP
  163. v = _mm_shuffle_epi8(
  164. v, R); // WINSCP
  165. return v;
  166. }
  167. /*
  168. * The SSH interface and the cipher modes.
  169. */
  170. typedef struct aes_ni_context aes_ni_context;
  171. struct aes_ni_context {
  172. __m128i keysched_e[MAXROUNDKEYS], keysched_d[MAXROUNDKEYS], iv;
  173. void *pointer_to_free;
  174. ssh_cipher ciph;
  175. };
  176. /*static WINSCP*/ ssh_cipher *aes_ni_new(const ssh_cipheralg *alg)
  177. {
  178. const struct aes_extra *extra = (const struct aes_extra *)alg->extra;
  179. if (!check_availability(extra))
  180. return NULL;
  181. /*
  182. * The __m128i variables in the context structure need to be
  183. * 16-byte aligned, but not all malloc implementations that this
  184. * code has to work with will guarantee to return a 16-byte
  185. * aligned pointer. So we over-allocate, manually realign the
  186. * pointer ourselves, and store the original one inside the
  187. * context so we know how to free it later.
  188. */
  189. void *allocation = smalloc(sizeof(aes_ni_context) + 15);
  190. uintptr_t alloc_address = (uintptr_t)allocation;
  191. uintptr_t aligned_address = (alloc_address + 15) & ~15;
  192. aes_ni_context *ctx = (aes_ni_context *)aligned_address;
  193. ctx->ciph.vt = alg;
  194. ctx->pointer_to_free = allocation;
  195. return &ctx->ciph;
  196. }
  197. /*static WINSCP*/ void aes_ni_free(ssh_cipher *ciph)
  198. {
  199. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  200. void *allocation = ctx->pointer_to_free;
  201. smemclr(ctx, sizeof(*ctx));
  202. sfree(allocation);
  203. }
  204. /*static WINSCP*/ void aes_ni_setkey(ssh_cipher *ciph, const void *vkey)
  205. {
  206. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  207. const unsigned char *key = (const unsigned char *)vkey;
  208. aes_ni_key_expand(key, ctx->ciph.vt->real_keybits / 32,
  209. ctx->keysched_e, ctx->keysched_d);
  210. }
  211. /*static WINSCP*/ void aes_ni_setiv_cbc(ssh_cipher *ciph, const void *iv)
  212. {
  213. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  214. ctx->iv = _mm_loadu_si128(iv);
  215. }
  216. /*static WINSCP*/ void aes_ni_setiv_sdctr(ssh_cipher *ciph, const void *iv)
  217. {
  218. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  219. __m128i counter = _mm_loadu_si128(iv);
  220. ctx->iv = aes_ni_sdctr_reverse(counter);
  221. }
  222. static void aes_ni_setiv_gcm(ssh_cipher *ciph, const void *iv)
  223. {
  224. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  225. __m128i counter = _mm_loadu_si128(iv);
  226. ctx->iv = aes_ni_sdctr_reverse(counter);
  227. ctx->iv = _mm_insert_epi32(ctx->iv, 1, 0);
  228. }
  229. static void aes_ni_next_message_gcm(ssh_cipher *ciph)
  230. {
  231. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  232. uint32_t fixed = _mm_extract_epi32(ctx->iv, 3);
  233. uint64_t msg_counter = _mm_extract_epi32(ctx->iv, 2);
  234. msg_counter <<= 32;
  235. msg_counter |= (uint32_t)_mm_extract_epi32(ctx->iv, 1);
  236. msg_counter++;
  237. ctx->iv = _mm_set_epi32(fixed, msg_counter >> 32, msg_counter, 1);
  238. }
  239. typedef __m128i (*aes_ni_fn)(__m128i v, const __m128i *keysched);
  240. static inline void aes_cbc_ni_encrypt(
  241. ssh_cipher *ciph, void *vblk, int blklen, aes_ni_fn encrypt)
  242. {
  243. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  244. for (uint8_t *blk = (uint8_t *)vblk, *finish = blk + blklen;
  245. blk < finish; blk += 16) {
  246. __m128i plaintext = _mm_loadu_si128((const __m128i *)blk);
  247. __m128i cipher_input = _mm_xor_si128(plaintext, ctx->iv);
  248. __m128i ciphertext = encrypt(cipher_input, ctx->keysched_e);
  249. _mm_storeu_si128((__m128i *)blk, ciphertext);
  250. ctx->iv = ciphertext;
  251. }
  252. }
  253. static inline void aes_cbc_ni_decrypt(
  254. ssh_cipher *ciph, void *vblk, int blklen, aes_ni_fn decrypt)
  255. {
  256. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  257. for (uint8_t *blk = (uint8_t *)vblk, *finish = blk + blklen;
  258. blk < finish; blk += 16) {
  259. __m128i ciphertext = _mm_loadu_si128((const __m128i *)blk);
  260. __m128i decrypted = decrypt(ciphertext, ctx->keysched_d);
  261. __m128i plaintext = _mm_xor_si128(decrypted, ctx->iv);
  262. _mm_storeu_si128((__m128i *)blk, plaintext);
  263. ctx->iv = ciphertext;
  264. }
  265. }
  266. static inline void aes_sdctr_ni(
  267. ssh_cipher *ciph, void *vblk, int blklen, aes_ni_fn encrypt)
  268. {
  269. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  270. for (uint8_t *blk = (uint8_t *)vblk, *finish = blk + blklen;
  271. blk < finish; blk += 16) {
  272. __m128i counter = aes_ni_sdctr_reverse(ctx->iv);
  273. __m128i keystream = encrypt(counter, ctx->keysched_e);
  274. __m128i input = _mm_loadu_si128((const __m128i *)blk);
  275. __m128i output = _mm_xor_si128(input, keystream);
  276. _mm_storeu_si128((__m128i *)blk, output);
  277. ctx->iv = aes_ni_sdctr_increment(ctx->iv);
  278. }
  279. }
  280. static inline void aes_encrypt_ecb_block_ni(
  281. ssh_cipher *ciph, void *blk, aes_ni_fn encrypt)
  282. {
  283. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  284. __m128i plaintext = _mm_loadu_si128(blk);
  285. __m128i ciphertext = encrypt(plaintext, ctx->keysched_e);
  286. _mm_storeu_si128(blk, ciphertext);
  287. }
  288. static inline void aes_gcm_ni(
  289. ssh_cipher *ciph, void *vblk, int blklen, aes_ni_fn encrypt)
  290. {
  291. aes_ni_context *ctx = container_of(ciph, aes_ni_context, ciph);
  292. for (uint8_t *blk = (uint8_t *)vblk, *finish = blk + blklen;
  293. blk < finish; blk += 16) {
  294. __m128i counter = aes_ni_sdctr_reverse(ctx->iv);
  295. __m128i keystream = encrypt(counter, ctx->keysched_e);
  296. __m128i input = _mm_loadu_si128((const __m128i *)blk);
  297. __m128i output = _mm_xor_si128(input, keystream);
  298. _mm_storeu_si128((__m128i *)blk, output);
  299. ctx->iv = aes_ni_gcm_increment(ctx->iv);
  300. }
  301. }
  302. #define NI_ENC_DEC(len) \
  303. /*static WINSCP*/ void aes##len##_ni_cbc_encrypt( \
  304. ssh_cipher *ciph, void *vblk, int blklen) \
  305. { aes_cbc_ni_encrypt(ciph, vblk, blklen, aes_ni_##len##_e); } \
  306. /*static WINSCP*/ void aes##len##_ni_cbc_decrypt( \
  307. ssh_cipher *ciph, void *vblk, int blklen) \
  308. { aes_cbc_ni_decrypt(ciph, vblk, blklen, aes_ni_##len##_d); } \
  309. /*static WINSCP*/ void aes##len##_ni_sdctr( \
  310. ssh_cipher *ciph, void *vblk, int blklen) \
  311. { aes_sdctr_ni(ciph, vblk, blklen, aes_ni_##len##_e); } \
  312. static void aes##len##_ni_gcm( \
  313. ssh_cipher *ciph, void *vblk, int blklen) \
  314. { aes_gcm_ni(ciph, vblk, blklen, aes_ni_##len##_e); } \
  315. static void aes##len##_ni_encrypt_ecb_block( \
  316. ssh_cipher *ciph, void *vblk) \
  317. { aes_encrypt_ecb_block_ni(ciph, vblk, aes_ni_##len##_e); }
  318. NI_ENC_DEC(128)
  319. NI_ENC_DEC(192)
  320. NI_ENC_DEC(256)
  321. #endif // WINSCP_VS
  322. AES_EXTRA(_ni);
  323. AES_ALL_VTABLES(_ni, "AES-NI accelerated");