ml_dsa_sign.c 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. /*
  2. * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License 2.0 (the "License"). You may not use
  5. * this file except in compliance with the License. You can obtain a copy
  6. * in the file LICENSE in the source distribution or at
  7. * https://www.openssl.org/source/license.html
  8. */
  9. #include <openssl/core_dispatch.h>
  10. #include <openssl/core_names.h>
  11. #include <openssl/params.h>
  12. #include <openssl/rand.h>
  13. #include "ml_dsa_local.h"
  14. #include "ml_dsa_key.h"
  15. #include "ml_dsa_matrix.h"
  16. #include "ml_dsa_sign.h"
  17. #include "ml_dsa_hash.h"
  18. #define ML_DSA_MAX_LAMBDA 256 /* bit strength for ML-DSA-87 */
  19. /*
  20. * @brief Initialize a Signature object by pointing all of its objects to
  21. * preallocated blocks. The values passed for hint, z and
  22. * c_tilde values are not owned/freed by the |sig| object.
  23. *
  24. * @param sig The ML_DSA_SIG to initialize.
  25. * @param hint A preallocated array of |k| polynomial blocks
  26. * @param k The number of |hint| polynomials
  27. * @param z A preallocated array of |l| polynomial blocks
  28. * @param l The number of |z| polynomials
  29. * @param c_tilde A preallocated buffer
  30. * @param c_tilde_len The size of |c_tilde|
  31. */
  32. static void signature_init(ML_DSA_SIG *sig,
  33. POLY *hint, uint32_t k, POLY *z, uint32_t l,
  34. uint8_t *c_tilde, size_t c_tilde_len)
  35. {
  36. vector_init(&sig->z, z, l);
  37. vector_init(&sig->hint, hint, k);
  38. sig->c_tilde = c_tilde;
  39. sig->c_tilde_len = c_tilde_len;
  40. }
  41. /*
  42. * FIPS 204, Algorithm 7, ML-DSA.Sign_internal()
  43. * @returns 1 on success and 0 on failure.
  44. */
  45. static int ml_dsa_sign_internal(const ML_DSA_KEY *priv, int msg_is_mu,
  46. const uint8_t *encoded_msg,
  47. size_t encoded_msg_len,
  48. const uint8_t *rnd, size_t rnd_len,
  49. uint8_t *out_sig)
  50. {
  51. int ret = 0;
  52. const ML_DSA_PARAMS *params = priv->params;
  53. EVP_MD_CTX *md_ctx = NULL;
  54. uint32_t k = params->k, l = params->l;
  55. uint32_t gamma1 = params->gamma1, gamma2 = params->gamma2;
  56. uint8_t *alloc = NULL, *w1_encoded;
  57. size_t alloc_len, w1_encoded_len;
  58. size_t num_polys_sig_k = 2 * k;
  59. size_t num_polys_k = 5 * k;
  60. size_t num_polys_l = 3 * l;
  61. size_t num_polys_k_by_l = k * l;
  62. POLY *polys = NULL, *p, *c_ntt;
  63. VECTOR s1_ntt, s2_ntt, t0_ntt, w, w1, cs1, cs2, y;
  64. MATRIX a_ntt;
  65. ML_DSA_SIG sig;
  66. uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
  67. const size_t mu_len = sizeof(mu);
  68. uint8_t rho_prime[ML_DSA_RHO_PRIME_BYTES];
  69. uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
  70. size_t c_tilde_len = params->bit_strength >> 2;
  71. size_t kappa;
  72. /*
  73. * Allocate a single blob for most of the variable size temporary variables.
  74. * Mostly used for VECTOR POLYNOMIALS (every POLY is 1K).
  75. */
  76. w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
  77. alloc_len = w1_encoded_len
  78. + sizeof(*polys) * (1 + num_polys_k + num_polys_l
  79. + num_polys_k_by_l + num_polys_sig_k);
  80. alloc = OPENSSL_malloc(alloc_len);
  81. if (alloc == NULL)
  82. return 0;
  83. md_ctx = EVP_MD_CTX_new();
  84. if (md_ctx == NULL)
  85. goto err;
  86. w1_encoded = alloc;
  87. /* Init the temp vectors to point to the allocated polys blob */
  88. p = (POLY *)(w1_encoded + w1_encoded_len);
  89. c_ntt = p++;
  90. matrix_init(&a_ntt, p, k, l);
  91. p += num_polys_k_by_l;
  92. vector_init(&s2_ntt, p, k);
  93. vector_init(&t0_ntt, s2_ntt.poly + k, k);
  94. vector_init(&w, t0_ntt.poly + k, k);
  95. vector_init(&w1, w.poly + k, k);
  96. vector_init(&cs2, w1.poly + k, k);
  97. p += num_polys_k;
  98. vector_init(&s1_ntt, p, l);
  99. vector_init(&y, p + l, l);
  100. vector_init(&cs1, p + 2 * l, l);
  101. p += num_polys_l;
  102. signature_init(&sig, p, k, p + k, l, c_tilde, c_tilde_len);
  103. /* End of the allocated blob setup */
  104. if (!matrix_expand_A(md_ctx, priv->shake128_md, priv->rho, &a_ntt))
  105. goto err;
  106. if (msg_is_mu) {
  107. if (encoded_msg_len != mu_len)
  108. goto err;
  109. mu_ptr = (uint8_t *)encoded_msg;
  110. } else {
  111. if (!shake_xof_2(md_ctx, priv->shake256_md, priv->tr, sizeof(priv->tr),
  112. encoded_msg, encoded_msg_len, mu_ptr, mu_len))
  113. goto err;
  114. }
  115. if (!shake_xof_3(md_ctx, priv->shake256_md, priv->K, sizeof(priv->K),
  116. rnd, rnd_len, mu_ptr, mu_len,
  117. rho_prime, sizeof(rho_prime)))
  118. goto err;
  119. vector_copy(&s1_ntt, &priv->s1);
  120. vector_ntt(&s1_ntt);
  121. vector_copy(&s2_ntt, &priv->s2);
  122. vector_ntt(&s2_ntt);
  123. vector_copy(&t0_ntt, &priv->t0);
  124. vector_ntt(&t0_ntt);
  125. /*
  126. * kappa must not exceed 2^16. But the probability of it
  127. * exceeding even 1000 iterations is vanishingly small.
  128. */
  129. for (kappa = 0; ; kappa += l) {
  130. VECTOR *y_ntt = &cs1;
  131. VECTOR *r0 = &w1;
  132. VECTOR *ct0 = &w1;
  133. uint32_t z_max, r0_max, ct0_max, h_ones;
  134. vector_expand_mask(&y, rho_prime, sizeof(rho_prime), kappa,
  135. gamma1, md_ctx, priv->shake256_md);
  136. vector_copy(y_ntt, &y);
  137. vector_ntt(y_ntt);
  138. matrix_mult_vector(&a_ntt, y_ntt, &w);
  139. vector_ntt_inverse(&w);
  140. vector_high_bits(&w, gamma2, &w1);
  141. ossl_ml_dsa_w1_encode(&w1, gamma2, w1_encoded, w1_encoded_len);
  142. if (!shake_xof_2(md_ctx, priv->shake256_md, mu_ptr, mu_len,
  143. w1_encoded, w1_encoded_len, c_tilde, c_tilde_len))
  144. break;
  145. if (!poly_sample_in_ball_ntt(c_ntt, c_tilde, c_tilde_len,
  146. md_ctx, priv->shake256_md, params->tau))
  147. break;
  148. vector_mult_scalar(&s1_ntt, c_ntt, &cs1);
  149. vector_ntt_inverse(&cs1);
  150. vector_mult_scalar(&s2_ntt, c_ntt, &cs2);
  151. vector_ntt_inverse(&cs2);
  152. vector_add(&y, &cs1, &sig.z);
  153. /* r0 = lowbits(w - cs2) */
  154. vector_sub(&w, &cs2, r0);
  155. vector_low_bits(r0, gamma2, r0);
  156. /*
  157. * Leaking that the signature is rejected is fine as the next attempt at a
  158. * signature will be (indistinguishable from) independent of this one.
  159. */
  160. z_max = vector_max(&sig.z);
  161. r0_max = vector_max_signed(r0);
  162. if (value_barrier_32(constant_time_ge(z_max, gamma1 - params->beta)
  163. | constant_time_ge(r0_max, gamma2 - params->beta)))
  164. continue;
  165. vector_mult_scalar(&t0_ntt, c_ntt, ct0);
  166. vector_ntt_inverse(ct0);
  167. vector_make_hint(ct0, &cs2, &w, gamma2, &sig.hint);
  168. ct0_max = vector_max(ct0);
  169. h_ones = vector_count_ones(&sig.hint);
  170. /* Same reasoning applies to the leak as above */
  171. if (value_barrier_32(constant_time_ge(ct0_max, gamma2)
  172. | constant_time_lt(params->omega, h_ones)))
  173. continue;
  174. ret = ossl_ml_dsa_sig_encode(&sig, params, out_sig);
  175. break;
  176. }
  177. err:
  178. EVP_MD_CTX_free(md_ctx);
  179. OPENSSL_clear_free(alloc, alloc_len);
  180. OPENSSL_cleanse(rho_prime, sizeof(rho_prime));
  181. return ret;
  182. }
  183. /*
  184. * See FIPS 204, Algorithm 8, ML-DSA.Verify_internal().
  185. */
  186. static int ml_dsa_verify_internal(const ML_DSA_KEY *pub, int msg_is_mu,
  187. const uint8_t *msg_enc, size_t msg_enc_len,
  188. const uint8_t *sig_enc, size_t sig_enc_len)
  189. {
  190. int ret = 0;
  191. uint8_t *alloc = NULL, *w1_encoded;
  192. POLY *polys = NULL, *p, *c_ntt;
  193. MATRIX a_ntt;
  194. VECTOR az_ntt, ct1_ntt, *z_ntt, *w1, *w_approx;
  195. ML_DSA_SIG sig;
  196. const ML_DSA_PARAMS *params = pub->params;
  197. uint32_t k = pub->params->k;
  198. uint32_t l = pub->params->l;
  199. uint32_t gamma2 = params->gamma2;
  200. size_t w1_encoded_len;
  201. size_t num_polys_sig = k + l;
  202. size_t num_polys_k = 2 * k;
  203. size_t num_polys_l = 1 * l;
  204. size_t num_polys_k_by_l = k * l;
  205. uint8_t mu[ML_DSA_MU_BYTES], *mu_ptr = mu;
  206. const size_t mu_len = sizeof(mu);
  207. uint8_t c_tilde[ML_DSA_MAX_LAMBDA / 4];
  208. uint8_t c_tilde_sig[ML_DSA_MAX_LAMBDA / 4];
  209. EVP_MD_CTX *md_ctx = NULL;
  210. size_t c_tilde_len = params->bit_strength >> 2;
  211. uint32_t z_max;
  212. /* Allocate space for all the POLYNOMIALS used by temporary VECTORS */
  213. w1_encoded_len = k * (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV88 ? 192 : 128);
  214. alloc = OPENSSL_malloc(w1_encoded_len
  215. + sizeof(*polys) * (1 + num_polys_k
  216. + num_polys_l
  217. + num_polys_k_by_l
  218. + num_polys_sig));
  219. if (alloc == NULL)
  220. return 0;
  221. md_ctx = EVP_MD_CTX_new();
  222. if (md_ctx == NULL)
  223. goto err;
  224. w1_encoded = alloc;
  225. /* Init the temp vectors to point to the allocated polys blob */
  226. p = (POLY *)(w1_encoded + w1_encoded_len);
  227. c_ntt = p++;
  228. matrix_init(&a_ntt, p, k, l);
  229. p += num_polys_k_by_l;
  230. signature_init(&sig, p, k, p + k, l, c_tilde_sig, c_tilde_len);
  231. p += num_polys_sig;
  232. vector_init(&az_ntt, p, k);
  233. vector_init(&ct1_ntt, p + k, k);
  234. if (!ossl_ml_dsa_sig_decode(&sig, sig_enc, sig_enc_len, pub->params)
  235. || !matrix_expand_A(md_ctx, pub->shake128_md, pub->rho, &a_ntt))
  236. goto err;
  237. if (msg_is_mu) {
  238. if (msg_enc_len != mu_len)
  239. goto err;
  240. mu_ptr = (uint8_t *)msg_enc;
  241. } else {
  242. if (!shake_xof_2(md_ctx, pub->shake256_md, pub->tr, sizeof(pub->tr),
  243. msg_enc, msg_enc_len, mu_ptr, mu_len))
  244. goto err;
  245. }
  246. /* Compute verifiers challenge c_ntt = NTT(SampleInBall(c_tilde) */
  247. if (!poly_sample_in_ball_ntt(c_ntt, c_tilde_sig, c_tilde_len,
  248. md_ctx, pub->shake256_md, params->tau))
  249. goto err;
  250. /* ct1_ntt = NTT(c) * NTT(t1 * 2^d) */
  251. vector_scale_power2_round_ntt(&pub->t1, &ct1_ntt);
  252. vector_mult_scalar(&ct1_ntt, c_ntt, &ct1_ntt);
  253. /* compute z_max early in order to reuse sig.z */
  254. z_max = vector_max(&sig.z);
  255. /* w_approx = NTT_inverse(A * NTT(z) - ct1_ntt) */
  256. z_ntt = &sig.z;
  257. vector_ntt(z_ntt);
  258. matrix_mult_vector(&a_ntt, z_ntt, &az_ntt);
  259. w_approx = &az_ntt;
  260. vector_sub(&az_ntt, &ct1_ntt, w_approx);
  261. vector_ntt_inverse(w_approx);
  262. /* compute w1_encoded */
  263. w1 = w_approx;
  264. vector_use_hint(&sig.hint, w_approx, gamma2, w1);
  265. ossl_ml_dsa_w1_encode(w1, gamma2, w1_encoded, w1_encoded_len);
  266. if (!shake_xof_3(md_ctx, pub->shake256_md, mu_ptr, mu_len,
  267. w1_encoded, w1_encoded_len, NULL, 0, c_tilde, c_tilde_len))
  268. goto err;
  269. ret = (z_max < (uint32_t)(params->gamma1 - params->beta))
  270. && memcmp(c_tilde, sig.c_tilde, c_tilde_len) == 0;
  271. err:
  272. OPENSSL_free(alloc);
  273. EVP_MD_CTX_free(md_ctx);
  274. return ret;
  275. }
  276. /**
  277. * @brief Encode a message
  278. * See FIPS 204 Algorithm 2 Step 10 (and algorithm 3 Step 5).
  279. *
  280. * ML_DSA pure signatures are encoded as M' = 00 || ctx_len || ctx || msg
  281. * Where ctx is the empty string by default and ctx_len <= 255.
  282. *
  283. * Note this code could be shared with SLH_DSA
  284. *
  285. * @param msg A message to encode
  286. * @param msg_len The size of |msg|
  287. * @param ctx An optional context to add to the message encoding.
  288. * @param ctx_len The size of |ctx|. It must be in the range 0..255
  289. * @param encode Use the Pure signature encoding if this is 1, and dont encode
  290. * if this value is 0.
  291. * @param tmp A small buffer that may be used if the message is small.
  292. * @param tmp_len The size of |tmp|
  293. * @param out_len The size of the returned encoded buffer.
  294. * @returns A buffer containing the encoded message. If the passed in
  295. * |tmp| buffer is big enough to hold the encoded message then it returns |tmp|
  296. * otherwise it allocates memory which must be freed by the caller. If |encode|
  297. * is 0 then it returns |msg|. NULL is returned if there is a failure.
  298. */
  299. static uint8_t *msg_encode(const uint8_t *msg, size_t msg_len,
  300. const uint8_t *ctx, size_t ctx_len, int encode,
  301. uint8_t *tmp, size_t tmp_len, size_t *out_len)
  302. {
  303. uint8_t *encoded = NULL;
  304. size_t encoded_len;
  305. if (encode == 0) {
  306. /* Raw message */
  307. *out_len = msg_len;
  308. return (uint8_t *)msg;
  309. }
  310. if (ctx_len > ML_DSA_MAX_CONTEXT_STRING_LEN)
  311. return NULL;
  312. /* Pure encoding */
  313. encoded_len = 1 + 1 + ctx_len + msg_len;
  314. *out_len = encoded_len;
  315. if (encoded_len <= tmp_len) {
  316. encoded = tmp;
  317. } else {
  318. encoded = OPENSSL_malloc(encoded_len);
  319. if (encoded == NULL)
  320. return NULL;
  321. }
  322. encoded[0] = 0;
  323. encoded[1] = (uint8_t)ctx_len;
  324. memcpy(&encoded[2], ctx, ctx_len);
  325. memcpy(&encoded[2 + ctx_len], msg, msg_len);
  326. return encoded;
  327. }
  328. /**
  329. * See FIPS 204 Section 5.2 Algorithm 2 ML-DSA.Sign()
  330. *
  331. * @returns 1 on success, or 0 on error.
  332. */
  333. int ossl_ml_dsa_sign(const ML_DSA_KEY *priv, int msg_is_mu,
  334. const uint8_t *msg, size_t msg_len,
  335. const uint8_t *context, size_t context_len,
  336. const uint8_t *rand, size_t rand_len, int encode,
  337. unsigned char *sig, size_t *sig_len, size_t sig_size)
  338. {
  339. int ret = 1;
  340. uint8_t m_tmp[1024], *m = m_tmp, *alloced_m = NULL;
  341. size_t m_len = 0;
  342. if (ossl_ml_dsa_key_get_priv(priv) == NULL)
  343. return 0;
  344. if (sig != NULL) {
  345. if (sig_size < priv->params->sig_len)
  346. return 0;
  347. if (msg_is_mu) {
  348. m = (uint8_t *)msg;
  349. m_len = msg_len;
  350. } else {
  351. m = msg_encode(msg, msg_len, context, context_len, encode,
  352. m_tmp, sizeof(m_tmp), &m_len);
  353. if (m == NULL)
  354. return 0;
  355. if (m != msg && m != m_tmp)
  356. alloced_m = m;
  357. }
  358. ret = ml_dsa_sign_internal(priv, msg_is_mu, m, m_len, rand, rand_len, sig);
  359. OPENSSL_free(alloced_m);
  360. }
  361. if (sig_len != NULL)
  362. *sig_len = priv->params->sig_len;
  363. return ret;
  364. }
  365. /**
  366. * See FIPS 203 Section 5.3 Algorithm 3 ML-DSA.Verify()
  367. * @returns 1 on success, or 0 on error.
  368. */
  369. int ossl_ml_dsa_verify(const ML_DSA_KEY *pub, int msg_is_mu,
  370. const uint8_t *msg, size_t msg_len,
  371. const uint8_t *context, size_t context_len, int encode,
  372. const uint8_t *sig, size_t sig_len)
  373. {
  374. uint8_t *m, *alloced_m = NULL;
  375. size_t m_len;
  376. uint8_t m_tmp[1024];
  377. int ret = 0;
  378. if (ossl_ml_dsa_key_get_pub(pub) == NULL)
  379. return 0;
  380. if (msg_is_mu) {
  381. m = (uint8_t *)msg;
  382. m_len = msg_len;
  383. } else {
  384. m = msg_encode(msg, msg_len, context, context_len, encode,
  385. m_tmp, sizeof(m_tmp), &m_len);
  386. if (m == NULL)
  387. return 0;
  388. if (m != msg && m != m_tmp)
  389. alloced_m = m;
  390. }
  391. ret = ml_dsa_verify_internal(pub, msg_is_mu, m, m_len, sig, sig_len);
  392. OPENSSL_free(alloced_m);
  393. return ret;
  394. }