ml_dsa_local.h 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. /*
  2. * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License 2.0 (the "License"). You may not use
  5. * this file except in compliance with the License. You can obtain a copy
  6. * in the file LICENSE in the source distribution or at
  7. * https://www.openssl.org/source/license.html
  8. */
  9. #ifndef OSSL_CRYPTO_ML_DSA_LOCAL_H
  10. # define OSSL_CRYPTO_ML_DSA_LOCAL_H
  11. # include "crypto/ml_dsa.h"
  12. # include "internal/constant_time.h"
  13. # include "internal/packet.h"
  14. /* The following constants are shared by ML-DSA-44, ML-DSA-65 & ML-DSA-87 */
  15. # define ML_DSA_Q 8380417 /* The modulus is 23 bits (2^23 - 2^13 + 1) */
  16. # define ML_DSA_Q_MINUS1_DIV2 ((ML_DSA_Q - 1) / 2)
  17. # define ML_DSA_Q_BITS 23
  18. # define ML_DSA_Q_INV 58728449 /* q^-1 satisfies: q^-1 * q = 1 mod 2^32 */
  19. # define ML_DSA_Q_NEG_INV 4236238847 /* Inverse of -q modulo 2^32 */
  20. # define ML_DSA_DEGREE_INV_MONTGOMERY 41978 /* Inverse of 256 mod q, in Montgomery form. */
  21. # define ML_DSA_D_BITS 13 /* The number of bits dropped from the public vector t */
  22. # define ML_DSA_NUM_POLY_COEFFICIENTS 256 /* The number of coefficients in the polynomials */
  23. # define ML_DSA_RHO_BYTES 32 /* p = Public Random Seed */
  24. # define ML_DSA_PRIV_SEED_BYTES 64 /* p' = Private random seed */
  25. # define ML_DSA_K_BYTES 32 /* K = Private random seed for signing */
  26. # define ML_DSA_TR_BYTES 64 /* Size of the Hash of the public key used for signing */
  27. # define ML_DSA_MU_BYTES 64 /* Size of the Hash for the message representative */
  28. # define ML_DSA_RHO_PRIME_BYTES 64 /* private random seed size */
  29. /*
  30. * There is special case code related to encoding/decoding that tests the
  31. * for the following values.
  32. */
  33. /*
  34. * The possible value for eta - If a new value is added, then all code
  35. * that accesses ML_DSA_ETA_4 would need to be modified.
  36. */
  37. # define ML_DSA_ETA_4 4
  38. # define ML_DSA_ETA_2 2
  39. /*
  40. * The possible values of gamma1 - If a new value is added, then all code
  41. * that accesses ML_DSA_GAMMA1_TWO_POWER_19 would need to be modified.
  42. */
  43. # define ML_DSA_GAMMA1_TWO_POWER_19 (1 << 19)
  44. # define ML_DSA_GAMMA1_TWO_POWER_17 (1 << 17)
  45. /*
  46. * The possible values for gamma2 - If a new value is added, then all code
  47. * that accesses ML_DSA_GAMMA2_Q_MINUS1_DIV32 would need to be modified.
  48. */
  49. # define ML_DSA_GAMMA2_Q_MINUS1_DIV32 ((ML_DSA_Q - 1) / 32)
  50. # define ML_DSA_GAMMA2_Q_MINUS1_DIV88 ((ML_DSA_Q - 1) / 88)
  51. typedef struct poly_st POLY;
  52. typedef struct vector_st VECTOR;
  53. typedef struct matrix_st MATRIX;
  54. typedef struct ml_dsa_sig_st ML_DSA_SIG;
  55. int ossl_ml_dsa_matrix_expand_A(EVP_MD_CTX *g_ctx, const EVP_MD *md,
  56. const uint8_t *rho, MATRIX *out);
  57. int ossl_ml_dsa_vector_expand_S(EVP_MD_CTX *h_ctx, const EVP_MD *md, int eta,
  58. const uint8_t *seed, VECTOR *s1, VECTOR *s2);
  59. void ossl_ml_dsa_matrix_mult_vector(const MATRIX *matrix_kl, const VECTOR *vl,
  60. VECTOR *vk);
  61. int ossl_ml_dsa_poly_expand_mask(POLY *out, const uint8_t *seed, size_t seed_len,
  62. uint32_t gamma1,
  63. EVP_MD_CTX *h_ctx, const EVP_MD *md);
  64. int ossl_ml_dsa_poly_sample_in_ball(POLY *out_c, const uint8_t *seed, int seed_len,
  65. EVP_MD_CTX *h_ctx, const EVP_MD *md,
  66. uint32_t tau);
  67. void ossl_ml_dsa_poly_ntt(POLY *s);
  68. void ossl_ml_dsa_poly_ntt_inverse(POLY *s);
  69. void ossl_ml_dsa_poly_ntt_mult(const POLY *lhs, const POLY *rhs, POLY *out);
  70. void ossl_ml_dsa_key_compress_power2_round(uint32_t r, uint32_t *r1, uint32_t *r0);
  71. uint32_t ossl_ml_dsa_key_compress_high_bits(uint32_t r, uint32_t gamma2);
  72. void ossl_ml_dsa_key_compress_decompose(uint32_t r, uint32_t gamma2,
  73. uint32_t *r1, int32_t *r0);
  74. void ossl_ml_dsa_key_compress_decompose(uint32_t r, uint32_t gamma2,
  75. uint32_t *r1, int32_t *r0);
  76. int32_t ossl_ml_dsa_key_compress_low_bits(uint32_t r, uint32_t gamma2);
  77. int32_t ossl_ml_dsa_key_compress_make_hint(uint32_t ct0, uint32_t cs2,
  78. uint32_t gamma2, uint32_t w);
  79. uint32_t ossl_ml_dsa_key_compress_use_hint(uint32_t hint, uint32_t r,
  80. uint32_t gamma2);
  81. int ossl_ml_dsa_pk_encode(ML_DSA_KEY *key);
  82. int ossl_ml_dsa_sk_encode(ML_DSA_KEY *key);
  83. int ossl_ml_dsa_sig_encode(const ML_DSA_SIG *sig, const ML_DSA_PARAMS *params,
  84. uint8_t *out);
  85. int ossl_ml_dsa_sig_decode(ML_DSA_SIG *sig, const uint8_t *in, size_t in_len,
  86. const ML_DSA_PARAMS *params);
  87. int ossl_ml_dsa_w1_encode(const VECTOR *w1, uint32_t gamma2,
  88. uint8_t *out, size_t out_len);
  89. int ossl_ml_dsa_poly_decode_expand_mask(POLY *out,
  90. const uint8_t *in, size_t in_len,
  91. uint32_t gamma1);
  92. /*
  93. * @brief Reduces x mod q in constant time
  94. * i.e. return x < q ? x : x - q;
  95. *
  96. * @param x Where x is assumed to be in the range 0 <= x < 2*q
  97. * @returns the difference in the range 0..q-1
  98. */
  99. static ossl_inline ossl_unused uint32_t reduce_once(uint32_t x)
  100. {
  101. return constant_time_select_32(constant_time_lt_32(x, ML_DSA_Q), x, x - ML_DSA_Q);
  102. }
  103. /*
  104. * @brief Calculate The positive value of (a-b) mod q in constant time.
  105. *
  106. * a - b mod q gives a value in the range -(q-1)..(q-1)
  107. * By adding q we get a range of 1..(2q-1).
  108. * Reducing this once then gives the range 0..q-1
  109. *
  110. * @param a The minuend assumed to be in the range 0..q-1
  111. * @param b The subtracthend assumed to be in the range 0..q-1.
  112. * @returns The value (q + a - b) mod q
  113. */
  114. static ossl_inline ossl_unused uint32_t mod_sub(uint32_t a, uint32_t b)
  115. {
  116. return reduce_once(ML_DSA_Q + a - b);
  117. }
  118. /*
  119. * @brief Returns the absolute value in constant time.
  120. * i.e. return is_positive(x) ? x : -x;
  121. */
  122. static ossl_inline ossl_unused uint32_t abs_signed(uint32_t x)
  123. {
  124. return constant_time_select_32(constant_time_lt_32(x, 0x80000000), x, 0u - x);
  125. }
  126. /*
  127. * @brief Returns the absolute value modulo q in constant time
  128. * i.e return x > (q - 1) / 2 ? q - x : x;
  129. */
  130. static ossl_inline ossl_unused uint32_t abs_mod_prime(uint32_t x)
  131. {
  132. return constant_time_select_32(constant_time_lt_32(ML_DSA_Q_MINUS1_DIV2, x),
  133. ML_DSA_Q - x, x);
  134. }
  135. /*
  136. * @brief Returns the maximum of two values in constant time.
  137. * i.e return x < y ? y : x;
  138. */
  139. static ossl_inline ossl_unused uint32_t maximum(uint32_t x, uint32_t y)
  140. {
  141. return constant_time_select_int(constant_time_lt(x, y), y, x);
  142. }
  143. #endif /* OSSL_CRYPTO_ML_DSA_LOCAL_H */