ml_dsa_key_compress.c 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. /*
  2. * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License 2.0 (the "License"). You may not use
  5. * this file except in compliance with the License. You can obtain a copy
  6. * in the file LICENSE in the source distribution or at
  7. * https://www.openssl.org/source/license.html
  8. */
  9. #include "ml_dsa_local.h"
  10. /* Key Compression related functions (Rounding & hints) */
  11. /**
  12. * @brief Decompose r into (r1, r0) such that r == r1 * 2^13 + r0 mod q
  13. * See FIPS 204, Algorithm 35, Power2Round()
  14. *
  15. * Note: that this code is more complex than the FIPS 204 spec since it keeps
  16. * r0 as a positive number
  17. *
  18. * r mod +- 2^13 is defined as having a range of -4095..4096
  19. *
  20. * i.e for r = 0..4096 r1 = 0 and r0 = 0..4096
  21. * at r = 4097..8191 r1 = 1 and r0 = -4095..0
  22. * (but since r0 is kept positive it effectively adds q and then reduces by q if needed)
  23. * Similarly for the range r = 8192..8192+4096 r1=1 and r0=0..4096
  24. * & 12289..16383 r1=2 and r0=-4095..0
  25. *
  26. * @param r is in the range 0..q-1
  27. * @param r1 The returned top 10 MSB (i.e it ranges from 0..1023)
  28. * @param r0 The remainder in the range (0..4096 or q-4095..q-1)
  29. * So r0 has an effective range of 8192 (i.e. 13 bits).
  30. */
  31. void ossl_ml_dsa_key_compress_power2_round(uint32_t r, uint32_t *r1, uint32_t *r0)
  32. {
  33. unsigned int mask;
  34. uint32_t r0_adjusted, r1_adjusted;
  35. *r1 = r >> ML_DSA_D_BITS; /* top 13 bits */
  36. *r0 = r - (*r1 << ML_DSA_D_BITS); /* The remainder mod q */
  37. r0_adjusted = mod_sub(*r0, 1 << ML_DSA_D_BITS);
  38. r1_adjusted = *r1 + 1;
  39. /* Mask is set iff r0 > (2^(dropped_bits))/2. */
  40. mask = constant_time_lt((uint32_t)(1 << (ML_DSA_D_BITS - 1)), *r0);
  41. /* r0 = mask ? r0_adjusted : r0 */
  42. *r0 = constant_time_select_int(mask, r0_adjusted, *r0);
  43. /* r1 = mask ? r1_adjusted : r1 */
  44. *r1 = constant_time_select_int(mask, r1_adjusted, *r1);
  45. }
  46. /*
  47. * @brief return the r1 component of Decomposing r into (r1, r0) such that
  48. * r == r1 * (2 * gamma2) + r0 mod q
  49. * See FIPS 204, Algorithm 37, HighBits()
  50. *
  51. * @param r A value to decompose in the range (0..q-1)
  52. * @param gamma2 Depending on the algorithm gamma2 is either (q-1)/32 or (q-1)/88
  53. * @returns r1 (The high order bits)
  54. */
  55. uint32_t ossl_ml_dsa_key_compress_high_bits(uint32_t r, uint32_t gamma2)
  56. {
  57. int32_t r1 = (r + 127) >> 7;
  58. if (gamma2 == ML_DSA_GAMMA2_Q_MINUS1_DIV32) {
  59. r1 = (r1 * 1025 + (1 << 21)) >> 22;
  60. r1 &= 15; /* mod 16 */
  61. return r1;
  62. } else {
  63. r1 = (r1 * 11275 + (1 << 23)) >> 24;
  64. r1 ^= ((43 - r1) >> 31) & r1;
  65. return r1;
  66. }
  67. }
  68. /**
  69. * @brief Decomposes r into (r1, r0) such that r == r1 * (2*gamma2) + r0 mod q.
  70. * See FIPS 204, Algorithm 36, Decompose()
  71. *
  72. * @param r A value to decompose in the range (0..q-1)
  73. * @param gamma2 Depending on the algorithm gamma2 is either (q-1)/32 or (q-1)/88
  74. * @param r1 The returned high order bits
  75. * @param r0 The returned low order bits
  76. */
  77. void ossl_ml_dsa_key_compress_decompose(uint32_t r, uint32_t gamma2,
  78. uint32_t *r1, int32_t *r0)
  79. {
  80. *r1 = ossl_ml_dsa_key_compress_high_bits(r, gamma2);
  81. *r0 = r - *r1 * 2 * (int32_t)gamma2;
  82. *r0 -= (((int32_t)ML_DSA_Q_MINUS1_DIV2 - *r0) >> 31) & (int32_t)ML_DSA_Q;
  83. }
  84. /**
  85. * @brief return the r0 component of Decomposing r into (r1, r0) such that
  86. * r == r1 * (2 * gamma2) + r0 mod q
  87. * See FIPS 204, Algorithm 38, LowBits()
  88. *
  89. * @param r A value to decompose in the range (0..q-1)
  90. * @param gamma2 Depending on the algorithm gamma2 is either (q-1)/32 or (q-1)/88
  91. * @param r0 The returned low order bits
  92. */
  93. int32_t ossl_ml_dsa_key_compress_low_bits(uint32_t r, uint32_t gamma2)
  94. {
  95. uint32_t r1;
  96. int32_t r0;
  97. ossl_ml_dsa_key_compress_decompose(r, gamma2, &r1, &r0);
  98. return r0;
  99. }
  100. /*
  101. * @brief Computes hint bit indicating whether adding z to r alters the high
  102. * bits of r
  103. * See FIPS 204, Algorithm 39, MakeHint().
  104. *
  105. * In the spec this takes two arguments, z and r, and is called with
  106. * z = -ct0
  107. * r = w - cs2 + ct0
  108. *
  109. * It then computes HighBits (algorithm 37) of z and z+r.
  110. * But z + r is just w - cs2, so this takes three arguments and saves an addition.
  111. *
  112. * @params ct0 A polynomial c (with coefficients of (-1,0,1)) multiplied by the
  113. * polynomial vector t0 (which encodes the least significant bits of each coefficient of the
  114. uncompressed public-key polynomial t)
  115. * @params cs2 A polynomial c (with coefficients of (-1,0,1)) multiplied by s2 (a secret polynomial)
  116. * @params gamma2 Depending on the algorithm gamma2 is either (q-1)/32 or (q-1)/88
  117. * @params w (A * y)
  118. * @returns The hint bit.
  119. */
  120. int32_t ossl_ml_dsa_key_compress_make_hint(uint32_t ct0, uint32_t cs2,
  121. uint32_t gamma2, uint32_t w)
  122. {
  123. uint32_t r_plus_z = mod_sub(w, cs2);
  124. uint32_t r = reduce_once(r_plus_z + ct0);
  125. return ossl_ml_dsa_key_compress_high_bits(r, gamma2)
  126. != ossl_ml_dsa_key_compress_high_bits(r_plus_z, gamma2);
  127. }
  128. /*
  129. * @brief Returns the high bits of |r| adjusted according to hint |h|.
  130. * FIPS 204, Algorithm 40, UseHint().
  131. * This is not constant time.
  132. *
  133. * @param hint The hint bit which is either 0 or 1
  134. * @param r A value to decompose in the range (0..q-1)
  135. * @param gamma2 Depending on the algorithm gamma2 is either (q-1)/32 or (q-1)/88
  136. *
  137. * @returns The adjusted high bits or r.
  138. */
  139. uint32_t ossl_ml_dsa_key_compress_use_hint(uint32_t hint, uint32_t r,
  140. uint32_t gamma2)
  141. {
  142. uint32_t r1;
  143. int32_t r0;
  144. ossl_ml_dsa_key_compress_decompose(r, gamma2, &r1, &r0);
  145. if (hint == 0)
  146. return r1;
  147. if (gamma2 == ((ML_DSA_Q - 1) / 32)) {
  148. /* m = 16, thus |mod m| in the spec turns into |& 15| */
  149. return r0 > 0 ? (r1 + 1) & 15 : (r1 - 1) & 15;
  150. } else {
  151. /* m = 44 if gamma2 = ((q - 1) / 88) */
  152. if (r0 > 0)
  153. return (r1 == 43) ? 0 : r1 + 1;
  154. else
  155. return (r1 == 0) ? 43 : r1 - 1;
  156. }
  157. }