ml_dsa_poly.h 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. /*
  2. * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License 2.0 (the "License"). You may not use
  5. * this file except in compliance with the License. You can obtain a copy
  6. * in the file LICENSE in the source distribution or at
  7. * https://www.openssl.org/source/license.html
  8. */
  9. #include <openssl/crypto.h>
  10. #define ML_DSA_NUM_POLY_COEFFICIENTS 256
  11. /* Polynomial object with 256 coefficients. The coefficients are unsigned 32 bits */
  12. struct poly_st {
  13. uint32_t coeff[ML_DSA_NUM_POLY_COEFFICIENTS];
  14. };
  15. static ossl_inline ossl_unused void
  16. poly_zero(POLY *p)
  17. {
  18. memset(p->coeff, 0, sizeof(*p));
  19. }
  20. /**
  21. * @brief Polynomial addition.
  22. *
  23. * @param lhs A polynomial with coefficients in the range (0..q-1)
  24. * @param rhs A polynomial with coefficients in the range (0..q-1) to add
  25. * to the 'lhs'.
  26. * @param out The returned addition result with the coefficients all in the
  27. * range 0..q-1
  28. */
  29. static ossl_inline ossl_unused void
  30. poly_add(const POLY *lhs, const POLY *rhs, POLY *out)
  31. {
  32. int i;
  33. for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
  34. out->coeff[i] = reduce_once(lhs->coeff[i] + rhs->coeff[i]);
  35. }
  36. /**
  37. * @brief Polynomial subtraction.
  38. *
  39. * @param lhs A polynomial with coefficients in the range (0..q-1)
  40. * @param rhs A polynomial with coefficients in the range (0..q-1) to subtract
  41. * from the 'lhs'.
  42. * @param out The returned subtraction result with the coefficients all in the
  43. * range 0..q-1
  44. */
  45. static ossl_inline ossl_unused void
  46. poly_sub(const POLY *lhs, const POLY *rhs, POLY *out)
  47. {
  48. int i;
  49. for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
  50. out->coeff[i] = mod_sub(lhs->coeff[i], rhs->coeff[i]);
  51. }
  52. /* @returns 1 if the polynomials are equal, or 0 otherwise */
  53. static ossl_inline ossl_unused int
  54. poly_equal(const POLY *a, const POLY *b)
  55. {
  56. return CRYPTO_memcmp(a, b, sizeof(*a)) == 0;
  57. }
  58. static ossl_inline ossl_unused void
  59. poly_ntt(POLY *p)
  60. {
  61. ossl_ml_dsa_poly_ntt(p);
  62. }
  63. static ossl_inline ossl_unused int
  64. poly_sample_in_ball_ntt(POLY *out, const uint8_t *seed, int seed_len,
  65. EVP_MD_CTX *h_ctx, const EVP_MD *md, uint32_t tau)
  66. {
  67. if (!ossl_ml_dsa_poly_sample_in_ball(out, seed, seed_len, h_ctx, md, tau))
  68. return 0;
  69. poly_ntt(out);
  70. return 1;
  71. }
  72. static ossl_inline ossl_unused int
  73. poly_expand_mask(POLY *out, const uint8_t *seed, size_t seed_len,
  74. uint32_t gamma1, EVP_MD_CTX *h_ctx, const EVP_MD *md)
  75. {
  76. return ossl_ml_dsa_poly_expand_mask(out, seed, seed_len, gamma1, h_ctx, md);
  77. }
  78. /**
  79. * @brief Decompose the coefficients of a polynomial into (r1, r0) such that
  80. * coeff[i] == t1[i] * 2^13 + t0[i] mod q
  81. * See FIPS 204, Algorithm 35, Power2Round()
  82. *
  83. * @param t A polynomial containing coefficients in the range 0..q-1
  84. * @param t1 The returned polynomial containing coefficients that represent
  85. * the top 10 MSB of each coefficient in t (i.e each ranging from 0..1023)
  86. * @param t0 The remainder coefficients of t in the range (0..4096 or q-4095..q-1)
  87. * Each t0 coefficient has an effective range of 8192 (i.e. 13 bits).
  88. */
  89. static ossl_inline ossl_unused void
  90. poly_power2_round(const POLY *t, POLY *t1, POLY *t0)
  91. {
  92. int i;
  93. for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
  94. ossl_ml_dsa_key_compress_power2_round(t->coeff[i],
  95. t1->coeff + i, t0->coeff + i);
  96. }
  97. static ossl_inline ossl_unused void
  98. poly_scale_power2_round(POLY *in, POLY *out)
  99. {
  100. int i;
  101. for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
  102. out->coeff[i] = (in->coeff[i] << ML_DSA_D_BITS);
  103. }
  104. static ossl_inline ossl_unused void
  105. poly_high_bits(const POLY *in, uint32_t gamma2, POLY *out)
  106. {
  107. int i;
  108. for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
  109. out->coeff[i] = ossl_ml_dsa_key_compress_high_bits(in->coeff[i], gamma2);
  110. }
  111. static ossl_inline ossl_unused void
  112. poly_low_bits(const POLY *in, uint32_t gamma2, POLY *out)
  113. {
  114. int i;
  115. for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
  116. out->coeff[i] = ossl_ml_dsa_key_compress_low_bits(in->coeff[i], gamma2);
  117. }
  118. static ossl_inline ossl_unused void
  119. poly_make_hint(const POLY *ct0, const POLY *cs2, const POLY *w, uint32_t gamma2,
  120. POLY *out)
  121. {
  122. int i;
  123. for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
  124. out->coeff[i] = ossl_ml_dsa_key_compress_make_hint(ct0->coeff[i],
  125. cs2->coeff[i],
  126. gamma2, w->coeff[i]);
  127. }
  128. static ossl_inline ossl_unused void
  129. poly_use_hint(const POLY *h, const POLY *r, uint32_t gamma2, POLY *out)
  130. {
  131. int i;
  132. for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++)
  133. out->coeff[i] = ossl_ml_dsa_key_compress_use_hint(h->coeff[i],
  134. r->coeff[i], gamma2);
  135. }
  136. static ossl_inline ossl_unused void
  137. poly_max(const POLY *p, uint32_t *mx)
  138. {
  139. int i;
  140. for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++) {
  141. uint32_t c = p->coeff[i];
  142. uint32_t abs = abs_mod_prime(c);
  143. *mx = maximum(*mx, abs);
  144. }
  145. }
  146. static ossl_inline ossl_unused void
  147. poly_max_signed(const POLY *p, uint32_t *mx)
  148. {
  149. int i;
  150. for (i = 0; i < ML_DSA_NUM_POLY_COEFFICIENTS; i++) {
  151. uint32_t c = p->coeff[i];
  152. uint32_t abs = abs_signed(c);
  153. *mx = maximum(*mx, abs);
  154. }
  155. }