ml_dsa_vector.h 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. /*
  2. * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License 2.0 (the "License"). You may not use
  5. * this file except in compliance with the License. You can obtain a copy
  6. * in the file LICENSE in the source distribution or at
  7. * https://www.openssl.org/source/license.html
  8. */
  9. #include <assert.h>
  10. #include "ml_dsa_poly.h"
  11. struct vector_st {
  12. POLY *poly;
  13. size_t num_poly;
  14. };
  15. /**
  16. * @brief Initialize a Vector object.
  17. *
  18. * @param v The vector to initialize.
  19. * @param polys Preallocated storage for an array of Polynomials blocks. |v|
  20. * does not own/free this.
  21. * @param num_polys The number of |polys| blocks (k or l)
  22. */
  23. static ossl_inline ossl_unused
  24. void vector_init(VECTOR *v, POLY *polys, size_t num_polys)
  25. {
  26. v->poly = polys;
  27. v->num_poly = num_polys;
  28. }
  29. static ossl_inline ossl_unused
  30. int vector_alloc(VECTOR *v, size_t num_polys)
  31. {
  32. v->poly = OPENSSL_malloc(num_polys * sizeof(POLY));
  33. if (v->poly == NULL)
  34. return 0;
  35. v->num_poly = num_polys;
  36. return 1;
  37. }
  38. static ossl_inline ossl_unused
  39. void vector_free(VECTOR *v)
  40. {
  41. OPENSSL_free(v->poly);
  42. v->poly = NULL;
  43. v->num_poly = 0;
  44. }
  45. /* @brief zeroize a vectors polynomial coefficients */
  46. static ossl_inline ossl_unused
  47. void vector_zero(VECTOR *va)
  48. {
  49. if (va->poly != NULL)
  50. memset(va->poly, 0, va->num_poly * sizeof(va->poly[0]));
  51. }
  52. /*
  53. * @brief copy a vector
  54. * The assumption is that |dst| has already been initialized
  55. */
  56. static ossl_inline ossl_unused void
  57. vector_copy(VECTOR *dst, const VECTOR *src)
  58. {
  59. assert(dst->num_poly == src->num_poly);
  60. memcpy(dst->poly, src->poly, src->num_poly * sizeof(src->poly[0]));
  61. }
  62. /* @brief return 1 if 2 vectors are equal, or 0 otherwise */
  63. static ossl_inline ossl_unused int
  64. vector_equal(const VECTOR *a, const VECTOR *b)
  65. {
  66. size_t i;
  67. if (a->num_poly != b->num_poly)
  68. return 0;
  69. for (i = 0; i < a->num_poly; ++i) {
  70. if (!poly_equal(a->poly + i, b->poly + i))
  71. return 0;
  72. }
  73. return 1;
  74. }
  75. /* @brief add 2 vectors */
  76. static ossl_inline ossl_unused void
  77. vector_add(const VECTOR *lhs, const VECTOR *rhs, VECTOR *out)
  78. {
  79. size_t i;
  80. for (i = 0; i < lhs->num_poly; i++)
  81. poly_add(lhs->poly + i, rhs->poly + i, out->poly + i);
  82. }
  83. /* @brief subtract 2 vectors */
  84. static ossl_inline ossl_unused void
  85. vector_sub(const VECTOR *lhs, const VECTOR *rhs, VECTOR *out)
  86. {
  87. size_t i;
  88. for (i = 0; i < lhs->num_poly; i++)
  89. poly_sub(lhs->poly + i, rhs->poly + i, out->poly + i);
  90. }
  91. /* @brief convert a vector in place into NTT form */
  92. static ossl_inline ossl_unused void
  93. vector_ntt(VECTOR *va)
  94. {
  95. size_t i;
  96. for (i = 0; i < va->num_poly; i++)
  97. ossl_ml_dsa_poly_ntt(va->poly + i);
  98. }
  99. /* @brief convert a vector in place into inverse NTT form */
  100. static ossl_inline ossl_unused void
  101. vector_ntt_inverse(VECTOR *va)
  102. {
  103. size_t i;
  104. for (i = 0; i < va->num_poly; i++)
  105. ossl_ml_dsa_poly_ntt_inverse(va->poly + i);
  106. }
  107. /* @brief multiply a vector by a SCALAR polynomial */
  108. static ossl_inline ossl_unused void
  109. vector_mult_scalar(const VECTOR *lhs, const POLY *rhs, VECTOR *out)
  110. {
  111. size_t i;
  112. for (i = 0; i < lhs->num_poly; i++)
  113. ossl_ml_dsa_poly_ntt_mult(lhs->poly + i, rhs, out->poly + i);
  114. }
  115. static ossl_inline ossl_unused int
  116. vector_expand_S(EVP_MD_CTX *h_ctx, const EVP_MD *md, int eta,
  117. const uint8_t *seed, VECTOR *s1, VECTOR *s2)
  118. {
  119. return ossl_ml_dsa_vector_expand_S(h_ctx, md, eta, seed, s1, s2);
  120. }
  121. static ossl_inline ossl_unused void
  122. vector_expand_mask(VECTOR *out, const uint8_t *rho_prime, size_t rho_prime_len,
  123. uint32_t kappa, uint32_t gamma1,
  124. EVP_MD_CTX *h_ctx, const EVP_MD *md)
  125. {
  126. size_t i;
  127. uint8_t derived_seed[ML_DSA_RHO_PRIME_BYTES + 2];
  128. memcpy(derived_seed, rho_prime, ML_DSA_RHO_PRIME_BYTES);
  129. for (i = 0; i < out->num_poly; i++) {
  130. size_t index = kappa + i;
  131. derived_seed[ML_DSA_RHO_PRIME_BYTES] = index & 0xFF;
  132. derived_seed[ML_DSA_RHO_PRIME_BYTES + 1] = (index >> 8) & 0xFF;
  133. poly_expand_mask(out->poly + i, derived_seed, sizeof(derived_seed),
  134. gamma1, h_ctx, md);
  135. }
  136. }
  137. /* Scale back previously rounded value */
  138. static ossl_inline ossl_unused void
  139. vector_scale_power2_round_ntt(const VECTOR *in, VECTOR *out)
  140. {
  141. size_t i;
  142. for (i = 0; i < in->num_poly; i++)
  143. poly_scale_power2_round(in->poly + i, out->poly + i);
  144. vector_ntt(out);
  145. }
  146. /*
  147. * @brief Decompose all polynomial coefficients of a vector into (t1, t0) such
  148. * that coeff[i] == t1[i] * 2^13 + t0[i] mod q.
  149. * See FIPS 204, Algorithm 35, Power2Round()
  150. */
  151. static ossl_inline ossl_unused void
  152. vector_power2_round(const VECTOR *t, VECTOR *t1, VECTOR *t0)
  153. {
  154. size_t i;
  155. for (i = 0; i < t->num_poly; i++)
  156. poly_power2_round(t->poly + i, t1->poly + i, t0->poly + i);
  157. }
  158. static ossl_inline ossl_unused void
  159. vector_high_bits(const VECTOR *in, uint32_t gamma2, VECTOR *out)
  160. {
  161. size_t i;
  162. for (i = 0; i < out->num_poly; i++)
  163. poly_high_bits(in->poly + i, gamma2, out->poly + i);
  164. }
  165. static ossl_inline ossl_unused void
  166. vector_low_bits(const VECTOR *in, uint32_t gamma2, VECTOR *out)
  167. {
  168. size_t i;
  169. for (i = 0; i < out->num_poly; i++)
  170. poly_low_bits(in->poly + i, gamma2, out->poly + i);
  171. }
  172. static ossl_inline ossl_unused uint32_t
  173. vector_max(const VECTOR *v)
  174. {
  175. size_t i;
  176. uint32_t mx = 0;
  177. for (i = 0; i < v->num_poly; i++)
  178. poly_max(v->poly + i, &mx);
  179. return mx;
  180. }
  181. static ossl_inline ossl_unused uint32_t
  182. vector_max_signed(const VECTOR *v)
  183. {
  184. size_t i;
  185. uint32_t mx = 0;
  186. for (i = 0; i < v->num_poly; i++)
  187. poly_max_signed(v->poly + i, &mx);
  188. return mx;
  189. }
  190. static ossl_inline ossl_unused size_t
  191. vector_count_ones(const VECTOR *v)
  192. {
  193. int j;
  194. size_t i, count = 0;
  195. for (i = 0; i < v->num_poly; i++)
  196. for (j = 0; j < ML_DSA_NUM_POLY_COEFFICIENTS; j++)
  197. count += v->poly[i].coeff[j];
  198. return count;
  199. }
  200. static ossl_inline ossl_unused void
  201. vector_make_hint(const VECTOR *ct0, const VECTOR *cs2, const VECTOR *w,
  202. uint32_t gamma2, VECTOR *out)
  203. {
  204. size_t i;
  205. for (i = 0; i < out->num_poly; i++)
  206. poly_make_hint(ct0->poly + i, cs2->poly + i, w->poly + i, gamma2,
  207. out->poly + i);
  208. }
  209. static ossl_inline ossl_unused void
  210. vector_use_hint(const VECTOR *h, const VECTOR *r, uint32_t gamma2, VECTOR *out)
  211. {
  212. size_t i;
  213. for (i = 0; i < out->num_poly; i++)
  214. poly_use_hint(h->poly + i, r->poly + i, gamma2, out->poly + i);
  215. }