mlkem.c 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179
  1. /*
  2. * Implementation of ML-KEM, previously known as 'Crystals: Kyber'.
  3. */
  4. #include <stdio.h>
  5. #include <stdarg.h>
  6. #include <stdlib.h>
  7. #include <assert.h>
  8. #include "putty.h"
  9. #include "ssh.h"
  10. #include "mlkem.h"
  11. #include "smallmoduli.h"
  12. /* ----------------------------------------------------------------------
  13. * General definitions.
  14. */
  15. /*
  16. * Arithmetic in this system works mod 3329, which is prime, and
  17. * congruent to 1 mod 256 (in fact it's 13*256 + 1), meaning that
  18. * 256th roots of unity exist.
  19. */
  20. #define Q 3329
  21. /*
  22. * Parameter structure describing a particular instance of ML-KEM.
  23. */
  24. struct mlkem_params {
  25. int k; /* dimensions of the matrices used */
  26. int eta_1, eta_2; /* parameters for mlkem_matrix_poly_cbd calls */
  27. int d_u, d_v; /* bit counts to use in lossy compressed encoding */
  28. };
  29. /*
  30. * Specific parameter sets.
  31. */
  32. const mlkem_params mlkem_params_512 = {
  33. /*.k =*/ 2, /*.eta_1 =*/ 3, /*.eta_2 =*/ 2, /*.d_u =*/ 10, /*.d_v =*/ 4,
  34. };
  35. const mlkem_params mlkem_params_768 = {
  36. /*.k =*/ 3, /*.eta_1 =*/ 2, /*.eta_2 =*/ 2, /*.d_u =*/ 10, /*.d_v =*/ 4,
  37. };
  38. const mlkem_params mlkem_params_1024 = {
  39. /*.k =*/ 4, /*.eta_1 =*/ 2, /*.eta_2 =*/ 2, /*.d_u =*/ 11, /*.d_v =*/ 5,
  40. };
  41. #define KMAX 4
  42. /* ----------------------------------------------------------------------
  43. * Number-theoretic transform on ring elements.
  44. *
  45. * The ring R used by ML-KEM is (Z/qZ)[X] / <X^256+1> (where q=3329 as
  46. * above). If the quotient polynomial were X^256-1 then it would split
  47. * into 256 linear factors, so that R could be expressed as the direct
  48. * sum of 256 rings (Z/qZ)[X] / <X-zeta^i> (where zeta is some fixed
  49. * primitive 256th root of unity mod q), each isomorphic to Z/qZ
  50. * itself. But X^256+1 only splits into 128 _quadratic_ factors, and
  51. * hence we can only decompose R as the direct sum of rings of the
  52. * form (Z/qZ)[X] / <X^2-zeta^j> for odd j, each a quadratic extension
  53. * of Z/qZ, and all mutually nonisomorphic. This means the NTT runs
  54. * one pass fewer than you'd "normally" expect, and also, multiplying
  55. * two elements of R in their NTT representation is not quite as
  56. * trivial as it would normally be - within each component ring of the
  57. * direct sum you have to do the multiplication slightly differently
  58. * depending on the power of zeta in its quotient polynomial.
  59. *
  60. * We take zeta=17 to be the canonical primitive 256th root of unity
  61. * for NTT purposes.
  62. */
  63. /*
  64. * First 128 powers of zeta, reordered by bit-reversing the 7-bit
  65. * index. That is, the nth element of this array contains
  66. * zeta^(bitrev7(n)). Used by the NTT itself.
  67. */
  68. static const uint16_t powers_reversed_order[128] = {
  69. 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786,
  70. 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 1426, 2094,
  71. 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277,
  72. 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452,
  73. 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, 17,
  74. 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156,
  75. 3015, 3050, 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437,
  76. 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645,
  77. 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143,
  78. 2150, 2775, 886, 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
  79. };
  80. /*
  81. * First 128 _odd_ powers of zeta: the nth element is
  82. * zeta^(2*bitrev7(n)+1). Each of these is used for multiplication in
  83. * one of the 128 quadratic-extension rings in the NTT decomposition.
  84. */
  85. static const uint16_t powers_odd_reversed_order[128] = {
  86. 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 2288,
  87. 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 756, 2573,
  88. 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 2789, 540, 1789,
  89. 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 939, 2390, 2308, 1021,
  90. 2437, 892, 2388, 941, 733, 2596, 2337, 992, 268, 3061, 641, 2688, 1584,
  91. 1745, 2298, 1031, 2037, 1292, 3220, 109, 375, 2954, 2549, 780, 2090, 1239,
  92. 1645, 1684, 1063, 2266, 319, 3010, 2773, 556, 757, 2572, 2099, 1230, 561,
  93. 2768, 2466, 863, 2594, 735, 2804, 525, 1092, 2237, 403, 2926, 1026, 2303,
  94. 1143, 2186, 2150, 1179, 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874,
  95. 1455, 1029, 2300, 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
  96. };
  97. /*
  98. * Convert a ring element into NTT representation.
  99. *
  100. * The input v is an array of 256 uint16_t, giving the coefficients of
  101. * a polynomial in X, with v[i] being the coefficient of X^i.
  102. *
  103. * v is modified in place. On output, adjacent pairs of elements of v
  104. * give the coefficients of a smaller polynomial in X, with the pair
  105. * v[2i],v[2i+1] being the coefficients of X^0 and X^1 respectively in
  106. * the ring (Z/qZ)[X] / <X^2 - k>, where k = powers_odd_reversed_order[i].
  107. */
  108. static void mlkem_ntt(uint16_t *v)
  109. {
  110. const uint64_t Qrecip = reciprocal_for_reduction(Q);
  111. size_t next_power = 1;
  112. size_t len; // WINSCP
  113. for (len = 128; len >= 2; len /= 2) {
  114. size_t start; // WINSCP
  115. for (start = 0; start < 256; start += 2*len) {
  116. uint16_t mult = powers_reversed_order[next_power++];
  117. size_t j; // WINSCP
  118. for (j = start; j < start + len; j++) {
  119. uint16_t t = reduce(mult * v[j + len], Q, Qrecip);
  120. v[j + len] = reduce(v[j] + Q - t, Q, Qrecip);
  121. v[j] = reduce(v[j] + t, Q, Qrecip);
  122. }
  123. }
  124. }
  125. }
  126. /*
  127. * Convert back from NTT representation. Exactly inverts mlkem_ntt().
  128. */
  129. static void mlkem_inverse_ntt(uint16_t *v)
  130. {
  131. const uint64_t Qrecip = reciprocal_for_reduction(Q);
  132. size_t next_power = 127;
  133. size_t len; // WINSCP
  134. for (len = 2; len <= 128; len *= 2) {
  135. size_t start; // WINSCP
  136. for (start = 0; start < 256; start += 2*len) {
  137. uint16_t mult = powers_reversed_order[next_power--];
  138. size_t j; // WINSCP
  139. for (j = start; j < start + len; j++) {
  140. uint16_t t = v[j];
  141. v[j] = reduce(t + v[j + len], Q, Qrecip);
  142. v[j + len] = reduce(mult * (v[j + len] + Q - t), Q, Qrecip);
  143. }
  144. }
  145. }
  146. { // WINSCP
  147. size_t i; // WINSCP
  148. for (i = 0; i < 256; i++)
  149. v[i] = reduce(v[i] * 3303, Q, Qrecip);
  150. } // WINSCP
  151. }
  152. /*
  153. * Multiply two elements of R in NTT representation.
  154. *
  155. * The output can alias an input completely, but mustn't alias one
  156. * partially.
  157. */
  158. static void mlkem_multiply_ntts(
  159. uint16_t *out, const uint16_t *a, const uint16_t *b)
  160. {
  161. const uint64_t Qrecip = reciprocal_for_reduction(Q);
  162. size_t i; // WINSCP
  163. for (i = 0; i < 128; i++) {
  164. uint16_t a0 = a[2*i], a1 = a[2*i+1];
  165. uint16_t b0 = b[2*i], b1 = b[2*i+1];
  166. uint16_t mult = powers_odd_reversed_order[i];
  167. uint16_t a1b1 = reduce(a1 * b1, Q, Qrecip);
  168. out[2*i] = reduce(a0 * b0 + a1b1 * mult, Q, Qrecip);
  169. out[2*i+1] = reduce(a0 * b1 + a1 * b0, Q, Qrecip);
  170. }
  171. }
  172. /* ----------------------------------------------------------------------
  173. * Operations on matrices over the ring R.
  174. *
  175. * Most of these don't mind whether the matrix contains ring elements
  176. * represented directly as polynomials, or in NTT form. The exception
  177. * is that mlkem_matrix_mul requires it to be in NTT form (because
  178. * multiplying is a huge pain in the ordinary representation).
  179. */
  180. typedef struct mlkem_matrix mlkem_matrix;
  181. struct mlkem_matrix {
  182. unsigned nrows, ncols;
  183. /*
  184. * (nrows * ncols * 256) 16-bit integers. Each 256-word block
  185. * contains an element of R; the blocks are in in row-major order,
  186. * so that (data + 256*(ncols*y + x)) points at the start of the
  187. * element in row y column x.
  188. */
  189. uint16_t *data;
  190. };
  191. /* Storage used for multiple matrices, to free all at once afterwards */
  192. typedef struct mlkem_matrix_storage mlkem_matrix_storage;
  193. struct mlkem_matrix_storage {
  194. uint16_t *data;
  195. size_t n; /* number of ring elements */
  196. };
  197. /*
  198. * Allocate space for multiple matrices. All the arrays of uint16_t
  199. * are allocated as a single big array. This makes it easy to free the
  200. * whole lot in one go afterwards.
  201. *
  202. * It also means that the arrays have a fixed memory relationship to
  203. * each other, which matters not at all during live use, but
  204. * eliminates spurious control-flow divergences in testsc based on
  205. * accidents of memory allocation when vectorised code checks two
  206. * memory regions to see if they alias. (The compiler-generated
  207. * aliasing check must do two comparisons, one for each direction, and
  208. * the order of those two regions in memory affects whether the first
  209. * comparison decides the second one is necessary.)
  210. *
  211. * The variadic arguments for this function consist of a sequence of
  212. * triples (mlkem_matrix *m, int nrows, int ncols), terminated by a
  213. * null matrix pointer.
  214. */
  215. static void mlkem_matrix_alloc(mlkem_matrix_storage *storage, ...)
  216. {
  217. va_list ap;
  218. mlkem_matrix *m;
  219. storage->n = 0;
  220. va_start(ap, storage);
  221. while ((m = va_arg(ap, mlkem_matrix *)) != NULL) {
  222. int nrows = va_arg(ap, int), ncols = va_arg(ap, int);
  223. storage->n += nrows * ncols;
  224. }
  225. va_end(ap);
  226. storage->data = snewn(256 * storage->n, uint16_t);
  227. { // WINSCP
  228. size_t pos = 0;
  229. va_start(ap, storage);
  230. while ((m = va_arg(ap, mlkem_matrix *)) != NULL) {
  231. int nrows = va_arg(ap, int), ncols = va_arg(ap, int);
  232. m->nrows = nrows;
  233. m->ncols = ncols;
  234. m->data = storage->data + 256 * pos;
  235. pos += nrows * ncols;
  236. }
  237. va_end(ap);
  238. } // WINSCP
  239. }
  240. /* Clear and free the storage allocated by mlkem_matrix_alloc. */
  241. static void mlkem_matrix_storage_free(mlkem_matrix_storage *storage)
  242. {
  243. smemclr(storage->data, 256 * storage->n * sizeof(uint16_t));
  244. sfree(storage->data);
  245. }
  246. /* Add two matrices. */
  247. static void mlkem_matrix_add(mlkem_matrix *out, const mlkem_matrix *left,
  248. const mlkem_matrix *right)
  249. {
  250. const uint64_t Qrecip = reciprocal_for_reduction(Q);
  251. assert(out->nrows == left->nrows);
  252. assert(out->ncols == left->ncols);
  253. assert(out->nrows == right->nrows);
  254. assert(out->ncols == right->ncols);
  255. { // WINSCP
  256. size_t i; // WINSCP
  257. for (i = 0; i < out->nrows; i++) {
  258. size_t j; // WINSCP
  259. for (j = 0; j < out->ncols; j++) {
  260. const uint16_t *lv = left->data + 256*(i * left->ncols + j);
  261. const uint16_t *rv = right->data + 256*(i * right->ncols + j);
  262. uint16_t *ov = out->data + 256*(i * out->ncols + j);
  263. size_t p; // WINSCP
  264. for (p = 0; p < 256; p++)
  265. ov[p] = reduce(lv[p] + rv[p] , Q, Qrecip);
  266. }
  267. }
  268. } // WINSCP
  269. }
  270. /* Subtract matrices. */
  271. static void mlkem_matrix_sub(mlkem_matrix *out, const mlkem_matrix *left,
  272. const mlkem_matrix *right)
  273. {
  274. const uint64_t Qrecip = reciprocal_for_reduction(Q);
  275. assert(out->nrows == left->nrows);
  276. assert(out->ncols == left->ncols);
  277. assert(out->nrows == right->nrows);
  278. assert(out->ncols == right->ncols);
  279. { // WINSCP
  280. size_t i; // WINSCP
  281. for (i = 0; i < out->nrows; i++) {
  282. size_t j; // WINSCP
  283. for (j = 0; j < out->ncols; j++) {
  284. const uint16_t *lv = left->data + 256*(i * left->ncols + j);
  285. const uint16_t *rv = right->data + 256*(i * right->ncols + j);
  286. uint16_t *ov = out->data + 256*(i * out->ncols + j);
  287. size_t p; // WINSCP
  288. for (p = 0; p < 256; p++)
  289. ov[p] = reduce(lv[p] + Q - rv[p] , Q, Qrecip);
  290. }
  291. }
  292. } // WINSCP
  293. }
  294. /* Convert every element of a matrix into NTT representation. */
  295. static void mlkem_matrix_ntt(mlkem_matrix *m)
  296. {
  297. size_t i; // WINSCP
  298. for (i = 0; i < m->nrows * m->ncols; i++)
  299. mlkem_ntt(m->data + i * 256);
  300. }
  301. /* Convert every element of a matrix out of NTT representation. */
  302. static void mlkem_matrix_inverse_ntt(mlkem_matrix *m)
  303. {
  304. size_t i; // WINSCP
  305. for (i = 0; i < m->nrows * m->ncols; i++)
  306. mlkem_inverse_ntt(m->data + i * 256);
  307. }
  308. /*
  309. * Multiply two matrices, assuming their elements to be currently in
  310. * NTT representation.
  311. *
  312. * The left input must have the same number of columns as the right
  313. * has rows, in the usual fashion. The output matrix is overwritten.
  314. *
  315. * If 'left_transposed' is true then the left matrix is used as if
  316. * transposed.
  317. */
  318. static void mlkem_matrix_mul(mlkem_matrix *out, const mlkem_matrix *left,
  319. const mlkem_matrix *right, bool left_transposed)
  320. {
  321. const uint64_t Qrecip = reciprocal_for_reduction(Q);
  322. size_t left_nrows = (left_transposed ? left->ncols : left->nrows);
  323. size_t left_ncols = (left_transposed ? left->nrows : left->ncols);
  324. assert(out->nrows == left_nrows);
  325. assert(left_ncols == right->nrows);
  326. assert(right->ncols == out->ncols);
  327. { // WINSCP
  328. uint16_t work[256];
  329. size_t i; // WINSCP
  330. for (i = 0; i < out->nrows; i++) {
  331. size_t j; // WINSCP
  332. for (j = 0; j < out->ncols; j++) {
  333. uint16_t *thisout = out->data + 256 * (i * out->ncols + j);
  334. memset(thisout, 0, 256 * sizeof(uint16_t));
  335. { // WINSCP
  336. size_t k; // WINSCP
  337. for (k = 0; k < right->nrows; k++) {
  338. size_t left_index = left_transposed ?
  339. k * left->ncols + i : i * left->ncols + k;
  340. const uint16_t *lv = left->data + 256*left_index;
  341. const uint16_t *rv = right->data + 256*(k * right->ncols + j);
  342. mlkem_multiply_ntts(work, lv, rv);
  343. { // WINSCP
  344. size_t p; // WINSCP
  345. for (p = 0; p < 256; p++)
  346. thisout[p] = reduce(thisout[p] + work[p], Q, Qrecip);
  347. } // WINSCP
  348. }
  349. } // WINSCP
  350. }
  351. }
  352. smemclr(work, sizeof(work));
  353. } // WINSCP
  354. }
  355. /* ----------------------------------------------------------------------
  356. * Random sampling functions to make up various kinds of randomised
  357. * matrix and vector.
  358. */
  359. static void mlkem_sample_ntt(uint16_t *output, ptrlen seed); /* forward ref */
  360. /*
  361. * Invent a matrix based on a 32-bit random seed rho.
  362. *
  363. * This matrix is logically part of the public (encryption) key: it's
  364. * not transmitted explicitly, but the seed is, so that the receiver
  365. * can reconstruct the same matrix. As a result, this function
  366. * _doesn't_ have to worry about side channel resistance, or even
  367. * leaving data lying around in arrays.
  368. */
  369. static void mlkem_matrix_from_seed(mlkem_matrix *m, const void *rho)
  370. {
  371. unsigned r; // WINSCP
  372. for (r = 0; r < m->nrows; r++) {
  373. unsigned c; // WINSCP
  374. for (c = 0; c < m->ncols; c++) {
  375. unsigned char seedbuf[34];
  376. memcpy(seedbuf, rho, 32);
  377. seedbuf[32] = c;
  378. seedbuf[33] = r;
  379. mlkem_sample_ntt(m->data + 256 * (r * m->nrows + c),
  380. make_ptrlen(seedbuf, sizeof(seedbuf)));
  381. }
  382. }
  383. }
  384. /*
  385. * Invent a single element of the ring R, uniformly at random, derived
  386. * in a specified way from the input random seed.
  387. *
  388. * Used as a subroutine of mlkem_matrix_from_seed() above. So, for the
  389. * same reasons, this doesn't have to worry about side channels,
  390. * making the 'rejection sampling' generation technique easy.
  391. *
  392. * The name SampleNTT (in the official spec) reflects the fact that
  393. * the output elements are regarded as being in NTT representation.
  394. * But since the NTT is a bijection, and the sampling is from the
  395. * uniform probability distribution over R, nothing in this function
  396. * actually needs to worry about that.
  397. */
  398. static void mlkem_sample_ntt(uint16_t *output, ptrlen seed)
  399. {
  400. ShakeXOF *sx = shake128_xof_from_input(seed);
  401. unsigned char bytebuf[4];
  402. bytebuf[3] = '\0';
  403. { // WINSCP
  404. size_t pos; // WINSCP
  405. for (pos = 0; pos < 256 ;) {
  406. /* Read 3 bytes into the low-order end of bytebuf. The fourth
  407. * byte is always 0, so this gives us a random 24-bit integer. */
  408. shake_xof_read(sx, &bytebuf, 3);
  409. { // WINSCP
  410. uint32_t random24 = GET_32BIT_LSB_FIRST(bytebuf);
  411. /*
  412. * Split that integer up into two 12-bit ones, and use each
  413. * one if it's in range (taking care for the second one that
  414. * we didn't just reach the end of the buffer).
  415. *
  416. * This function is only used for generating matrices from an
  417. * element of the public key, so we can use data-dependent
  418. * control flow here without worrying about giving away
  419. * secrets.
  420. */
  421. uint16_t d1 = random24 & 0xFFF;
  422. uint16_t d2 = random24 >> 12;
  423. if (d1 < Q)
  424. output[pos++] = d1;
  425. if (d2 < Q && pos < 256)
  426. output[pos++] = d2;
  427. } // WINSCP
  428. }
  429. shake_xof_free(sx);
  430. } // WINSCP
  431. }
  432. /*
  433. * Invent a random vector, with its elements _not_ in NTT
  434. * representation, and all the coefficients very small integers (a lot
  435. * smaller than q) of one sign or the other.
  436. *
  437. * eta is a parameter of the probability distribution, sigma is an
  438. * input 32-byte random seed. Each element of the vector is made by a
  439. * separate hash operation based on sigma plus a distinguishing
  440. * integer suffix; 'offset' indicates the starting point for those
  441. * suffixes, so that the ith output value has suffix (offset+i).
  442. */
  443. static void mlkem_matrix_poly_cbd(
  444. mlkem_matrix *v, int eta, const void *sigma, int offset)
  445. {
  446. const uint64_t Qrecip = reciprocal_for_reduction(Q);
  447. unsigned char seedbuf[33];
  448. memcpy(seedbuf, sigma, 32);
  449. { // WINSCP
  450. unsigned char *randombuf = snewn(eta * 64, unsigned char);
  451. unsigned r; // WINSCP
  452. for (r = 0; r < v->nrows * v->ncols; r++) {
  453. seedbuf[32] = r + offset;
  454. { // WINSCP
  455. ShakeXOF *sx = shake256_xof_from_input(make_ptrlen(seedbuf, 33));
  456. shake_xof_read(sx, randombuf, eta * 64);
  457. shake_xof_free(sx);
  458. { // WINSCP
  459. size_t i; // WINSCP
  460. for (i = 0; i < 256; i++) {
  461. unsigned x = 0, y = 0;
  462. size_t j; // WINSCP
  463. for (j = 0; j < eta; j++) {
  464. size_t bitpos = 2 * i * eta + j;
  465. x += 1 & ((randombuf[bitpos >> 3]) >> (bitpos & 7));
  466. }
  467. for (j = 0; j < eta; j++) {
  468. size_t bitpos = 2 * i * eta + eta + j;
  469. y += 1 & ((randombuf[bitpos >> 3]) >> (bitpos & 7));
  470. }
  471. v->data[256 * r + i] = reduce(x + Q - y, Q, Qrecip);
  472. }
  473. } // WINSCP
  474. } // WINSCP
  475. }
  476. smemclr(seedbuf, sizeof(seedbuf));
  477. smemclr(randombuf, eta * 64);
  478. sfree(randombuf);
  479. } // WINSCP
  480. }
  481. /* ----------------------------------------------------------------------
  482. * Byte-encoding and decoding functions.
  483. */
  484. /*
  485. * Losslessly encode one or more elements of the ring R.
  486. *
  487. * Each polynomial coefficient, in the range [0,q), is represented as
  488. * a 12-bit integer. So encoding an entire ring element requires
  489. * (256*12)/8 = 384 bytes, and if that 384-byte string were
  490. * interpreted as a little-endian 3072-bit integer D, then the
  491. * coefficient of X^i could be recovered as (D >> (12*i)) & 0xFFF.
  492. *
  493. * The input is expected to be an array of 256*n uint16_t (often the
  494. * 'data' pointer in an mlkem_matrix). The output is 384*n bytes.
  495. */
  496. static void mlkem_byte_encode_lossless(
  497. void *outv, const uint16_t *in, size_t n)
  498. {
  499. unsigned char *out = (unsigned char *)outv;
  500. uint32_t buffer = 0, bufbits = 0;
  501. size_t i; // WINSCP
  502. for (i = 0; i < 256*n; i++) {
  503. buffer |= (uint32_t) in[i] << bufbits;
  504. bufbits += 12;
  505. while (bufbits >= 8) {
  506. *out++ = buffer & 0xFF;
  507. buffer >>= 8;
  508. bufbits -= 8;
  509. }
  510. }
  511. }
  512. /*
  513. * Decode a string written by mlkem_byte_encode_lossless.
  514. *
  515. * Each 12-bit value extracted from the input data is checked to make
  516. * sure it's in the range [0,q); if it's out of range, the whole
  517. * function fails and returns false. (But it need not do so in
  518. * constant time, because that's an "abandon the whole connection"
  519. * error, not a "subtly make things not work for the attacker" error.)
  520. */
  521. static bool mlkem_byte_decode_lossless(
  522. uint16_t *out, const void *inv, size_t n)
  523. {
  524. const unsigned char *in = (const unsigned char *)inv;
  525. uint32_t buffer = 0, bufbits = 0;
  526. size_t i; // WINSCP
  527. for (i = 0; i < 384*n; i++) {
  528. buffer |= (uint32_t) in[i] << bufbits;
  529. bufbits += 8;
  530. while (bufbits >= 12) {
  531. uint16_t value = buffer & 0xFFF;
  532. if (value >= Q)
  533. return false;
  534. *out++ = value;
  535. buffer >>= 12;
  536. bufbits -= 12;
  537. }
  538. }
  539. return true;
  540. }
  541. /*
  542. * Lossily encode one or more elements of R, using d bits for each
  543. * polynomial coefficient, for some d < 12. Each output d-bit value is
  544. * obtained as if by regarding the input coefficient as an integer in
  545. * the range [0,q), multiplying by 2^d/q, and rounding to the nearest
  546. * integer. (Since q is odd, 'round to nearest' can't have a tie.)
  547. *
  548. * This means that a large enough input coefficient can round up to
  549. * 2^d itself. In that situation the output d-bit value is 0.
  550. */
  551. static void mlkem_byte_encode_compressed(
  552. void *outv, const uint16_t *in, unsigned d, size_t n)
  553. {
  554. const uint64_t Qrecip = reciprocal_for_reduction(2*Q);
  555. unsigned char *out = (unsigned char *)outv;
  556. uint32_t buffer = 0, bufbits = 0;
  557. size_t i; // WINSCP
  558. for (i = 0; i < 256*n; i++) {
  559. uint32_t dividend = ((uint32_t)in[i] << (d+1)) + Q;
  560. uint32_t quotient;
  561. reduce_with_quot(dividend, &quotient, 2*Q, Qrecip);
  562. buffer |= (uint32_t) (quotient & ((1 << d) - 1)) << bufbits;
  563. bufbits += d;
  564. while (bufbits >= 8) {
  565. *out++ = buffer & 0xFF;
  566. buffer >>= 8;
  567. bufbits -= 8;
  568. }
  569. }
  570. }
  571. /*
  572. * Decode the lossily encoded output of mlkem_byte_encode_compressed.
  573. *
  574. * Each d-bit chunk of the encoding is converted back into a
  575. * polynomial coefficient as if by multiplying by q/2^d and then
  576. * rounding to nearest. Unlike the rounding in the encode step, this
  577. * _can_ have a tie when an unrounded value is half way between two
  578. * integers. Ties are broken by rounding up (as if the whole rounding
  579. * were performed by the simple rounding method of adding 1/2 and then
  580. * truncating).
  581. *
  582. * Unlike the lossless decode function, this one can't fail input
  583. * validation, because any d-bit value generates some legal
  584. * coefficient.
  585. */
  586. static void mlkem_byte_decode_compressed(
  587. uint16_t *out, const void *inv, unsigned d, size_t n)
  588. {
  589. const unsigned char *in = (const unsigned char *)inv;
  590. uint32_t buffer = 0, bufbits = 0;
  591. size_t i; // WINSCP
  592. for (i = 0; i < 32*d*n; i++) {
  593. buffer |= (uint32_t) in[i] << bufbits;
  594. bufbits += 8;
  595. while (bufbits >= d) {
  596. uint32_t value = buffer & ((1 << d) - 1);
  597. *out++ = (value * (2*Q) + (1 << d)) >> (d + 1);;
  598. buffer >>= d;
  599. bufbits -= d;
  600. }
  601. }
  602. }
  603. /* ----------------------------------------------------------------------
  604. * The top-level ML-KEM functions.
  605. */
  606. /*
  607. * Innermost keygen function, exposed for side-channel testing, with
  608. * separate random values rho (public) and sigma (private), so that
  609. * testsc can vary sigma while leaving rho the same.
  610. */
  611. void mlkem_keygen_rho_sigma(
  612. BinarySink *ek_out, BinarySink *dk_out, const mlkem_params *params,
  613. const void *rho, const void *sigma, const void *z)
  614. {
  615. mlkem_matrix_storage storage[1];
  616. mlkem_matrix a[1], s[1], e[1], t[1];
  617. mlkem_matrix_alloc(storage,
  618. a, params->k, params->k,
  619. s, params->k, 1,
  620. e, params->k, 1,
  621. t, params->k, 1,
  622. (mlkem_matrix *)NULL);
  623. /*
  624. * Make a random k x k matrix A (regarded as in NTT form).
  625. */
  626. mlkem_matrix_from_seed(a, rho);
  627. /*
  628. * Make two column vectors s and e, with all components having
  629. * small polynomial coefficients, and then convert them _into_ NTT
  630. * form.
  631. */
  632. mlkem_matrix_poly_cbd(s, params->eta_1, sigma, 0);
  633. mlkem_matrix_poly_cbd(e, params->eta_1, sigma, params->k);
  634. mlkem_matrix_ntt(s);
  635. mlkem_matrix_ntt(e);
  636. /*
  637. * Compute the vector t = As + e.
  638. */
  639. mlkem_matrix_mul(t, a, s, false);
  640. mlkem_matrix_add(t, t, e);
  641. /*
  642. * The encryption key is the vector t, plus the random seed rho
  643. * from which anyone can reconstruct the matrix A.
  644. */
  645. { // WINSCP
  646. unsigned char ek[1568];
  647. mlkem_byte_encode_lossless(ek, t->data, params->k);
  648. memcpy(ek + 384 * params->k, rho, 32);
  649. { // WINSCP
  650. size_t eklen = 384 * params->k + 32;
  651. put_data(ek_out, ek, eklen);
  652. /*
  653. * The decryption key (for the internal "K-PKE" public-key system)
  654. * is the vector s.
  655. */
  656. { // WINSCP
  657. unsigned char dk[1536];
  658. mlkem_byte_encode_lossless(dk, s->data, params->k);
  659. { // WINSCP
  660. size_t dklen = 384 * params->k;
  661. /*
  662. * The decapsulation key, for the full ML-KEM, consists of
  663. * - the decryption key as above
  664. * - the encryption key
  665. * - an extra hash of the encryption key
  666. * - the random value z used for "implicit rejection", aka
  667. * constructing a useless output value if tampering is
  668. * detected. (I think so an attacker can't tell the difference
  669. * between "I was rumbled" and "I was undetected but my attempt
  670. * didn't generate the right key">)
  671. */
  672. put_data(dk_out, dk, dklen);
  673. put_data(dk_out, ek, eklen);
  674. { // WINSCP
  675. ssh_hash *h = ssh_hash_new(&ssh_sha3_256);
  676. put_data(h, ek, eklen);
  677. { // WINSCP
  678. unsigned char ekhash[32];
  679. ssh_hash_final(h, ekhash);
  680. put_data(dk_out, ekhash, 32);
  681. put_data(dk_out, z, 32);
  682. mlkem_matrix_storage_free(storage);
  683. smemclr(ek, sizeof(ek));
  684. smemclr(ekhash, sizeof(ekhash));
  685. smemclr(dk, sizeof(dk));
  686. } // WINSCP
  687. } // WINSCP
  688. } // WINSCP
  689. } // WINSCP
  690. } // WINSCP
  691. } // WINSCP
  692. }
  693. /*
  694. * Internal keygen function as described in the official spec, taking
  695. * random values d and z and deterministically constructing a key from
  696. * them. The test vectors are expressed in terms of this.
  697. */
  698. void mlkem_keygen_internal(
  699. BinarySink *ek, BinarySink *dk, const mlkem_params *params,
  700. const void *d, const void *z)
  701. {
  702. /* Hash the input randomness d to make two 32-byte values rho and sigma */
  703. unsigned char rho_sigma[64];
  704. ssh_hash *h = ssh_hash_new(&ssh_sha3_512);
  705. put_data(h, d, 32);
  706. put_byte(h, params->k);
  707. ssh_hash_final(h, rho_sigma);
  708. mlkem_keygen_rho_sigma(ek, dk, params, rho_sigma, rho_sigma + 32, z);
  709. smemclr(rho_sigma, sizeof(rho_sigma));
  710. }
  711. /*
  712. * Keygen function for live use, making up the values at random.
  713. */
  714. void mlkem_keygen(
  715. BinarySink *ek, BinarySink *dk, const mlkem_params *params)
  716. {
  717. unsigned char dz[64];
  718. random_read(dz, 64);
  719. mlkem_keygen_internal(ek, dk, params, dz, dz + 32);
  720. smemclr(dz, sizeof(dz));
  721. }
  722. /*
  723. * Internal encapsulation function from the official spec, taking a
  724. * random value m as input and behaving deterministically. Again used
  725. * for test vectors.
  726. */
  727. bool mlkem_encaps_internal(
  728. BinarySink *c_out, BinarySink *k_out,
  729. const mlkem_params *params, ptrlen ek, const void *m)
  730. {
  731. mlkem_matrix_storage storage[1];
  732. mlkem_matrix t[1], a[1], y[1], e1[1], e2[1], mu[1], u[1], v[1];
  733. mlkem_matrix_alloc(storage,
  734. t, params->k, 1,
  735. a, params->k, params->k,
  736. y, params->k, 1,
  737. e1, params->k, 1,
  738. e2, 1, 1,
  739. mu, 1, 1,
  740. u, params->k, 1,
  741. v, 1, 1,
  742. (mlkem_matrix *)NULL);
  743. /*
  744. * Validate input: ek must be the correct length, and its encoded
  745. * ring elements must not include any 16-bit integer intended to
  746. * represent a value mod q which is not in fact in the range [0,q).
  747. *
  748. * We test the latter property by decoding the matrix t, and
  749. * checking the success status returned by the decode.
  750. */
  751. if (ek.len != 384 * params->k + 32 ||
  752. !mlkem_byte_decode_lossless(t->data, ek.ptr, params->k)) {
  753. mlkem_matrix_storage_free(storage);
  754. return false;
  755. }
  756. /*
  757. * Regenerate the same matrix A used by key generation, from the
  758. * seed string rho at the end of ek.
  759. */
  760. mlkem_matrix_from_seed(a, (const unsigned char *)ek.ptr + 384 * params->k);
  761. /*
  762. * Hash the input randomness m, to get the value k we'll use as
  763. * the output shared secret, plus some randomness for making up
  764. * the vectors below.
  765. */
  766. { // WINSCP
  767. unsigned char kr[64];
  768. unsigned char ekhash[32];
  769. ssh_hash *h;
  770. /* Hash the encryption key */
  771. h = ssh_hash_new(&ssh_sha3_256);
  772. put_datapl(h, ek);
  773. ssh_hash_final(h, ekhash);
  774. /* Hash the input randomness m with that hash */
  775. h = ssh_hash_new(&ssh_sha3_512);
  776. put_data(h, m, 32);
  777. put_data(h, ekhash, 32);
  778. ssh_hash_final(h, kr);
  779. { // WINSCP
  780. const unsigned char *k = kr, *r = kr + 32;
  781. /*
  782. * Invent random k-element vectors y and e1, and a random scalar
  783. * e2 (here represented as a 1x1 matrix for the sake of not
  784. * proliferating internal helper functions). All are generated by
  785. * poly_cbd (i.e. their ring elements have polynomial coefficients
  786. * of small magnitude). y needs to be in NTT form.
  787. *
  788. * These generations all use r as their seed, which was the second
  789. * half of the 64-byte hash of the input m. We pass different
  790. * 'offset' values to mlkem_matrix_poly_cbd() to ensure the
  791. * generations are probabilistically independent.
  792. */
  793. mlkem_matrix_poly_cbd(y, params->eta_1, r, 0);
  794. mlkem_matrix_ntt(y);
  795. mlkem_matrix_poly_cbd(e1, params->eta_2, r, params->k);
  796. mlkem_matrix_poly_cbd(e2, params->eta_2, r, 2 * params->k);
  797. /*
  798. * Invent a random scalar mu (again imagined as a 1x1 matrix),
  799. * this time by doing lossy decompression of the random value m at
  800. * 1 bit per polynomial coefficient. That is, all the polynomial
  801. * coefficients of mu are either 0 or 1665 = (q+1)/2.
  802. *
  803. * This generation reuses the _input_ random value m, not either
  804. * half of the hash we made of it.
  805. */
  806. mlkem_byte_decode_compressed(mu->data, m, 1, 1);
  807. /*
  808. * Calculate a k-element vector u = A^T y + e1.
  809. *
  810. * A and y are in NTT representation, but e1 is not, and we don't
  811. * want the output to be in NTT form either. So we perform an
  812. * inverse NTT after the multiplication.
  813. */
  814. mlkem_matrix_mul(u, a, y, true); /* regard a as transposed */
  815. mlkem_matrix_inverse_ntt(u);
  816. mlkem_matrix_add(u, u, e1);
  817. /*
  818. * Calculate a scalar v = t^T y + e2 + mu.
  819. *
  820. * (t and y are column vectors, so t^T y is just a scalar - you
  821. * could think of it as the dot product t.y if you preferred.)
  822. *
  823. * Similarly to above, we multiply t and y which are in NTT
  824. * representation, and then perform an inverse NTT before adding
  825. * e2 and mu, which aren't.
  826. */
  827. mlkem_matrix_mul(v, t, y, true); /* regard t as transposed */
  828. mlkem_matrix_inverse_ntt(v);
  829. mlkem_matrix_add(v, v, e2);
  830. mlkem_matrix_add(v, v, mu);
  831. /*
  832. * The ciphertext consists of u and v, both encoded lossily, with
  833. * different numbers of bits retained per element.
  834. */
  835. { // WINSCP
  836. char c[1568];
  837. mlkem_byte_encode_compressed(c, u->data, params->d_u, params->k);
  838. mlkem_byte_encode_compressed(c + 32 * params->k * params->d_u,
  839. v->data, params->d_v, 1);
  840. put_data(c_out, c, 32 * (params->k * params->d_u + params->d_v));
  841. /*
  842. * The output shared secret is just half of the hash of m (the
  843. * first half, which we didn't use for generating vectors above).
  844. */
  845. put_data(k_out, k, 32);
  846. smemclr(kr, sizeof(kr));
  847. mlkem_matrix_storage_free(storage);
  848. return true;
  849. } // WINSCP
  850. } // WINSCP
  851. } // WINSCP
  852. }
  853. /*
  854. * Encapsulation function for live use, using the real RNG..
  855. */
  856. bool mlkem_encaps(BinarySink *ciphertext, BinarySink *kout,
  857. const mlkem_params *params, ptrlen ek)
  858. {
  859. unsigned char m[32];
  860. random_read(m, 32);
  861. { // WINSCP
  862. bool success = mlkem_encaps_internal(ciphertext, kout, params, ek, m);
  863. smemclr(m, sizeof(m));
  864. return success;
  865. } // WINSCP
  866. }
  867. /*
  868. * Decapsulation.
  869. */
  870. bool mlkem_decaps(BinarySink *k_out, const mlkem_params *params,
  871. ptrlen dk, ptrlen c)
  872. {
  873. /*
  874. * Validation: check the input strings are the right lengths.
  875. */
  876. if (dk.len != 768 * params->k + 96)
  877. return false;
  878. if (c.len != 32 * (params->d_u * params->k + params->d_v))
  879. return false;
  880. /*
  881. * Further validation: extract the encryption key from the middle
  882. * of dk, hash it, and check the hash matches.
  883. */
  884. { // WINSCP
  885. const unsigned char *dkp = (const unsigned char *)dk.ptr;
  886. const unsigned char *cp = (const unsigned char *)c.ptr;
  887. ptrlen ek = make_ptrlen(dkp + 384*params->k, 384*params->k + 32);
  888. ssh_hash *h;
  889. unsigned char ekhash[32];
  890. h = ssh_hash_new(&ssh_sha3_256);
  891. put_datapl(h, ek);
  892. ssh_hash_final(h, ekhash);
  893. if (!smemeq(ekhash, dkp + 768*params->k + 32, 32))
  894. return false;
  895. { // WINSCP
  896. mlkem_matrix_storage storage[1];
  897. mlkem_matrix u[1], v[1], s[1], w[1];
  898. mlkem_matrix_alloc(storage,
  899. u, params->k, 1,
  900. v, 1, 1,
  901. s, params->k, 1,
  902. w, 1, 1,
  903. (mlkem_matrix *)NULL);
  904. /*
  905. * Decode the vector u and the scalar v from the ciphertext. These
  906. * won't come out exactly the same as the originals, because of
  907. * the lossy compression.
  908. */
  909. mlkem_byte_decode_compressed(u->data, cp, params->d_u, params->k);
  910. mlkem_matrix_ntt(u);
  911. mlkem_byte_decode_compressed(v->data, cp + 32 * params->d_u * params->k,
  912. params->d_v, 1);
  913. /*
  914. * Decode the vector s from the private key.
  915. */
  916. mlkem_byte_decode_lossless(s->data, dkp, params->k);
  917. /*
  918. * Calculate the scalar w = v - s^T u.
  919. *
  920. * s and u are in NTT representation, but v isn't, so we
  921. * inverse-NTT the product before doing the subtraction. Therefore
  922. * w is not in NTT form either.
  923. */
  924. mlkem_matrix_mul(w, s, u, true); /* regard s as transposed */
  925. mlkem_matrix_inverse_ntt(w);
  926. mlkem_matrix_sub(w, v, w);
  927. /*
  928. * The aim is that this reconstructs something close enough to the
  929. * random vector mu that was made from the input secret m to
  930. * encapsulation, on the grounds that mu's polynomial coefficients
  931. * were very widely separated (on opposite sides of the cyclic
  932. * additive group of Z/qZ) and the noise added during encryption
  933. * all had _small_ polynomial coefficients.
  934. *
  935. * So we now re-encode this lossily at 1 bit per polynomial
  936. * coefficient, and hope that it reconstructs the actual string m.
  937. *
  938. * However, this _is_ only a hope! The ML-KEM decryption is not a
  939. * true mathematical inverse to encryption. With extreme bad luck,
  940. * the noise can add up enough that it flips a bit of m, and
  941. * everything fails. The parameters are chosen to make this happen
  942. * with negligible probability (the same kind of low probability
  943. * that makes you not worry about spontaneous hash collisions),
  944. * but it's not actually impossible.
  945. */
  946. { // WINSCP
  947. unsigned char m[32];
  948. mlkem_byte_encode_compressed(m, w->data, 1, 1);
  949. /*
  950. * Now do the key _encapsulation_ again from scratch, using that
  951. * secret m as input, and check that it generates the identical
  952. * ciphertext. This should catch the above theoretical failure,
  953. * but also, it's a defence against malicious intervention in the
  954. * key exchange.
  955. *
  956. * This is also where we get the output secret k from: the
  957. * encapsulation function creates it as half of the hash of m.
  958. */
  959. { // WINSCP
  960. unsigned char c_regen[1568], k[32];
  961. buffer_sink c_sink[1], k_sink[1];
  962. buffer_sink_init(c_sink, c_regen, sizeof(c_regen));
  963. buffer_sink_init(k_sink, k, sizeof(k));
  964. { // WINSCP
  965. bool success = mlkem_encaps_internal(
  966. BinarySink_UPCAST(c_sink), BinarySink_UPCAST(k_sink), params, ek, m);
  967. /* If any application of ML-KEM uses a dk given to it by someone
  968. * else, then perhaps they have to worry about being given an
  969. * invalid one? But in our application we always expect this to
  970. * succeed, because dk is generated and used at the same end of
  971. * the SSH connection, within the same process, and nobody is
  972. * interfering with it. */
  973. assert(success && "We generated this dk ourselves, how can it be bad?");
  974. /*
  975. * If mlkem_encaps_internal returned success but delivered the
  976. * wrong ciphertext, that's a failure, but we must be careful not
  977. * to let the attacker know exactly what went wrong. So we
  978. * generate a plausible but wrong substitute output secret.
  979. *
  980. * k_reject is that secret; for constant-time reasons we generate
  981. * it unconditionally.
  982. */
  983. { // WINSCP
  984. unsigned char k_reject[32];
  985. h = ssh_hash_new(&ssh_shake256_32bytes);
  986. put_data(h, dkp + 768 * params->k + 64, 32);
  987. put_datapl(h, c);
  988. ssh_hash_final(h, k_reject);
  989. /*
  990. * Now replace k with k_reject if the ciphertexts didn't match.
  991. */
  992. assert((void *)c_sink->out == (void *)(c_regen + c.len));
  993. { // WINSCP
  994. unsigned match = smemeq(c.ptr, c_regen, c.len);
  995. unsigned mask = match - 1;
  996. size_t i; // WINSCP
  997. for (i = 0; i < 32; i++)
  998. k[i] ^= mask & (k[i] ^ k_reject[i]);
  999. /*
  1000. * And we're done! Free everything and return whichever secret we
  1001. * chose.
  1002. */
  1003. put_data(k_out, k, 32);
  1004. mlkem_matrix_storage_free(storage);
  1005. smemclr(m, sizeof(m));
  1006. smemclr(c_regen, sizeof(c_regen));
  1007. smemclr(k, sizeof(k));
  1008. smemclr(k_reject, sizeof(k_reject));
  1009. return true;
  1010. } // WINSCP
  1011. } // WINSCP
  1012. } // WINSCP
  1013. } // WINSCP
  1014. } // WINSCP
  1015. } // WINSCP
  1016. } // WINSCP
  1017. }
  1018. /* ----------------------------------------------------------------------
  1019. * Implement the pq_kemalg vtable in terms of the above functions.
  1020. */
  1021. struct mlkem_dk {
  1022. strbuf *encoded;
  1023. pq_kem_dk dk;
  1024. };
  1025. static pq_kem_dk *mlkem_vt_keygen(const pq_kemalg *alg, BinarySink *ek)
  1026. {
  1027. struct mlkem_dk *mdk = snew(struct mlkem_dk);
  1028. mdk->dk.vt = alg;
  1029. mdk->encoded = strbuf_new_nm();
  1030. mlkem_keygen(ek, BinarySink_UPCAST(mdk->encoded), alg->extra);
  1031. return &mdk->dk;
  1032. }
  1033. static bool mlkem_vt_encaps(const pq_kemalg *alg, BinarySink *c, BinarySink *k,
  1034. ptrlen ek)
  1035. {
  1036. return mlkem_encaps(c, k, alg->extra, ek);
  1037. }
  1038. static bool mlkem_vt_decaps(pq_kem_dk *dk, BinarySink *k, ptrlen c)
  1039. {
  1040. struct mlkem_dk *mdk = container_of(dk, struct mlkem_dk, dk);
  1041. return mlkem_decaps(k, mdk->dk.vt->extra,
  1042. ptrlen_from_strbuf(mdk->encoded), c);
  1043. }
  1044. static void mlkem_vt_free_dk(pq_kem_dk *dk)
  1045. {
  1046. struct mlkem_dk *mdk = container_of(dk, struct mlkem_dk, dk);
  1047. strbuf_free(mdk->encoded);
  1048. sfree(mdk);
  1049. }
  1050. const pq_kemalg ssh_mlkem512 = {
  1051. /*.keygen =*/ mlkem_vt_keygen,
  1052. /*.encaps =*/ mlkem_vt_encaps,
  1053. /*.decaps =*/ mlkem_vt_decaps,
  1054. /*.free_dk =*/ mlkem_vt_free_dk,
  1055. /*.extra =*/ &mlkem_params_512,
  1056. /*.description =*/ "ML-KEM-512",
  1057. /*.ek_len =*/ 384 * 2 + 32,
  1058. /*.c_len =*/ 32 * (10 * 2 + 4),
  1059. };
  1060. const pq_kemalg ssh_mlkem768 = {
  1061. /*.keygen =*/ mlkem_vt_keygen,
  1062. /*.encaps =*/ mlkem_vt_encaps,
  1063. /*.decaps =*/ mlkem_vt_decaps,
  1064. /*.free_dk =*/ mlkem_vt_free_dk,
  1065. /*.extra =*/ &mlkem_params_768,
  1066. /*.description =*/ "ML-KEM-768",
  1067. /*.ek_len =*/ 384 * 3 + 32,
  1068. /*.c_len =*/ 32 * (10 * 3 + 4),
  1069. };
  1070. const pq_kemalg ssh_mlkem1024 = {
  1071. /*.keygen =*/ mlkem_vt_keygen,
  1072. /*.encaps =*/ mlkem_vt_encaps,
  1073. /*.decaps =*/ mlkem_vt_decaps,
  1074. /*.free_dk =*/ mlkem_vt_free_dk,
  1075. /*.extra =*/ &mlkem_params_1024,
  1076. /*.description =*/ "ML-KEM-1024",
  1077. /*.ek_len =*/ 384 * 4 + 32,
  1078. /*.c_len =*/ 32 * (11 * 4 + 5),
  1079. };