ml_dsa_sample.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. /*
  2. * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License 2.0 (the "License"). You may not use
  5. * this file except in compliance with the License. You can obtain a copy
  6. * in the file LICENSE in the source distribution or at
  7. * https://www.openssl.org/source/license.html
  8. */
  9. #include <openssl/byteorder.h>
  10. #include "ml_dsa_local.h"
  11. #include "ml_dsa_vector.h"
  12. #include "ml_dsa_matrix.h"
  13. #include "ml_dsa_hash.h"
  14. #include "internal/sha3.h"
  15. #include "internal/packet.h"
  16. #define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
  17. #define SHAKE256_BLOCKSIZE SHA3_BLOCKSIZE(256)
  18. /*
  19. * This is a constant time version of n % 5
  20. * Note that 0xFFFF / 5 = 0x3333, 2 is added to make an over-estimate of 1/5
  21. * and then we divide by (0xFFFF + 1)
  22. */
  23. #define MOD5(n) ((n) - 5 * (0x3335 * (n) >> 16))
  24. #if SHAKE128_BLOCKSIZE % 3 != 0
  25. # error "rej_ntt_poly() requires SHAKE128_BLOCKSIZE to be a multiple of 3"
  26. #endif
  27. typedef int (COEFF_FROM_NIBBLE_FUNC)(uint32_t nibble, uint32_t *out);
  28. static COEFF_FROM_NIBBLE_FUNC coeff_from_nibble_4;
  29. static COEFF_FROM_NIBBLE_FUNC coeff_from_nibble_2;
  30. /**
  31. * @brief Combine 3 bytes to form an coefficient.
  32. * See FIPS 204, Algorithm 14, CoeffFromThreeBytes()
  33. *
  34. * This is not constant time as it is used to generate the matrix A which is public.
  35. *
  36. * @param s A byte array of 3 uniformly distributed bytes.
  37. * @param out The returned coefficient in the range 0..q-1.
  38. * @returns 1 if the value is less than q or 0 otherwise.
  39. * This is used for rejection sampling.
  40. */
  41. static ossl_inline int coeff_from_three_bytes(const uint8_t *s, uint32_t *out)
  42. {
  43. /* Zero out the top bit of the 3rd byte to get a value in the range 0..2^23-1) */
  44. *out = (uint32_t)s[0] | ((uint32_t)s[1] << 8) | (((uint32_t)s[2] & 0x7f) << 16);
  45. return *out < ML_DSA_Q;
  46. }
  47. /**
  48. * @brief Generate a value in the range (q-4..0..4)
  49. * See FIPS 204, Algorithm 15, CoeffFromHalfByte() where eta = 4
  50. * Note the FIPS 204 code uses the range -4..4 (whereas this code adds q to the
  51. * negative numbers).
  52. *
  53. * @param nibble A value in the range 0..15
  54. * @param out The returned value if the range (q-4)..0..4 if nibble is < 9
  55. * @returns 1 nibble was in range, or 0 if the nibble was rejected.
  56. */
  57. static ossl_inline int coeff_from_nibble_4(uint32_t nibble, uint32_t *out)
  58. {
  59. /*
  60. * This is not constant time but will not leak any important info since
  61. * the value is either chosen or thrown away.
  62. */
  63. if (value_barrier_32(nibble < 9)) {
  64. *out = mod_sub(4, nibble);
  65. return 1;
  66. }
  67. return 0;
  68. }
  69. /**
  70. * @brief Generate a value in the range (q-2..0..2)
  71. * See FIPS 204, Algorithm 15, CoeffFromHalfByte() where eta = 2
  72. * Note the FIPS 204 code uses the range -2..2 (whereas this code adds q to the
  73. * negative numbers).
  74. *
  75. * @param nibble A value in the range 0..15
  76. * @param out The returned value if the range (q-2)..0..2 if nibble is < 15
  77. * @returns 1 nibble was in range, or 0 if the nibble was rejected.
  78. */
  79. static ossl_inline int coeff_from_nibble_2(uint32_t nibble, uint32_t *out)
  80. {
  81. if (value_barrier_32(nibble < 15)) {
  82. *out = mod_sub(2, MOD5(nibble));
  83. return 1;
  84. }
  85. return 0;
  86. }
  87. /**
  88. * @brief Use a seed value to generate a polynomial with coefficients in the
  89. * range of 0..q-1 using rejection sampling.
  90. * SHAKE128 is used to absorb the seed, and then sequences of 3 sample bytes are
  91. * squeezed to try to produce coefficients.
  92. * The SHAKE128 stream is used to get uniformly distributed elements.
  93. * This algorithm is used for matrix expansion and only operates on public inputs.
  94. *
  95. * See FIPS 204, Algorithm 30, RejNTTPoly()
  96. *
  97. * @param g_ctx A EVP_MD_CTX object used for sampling the seed.
  98. * @param md A pre-fetched SHAKE128 object.
  99. * @param seed The seed to use for sampling.
  100. * @param seed_len The size of |seed|
  101. * @param out The returned polynomial with coefficients in the range of
  102. * 0..q-1. This range is required for NTT.
  103. * @returns 1 if the polynomial was successfully generated, or 0 if any of the
  104. * digest operations failed.
  105. */
  106. static int rej_ntt_poly(EVP_MD_CTX *g_ctx, const EVP_MD *md,
  107. const uint8_t *seed, size_t seed_len, POLY *out)
  108. {
  109. int j = 0;
  110. uint8_t blocks[SHAKE128_BLOCKSIZE], *b, *end = blocks + sizeof(blocks);
  111. /*
  112. * Instead of just squeezing 3 bytes at a time, we grab a whole block
  113. * Note that the shake128 blocksize of 168 is divisible by 3.
  114. */
  115. if (!shake_xof(g_ctx, md, seed, seed_len, blocks, sizeof(blocks)))
  116. return 0;
  117. while (1) {
  118. for (b = blocks; b < end; b += 3) {
  119. if (coeff_from_three_bytes(b, &(out->coeff[j]))) {
  120. if (++j >= ML_DSA_NUM_POLY_COEFFICIENTS)
  121. return 1; /* finished */
  122. }
  123. }
  124. if (!EVP_DigestSqueeze(g_ctx, blocks, sizeof(blocks)))
  125. return 0;
  126. }
  127. }
  128. /**
  129. * @brief Use a seed value to generate a polynomial with coefficients in the
  130. * range of ((q-eta)..0..eta) using rejection sampling. eta is either 2 or 4.
  131. * SHAKE256 is used to absorb the seed, and then samples are squeezed.
  132. * See FIPS 204, Algorithm 31, RejBoundedPoly()
  133. *
  134. * @param h_ctx A EVP_MD_CTX object context used to sample the seed.
  135. * @param md A pre-fetched SHAKE256 object.
  136. * @param coef_from_nibble A function that is dependent on eta, which takes a
  137. * nibble and tries to see if it is in the correct range.
  138. * @param seed The seed to use for sampling.
  139. * @param seed_len The size of |seed|
  140. * @param out The returned polynomial with coefficients in the range of
  141. * ((q-eta)..0..eta)
  142. * @returns 1 if the polynomial was successfully generated, or 0 if any of the
  143. * digest operations failed.
  144. */
  145. static int rej_bounded_poly(EVP_MD_CTX *h_ctx, const EVP_MD *md,
  146. COEFF_FROM_NIBBLE_FUNC *coef_from_nibble,
  147. const uint8_t *seed, size_t seed_len, POLY *out)
  148. {
  149. int j = 0;
  150. uint32_t z0, z1;
  151. uint8_t blocks[SHAKE256_BLOCKSIZE], *b, *end = blocks + sizeof(blocks);
  152. /* Instead of just squeezing 1 byte at a time, we grab a whole block */
  153. if (!shake_xof(h_ctx, md, seed, seed_len, blocks, sizeof(blocks)))
  154. return 0;
  155. while (1) {
  156. for (b = blocks; b < end; b++) {
  157. z0 = *b & 0x0F; /* lower nibble of byte */
  158. z1 = *b >> 4; /* high nibble of byte */
  159. if (coef_from_nibble(z0, &out->coeff[j])
  160. && ++j >= ML_DSA_NUM_POLY_COEFFICIENTS)
  161. return 1;
  162. if (coef_from_nibble(z1, &out->coeff[j])
  163. && ++j >= ML_DSA_NUM_POLY_COEFFICIENTS)
  164. return 1;
  165. }
  166. if (!EVP_DigestSqueeze(h_ctx, blocks, sizeof(blocks)))
  167. return 0;
  168. }
  169. }
  170. /**
  171. * @brief Generate a k * l matrix that has uniformly distributed polynomial
  172. * elements using rejection sampling.
  173. * See FIPS 204, Algorithm 32, ExpandA()
  174. *
  175. * @param g_ctx A EVP_MD_CTX context used for rejection sampling
  176. * seed values generated from the seed rho.
  177. * @param md A pre-fetched SHAKE128 object
  178. * @param rho A 32 byte seed to generated the matrix from.
  179. * @param out The generated k * l matrix of polynomials with coefficients
  180. * in the range of 0..q-1.
  181. * @returns 1 if the matrix was generated, or 0 on error.
  182. */
  183. int ossl_ml_dsa_matrix_expand_A(EVP_MD_CTX *g_ctx, const EVP_MD *md,
  184. const uint8_t *rho, MATRIX *out)
  185. {
  186. int ret = 0;
  187. size_t i, j;
  188. uint8_t derived_seed[ML_DSA_RHO_BYTES + 2];
  189. POLY *poly = out->m_poly;
  190. /* The seed used for each matrix element is rho + column_index + row_index */
  191. memcpy(derived_seed, rho, ML_DSA_RHO_BYTES);
  192. for (i = 0; i < out->k; i++) {
  193. for (j = 0; j < out->l; j++) {
  194. derived_seed[ML_DSA_RHO_BYTES + 1] = (uint8_t)i;
  195. derived_seed[ML_DSA_RHO_BYTES] = (uint8_t)j;
  196. /* Generate the polynomial for each matrix element using a unique seed */
  197. if (!rej_ntt_poly(g_ctx, md, derived_seed, sizeof(derived_seed), poly++))
  198. goto err;
  199. }
  200. }
  201. ret = 1;
  202. err:
  203. return ret;
  204. }
  205. /**
  206. * @brief Generates 2 vectors using rejection sampling whose polynomial
  207. * coefficients are in the interval [q-eta..0..eta]
  208. *
  209. * See FIPS 204, Algorithm 33, ExpandS().
  210. * Note that in FIPS 204 the range -eta..eta is used.
  211. *
  212. * @param h_ctx A EVP_MD_CTX context to use to sample the seed.
  213. * @param md A pre-fetched SHAKE256 object.
  214. * @param eta Is either 2 or 4, and determines the range of the coefficients for
  215. * s1 and s2.
  216. * @param seed A 64 byte seed to use for sampling.
  217. * @param s1 A 1 * l column vector containing polynomials with coefficients in
  218. * the range (q-eta)..0..eta
  219. * @param s2 A 1 * k column vector containing polynomials with coefficients in
  220. * the range (q-eta)..0..eta
  221. * @returns 1 if s1 and s2 were successfully generated, or 0 otherwise.
  222. */
  223. int ossl_ml_dsa_vector_expand_S(EVP_MD_CTX *h_ctx, const EVP_MD *md, int eta,
  224. const uint8_t *seed, VECTOR *s1, VECTOR *s2)
  225. {
  226. int ret = 0;
  227. size_t i;
  228. size_t l = s1->num_poly;
  229. size_t k = s2->num_poly;
  230. uint8_t derived_seed[ML_DSA_PRIV_SEED_BYTES + 2];
  231. COEFF_FROM_NIBBLE_FUNC *coef_from_nibble_fn;
  232. coef_from_nibble_fn = (eta == ML_DSA_ETA_4) ? coeff_from_nibble_4 : coeff_from_nibble_2;
  233. /*
  234. * Each polynomial generated uses a unique seed that consists of
  235. * seed + counter (where the counter is 2 bytes starting at 0)
  236. */
  237. memcpy(derived_seed, seed, ML_DSA_PRIV_SEED_BYTES);
  238. derived_seed[ML_DSA_PRIV_SEED_BYTES] = 0;
  239. derived_seed[ML_DSA_PRIV_SEED_BYTES + 1] = 0;
  240. for (i = 0; i < l; i++) {
  241. if (!rej_bounded_poly(h_ctx, md, coef_from_nibble_fn,
  242. derived_seed, sizeof(derived_seed), &s1->poly[i]))
  243. goto err;
  244. ++derived_seed[ML_DSA_PRIV_SEED_BYTES];
  245. }
  246. for (i = 0; i < k; i++) {
  247. if (!rej_bounded_poly(h_ctx, md, coef_from_nibble_fn,
  248. derived_seed, sizeof(derived_seed), &s2->poly[i]))
  249. goto err;
  250. ++derived_seed[ML_DSA_PRIV_SEED_BYTES];
  251. }
  252. ret = 1;
  253. err:
  254. return ret;
  255. }
  256. /* See FIPS 204, Algorithm 34, ExpandMask(), Step 4 & 5 */
  257. int ossl_ml_dsa_poly_expand_mask(POLY *out, const uint8_t *seed, size_t seed_len,
  258. uint32_t gamma1,
  259. EVP_MD_CTX *h_ctx, const EVP_MD *md)
  260. {
  261. uint8_t buf[32 * 20];
  262. size_t buf_len = 32 * (gamma1 == ML_DSA_GAMMA1_TWO_POWER_19 ? 20 : 18);
  263. return shake_xof(h_ctx, md, seed, seed_len, buf, buf_len)
  264. && ossl_ml_dsa_poly_decode_expand_mask(out, buf, buf_len, gamma1);
  265. }
  266. /*
  267. * @brief Sample a polynomial with coefficients in the range {-1..1}.
  268. * The number of non zero values (hamming weight) is given by tau
  269. *
  270. * See FIPS 204, Algorithm 29, SampleInBall()
  271. * This function is assumed to not be constant time.
  272. * The algorithm is based on Durstenfeld's version of the Fisher-Yates shuffle.
  273. *
  274. * Note that the coefficients returned by this implementation are positive
  275. * i.e one of q-1, 0, or 1.
  276. *
  277. * @param tau is the number of +1 or -1's in the polynomial 'out_c' (39, 49 or 60)
  278. * that is less than or equal to 64
  279. */
  280. int ossl_ml_dsa_poly_sample_in_ball(POLY *out_c, const uint8_t *seed, int seed_len,
  281. EVP_MD_CTX *h_ctx, const EVP_MD *md,
  282. uint32_t tau)
  283. {
  284. uint8_t block[SHAKE256_BLOCKSIZE];
  285. uint64_t signs;
  286. int offset = 8;
  287. size_t end;
  288. /*
  289. * Rather than squeeze 8 bytes followed by lots of 1 byte squeezes
  290. * the SHAKE blocksize is squeezed each time and buffered into 'block'.
  291. */
  292. if (!shake_xof(h_ctx, md, seed, seed_len, block, sizeof(block)))
  293. return 0;
  294. /*
  295. * grab the first 64 bits - since tau < 64
  296. * Each bit gives a +1 or -1 value.
  297. */
  298. OPENSSL_load_u64_le(&signs, block);
  299. poly_zero(out_c);
  300. /* Loop tau times */
  301. for (end = 256 - tau; end < 256; end++) {
  302. size_t index; /* index is a random offset to write +1 or -1 */
  303. /* rejection sample in {0..end} to choose an index to place -1 or 1 into */
  304. for (;;) {
  305. if (offset == sizeof(block)) {
  306. /* squeeze another block if the bytes from block have been used */
  307. if (!EVP_DigestSqueeze(h_ctx, block, sizeof(block)))
  308. return 0;
  309. offset = 0;
  310. }
  311. index = block[offset++];
  312. if (index <= end)
  313. break;
  314. }
  315. /*
  316. * In-place swap the coefficient we are about to replace to the end so
  317. * we don't lose any values that have been already written.
  318. */
  319. out_c->coeff[end] = out_c->coeff[index];
  320. /* set the random coefficient value to either 1 or q-1 */
  321. out_c->coeff[index] = mod_sub(1, 2 * (signs & 1));
  322. signs >>= 1; /* grab the next random bit */
  323. }
  324. return 1;
  325. }