|  | @@ -0,0 +1,1889 @@
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Implementation of OpenSSH 9.x's hybrid key exchange protocol
 | 
	
		
			
				|  |  | + * [email protected] .
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * This consists of the 'Streamlined NTRU Prime' quantum-resistant
 | 
	
		
			
				|  |  | + * cryptosystem, run in parallel with ordinary Curve25519 to generate
 | 
	
		
			
				|  |  | + * a shared secret combining the output of both systems.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * (Hence, even if you don't trust this newfangled NTRU Prime thing at
 | 
	
		
			
				|  |  | + * all, it's at least no _less_ secure than the kex you were using
 | 
	
		
			
				|  |  | + * already.)
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * References for the NTRU Prime cryptosystem, up to and including
 | 
	
		
			
				|  |  | + * binary encodings of public and private keys and the exact preimages
 | 
	
		
			
				|  |  | + * of the hashes used in key exchange:
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * https://ntruprime.cr.yp.to/
 | 
	
		
			
				|  |  | + * https://ntruprime.cr.yp.to/nist/ntruprime-20201007.pdf
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * The SSH protocol layer is not documented anywhere I could find (as
 | 
	
		
			
				|  |  | + * of 2022-04-15, not even in OpenSSH's PROTOCOL.* files). I had to
 | 
	
		
			
				|  |  | + * read OpenSSH's source code to find out how it worked, and the
 | 
	
		
			
				|  |  | + * answer is as follows:
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * This hybrid kex method is treated for SSH purposes as a form of
 | 
	
		
			
				|  |  | + * elliptic-curve Diffie-Hellman, and shares the same SSH message
 | 
	
		
			
				|  |  | + * sequence: client sends SSH2_MSG_KEX_ECDH_INIT containing its public
 | 
	
		
			
				|  |  | + * half, server responds with SSH2_MSG_KEX_ECDH_REPLY containing _its_
 | 
	
		
			
				|  |  | + * public half plus the host key and signature on the shared secret.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * (This is a bit of a fudge, because unlike actual ECDH, this kex
 | 
	
		
			
				|  |  | + * method is asymmetric: one side sends a public key, and the other
 | 
	
		
			
				|  |  | + * side encrypts something with it and sends the ciphertext back. So
 | 
	
		
			
				|  |  | + * while the normal ECDH implementations can compute the two sides
 | 
	
		
			
				|  |  | + * independently in parallel, this system reusing the same messages
 | 
	
		
			
				|  |  | + * has to be serial. But the order of the messages _is_ firmly
 | 
	
		
			
				|  |  | + * specified in SSH ECDH, so it works anyway.)
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * For this kex method, SSH2_MSG_KEX_ECDH_INIT still contains a single
 | 
	
		
			
				|  |  | + * SSH 'string', which consists of the concatenation of a Streamlined
 | 
	
		
			
				|  |  | + * NTRU Prime public key with the Curve25519 public value. (Both of
 | 
	
		
			
				|  |  | + * these have fixed length in bytes, so there's no ambiguity in the
 | 
	
		
			
				|  |  | + * concatenation.)
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * SSH2_MSG_KEX_ECDH_REPLY is mostly the same as usual. The only
 | 
	
		
			
				|  |  | + * string in the packet that varies is the second one, which would
 | 
	
		
			
				|  |  | + * normally contain the server's public elliptic curve point. Instead,
 | 
	
		
			
				|  |  | + * it now contains the concatenation of
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + *  - a Streamlined NTRU Prime ciphertext
 | 
	
		
			
				|  |  | + *  - the 'confirmation hash' specified in ntruprime-20201007.pdf,
 | 
	
		
			
				|  |  | + *    hashing the plaintext of that ciphertext together with the
 | 
	
		
			
				|  |  | + *    public key
 | 
	
		
			
				|  |  | + *  - the Curve25519 public point as usual.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * Again, all three of those elements have fixed lengths.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * The client decrypts the ciphertext, checks the confirmation hash,
 | 
	
		
			
				|  |  | + * and if successful, generates the 'session hash' specified in
 | 
	
		
			
				|  |  | + * ntruprime-20201007.pdf, which is 32 bytes long and is the ultimate
 | 
	
		
			
				|  |  | + * output of the Streamlined NTRU Prime key exchange.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * The output of the hybrid kex method as a whole is an SSH 'string'
 | 
	
		
			
				|  |  | + * of length 64 containing the SHA-512 hash of the concatenatio of
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + *  - the Streamlined NTRU Prime session hash (32 bytes)
 | 
	
		
			
				|  |  | + *  - the Curve25519 shared secret (32 bytes).
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * That string is included directly into the SSH exchange hash and key
 | 
	
		
			
				|  |  | + * derivation hashes, in place of the mpint that comes out of most
 | 
	
		
			
				|  |  | + * other kex methods.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +#include <stdio.h>
 | 
	
		
			
				|  |  | +#include <stdlib.h>
 | 
	
		
			
				|  |  | +#include <assert.h>
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +#include "putty.h"
 | 
	
		
			
				|  |  | +#include "ssh.h"
 | 
	
		
			
				|  |  | +#include "mpint.h"
 | 
	
		
			
				|  |  | +#include "ntru.h"
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* ----------------------------------------------------------------------
 | 
	
		
			
				|  |  | + * Preliminaries: we're going to need to do modular arithmetic on
 | 
	
		
			
				|  |  | + * small values (considerably smaller than 2^16), and we need to do it
 | 
	
		
			
				|  |  | + * without using integer division which might not be time-safe.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * The strategy for this is the same as I used in
 | 
	
		
			
				|  |  | + * mp_mod_known_integer: see there for the proofs. The basic idea is
 | 
	
		
			
				|  |  | + * that we precompute the reciprocal of our modulus as a fixed-point
 | 
	
		
			
				|  |  | + * number, and use that to get an approximate quotient which we
 | 
	
		
			
				|  |  | + * subtract off. For these integer sizes, precomputing a fixed-point
 | 
	
		
			
				|  |  | + * reciprocal of the form (2^48 / modulus) leaves us at most off by 1
 | 
	
		
			
				|  |  | + * in the quotient, so there's a single (time-safe) trial subtraction
 | 
	
		
			
				|  |  | + * at the end.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * (It's possible that some speed could be gained by not reducing
 | 
	
		
			
				|  |  | + * fully at every step. But then you'd have to carefully identify all
 | 
	
		
			
				|  |  | + * the places in the algorithm where things are compared to zero. This
 | 
	
		
			
				|  |  | + * was the easiest way to get it all working in the first place.)
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* Precompute the reciprocal */
 | 
	
		
			
				|  |  | +static uint64_t reciprocal_for_reduction(uint16_t q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    return ((uint64_t)1 << 48) / q;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* Reduce x mod q, assuming qrecip == reciprocal_for_reduction(q) */
 | 
	
		
			
				|  |  | +static uint16_t reduce(uint32_t x, uint16_t q, uint64_t qrecip)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    uint64_t unshifted_quot = x * qrecip;
 | 
	
		
			
				|  |  | +    uint64_t quot = unshifted_quot >> 48;
 | 
	
		
			
				|  |  | +    uint16_t reduced = x - quot * q;
 | 
	
		
			
				|  |  | +    reduced -= q * (1 & ((q-1 - reduced) >> 15));
 | 
	
		
			
				|  |  | +    return reduced;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* Reduce x mod q as above, but also return the quotient */
 | 
	
		
			
				|  |  | +static uint16_t reduce_with_quot(uint32_t x, uint32_t *quot_out,
 | 
	
		
			
				|  |  | +                                 uint16_t q, uint64_t qrecip)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    uint64_t unshifted_quot = x * qrecip;
 | 
	
		
			
				|  |  | +    uint64_t quot = unshifted_quot >> 48;
 | 
	
		
			
				|  |  | +    uint16_t reduced = x - quot * q;
 | 
	
		
			
				|  |  | +    uint64_t extraquot = (1 & ((q-1 - reduced) >> 15));
 | 
	
		
			
				|  |  | +    reduced -= extraquot * q;
 | 
	
		
			
				|  |  | +    *quot_out = quot + extraquot;
 | 
	
		
			
				|  |  | +    return reduced;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* Invert x mod q, assuming it's nonzero. (For time-safety, no check
 | 
	
		
			
				|  |  | + * is made for zero; it just returns 0.) */
 | 
	
		
			
				|  |  | +static uint16_t invert(uint16_t x, uint16_t q, uint64_t qrecip)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    /* Fermat inversion: compute x^(q-2), since x^(q-1) == 1. */
 | 
	
		
			
				|  |  | +    uint32_t sq = x, bit = 1, acc = 1, exp = q-2;
 | 
	
		
			
				|  |  | +    while (1) {
 | 
	
		
			
				|  |  | +        if (exp & bit) {
 | 
	
		
			
				|  |  | +            acc = reduce(acc * sq, q, qrecip);
 | 
	
		
			
				|  |  | +            exp &= ~bit;
 | 
	
		
			
				|  |  | +            if (!exp)
 | 
	
		
			
				|  |  | +                return acc;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        sq = reduce(sq * sq, q, qrecip);
 | 
	
		
			
				|  |  | +        bit <<= 1;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* Check whether x == 0, time-safely, and return 1 if it is or 0 otherwise. */
 | 
	
		
			
				|  |  | +static unsigned iszero(uint16_t x)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    return 1 & ~((x + 0xFFFF) >> 16);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Handy macros to cut down on all those extra function parameters. In
 | 
	
		
			
				|  |  | + * the common case where a function is working mod the same modulus
 | 
	
		
			
				|  |  | + * throughout (and has called it q), you can just write 'SETUP;' at
 | 
	
		
			
				|  |  | + * the top and then call REDUCE(...) and INVERT(...) without having to
 | 
	
		
			
				|  |  | + * write out q and qrecip every time.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +#define SETUP uint64_t qrecip = reciprocal_for_reduction(q)
 | 
	
		
			
				|  |  | +#define REDUCE(x) reduce(x, q, qrecip)
 | 
	
		
			
				|  |  | +#define INVERT(x) invert(x, q, qrecip)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* ----------------------------------------------------------------------
 | 
	
		
			
				|  |  | + * Quotient-ring functions.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * NTRU Prime works with two similar but different quotient rings:
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + *   Z_q[x] / <x^p-x-1>      where p,q are the prime parameters of the system
 | 
	
		
			
				|  |  | + *   Z_3[x] / <x^p-x-1>      with the same p, but coefficients mod 3.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * The former is a field (every nonzero element is invertible),
 | 
	
		
			
				|  |  | + * because the system parameters are chosen such that x^p-x-1 is
 | 
	
		
			
				|  |  | + * invertible over Z_q. The latter is not a field (or not necessarily,
 | 
	
		
			
				|  |  | + * and in particular, not for the value of p we use here).
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * In these core functions, you pass in the modulus you want as the
 | 
	
		
			
				|  |  | + * parameter q, which is either the 'real' q specified in the system
 | 
	
		
			
				|  |  | + * parameters, or 3 if you're doing one of the mod-3 parts of the
 | 
	
		
			
				|  |  | + * algorithm.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Multiply two elements of a quotient ring.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * 'a' and 'b' are arrays of exactly p coefficients, with constant
 | 
	
		
			
				|  |  | + * term first. 'out' is an array the same size to write the inverse
 | 
	
		
			
				|  |  | + * into.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_ring_multiply(uint16_t *out, const uint16_t *a, const uint16_t *b,
 | 
	
		
			
				|  |  | +                        unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    SETUP;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Strategy: just compute the full product with 2p coefficients,
 | 
	
		
			
				|  |  | +     * and then reduce it mod x^p-x-1 by working downwards from the
 | 
	
		
			
				|  |  | +     * top coefficient replacing x^{p+k} with (x+1)x^k for k = ...,1,0.
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * Possibly some speed could be gained here by doing the recursive
 | 
	
		
			
				|  |  | +     * Karatsuba optimisation for the initial multiplication? But I
 | 
	
		
			
				|  |  | +     * haven't tried it.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint32_t *unreduced = snewn(2*p, uint32_t);
 | 
	
		
			
				|  |  | +    for (unsigned i = 0; i < 2*p; i++)
 | 
	
		
			
				|  |  | +        unreduced[i] = 0;
 | 
	
		
			
				|  |  | +    for (unsigned i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        for (unsigned j = 0; j < p; j++)
 | 
	
		
			
				|  |  | +            unreduced[i+j] = REDUCE(unreduced[i+j] + a[i] * b[j]);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    for (unsigned i = 2*p - 1; i >= p; i--) {
 | 
	
		
			
				|  |  | +        unreduced[i-p] += unreduced[i];
 | 
	
		
			
				|  |  | +        unreduced[i-p+1] += unreduced[i];
 | 
	
		
			
				|  |  | +        unreduced[i] = 0;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    for (unsigned i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        out[i] = REDUCE(unreduced[i]);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    smemclr(unreduced, 2*p * sizeof(*unreduced));
 | 
	
		
			
				|  |  | +    sfree(unreduced);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Invert an element of the quotient ring.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * 'in' is an array of exactly p coefficients, with constant term
 | 
	
		
			
				|  |  | + * first. 'out' is an array the same size to write the inverse into.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * Method: essentially Stein's gcd algorithm, taking the gcd of the
 | 
	
		
			
				|  |  | + * input (regarded as an element of Z_q[x] proper) and x^p-x-1. Given
 | 
	
		
			
				|  |  | + * two polynomials over a field which are not both divisible by x, you
 | 
	
		
			
				|  |  | + * can find their gcd by iterating the following procedure:
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + *  - if one is divisible by x, divide off x
 | 
	
		
			
				|  |  | + *  - otherwise, subtract from the higher-degree one whatever scalar
 | 
	
		
			
				|  |  | + *    multiple of the lower-degree one will make it divisible by x,
 | 
	
		
			
				|  |  | + *    and _then_ divide off x
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * Neither of these types of step changes the gcd of the two
 | 
	
		
			
				|  |  | + * polynomials.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * Each step reduces the sum of the two polynomials' degree by at
 | 
	
		
			
				|  |  | + * least one, as long as at least one of the degrees is positive.
 | 
	
		
			
				|  |  | + * (Maybe more than one if all the stars align in the second case, if
 | 
	
		
			
				|  |  | + * the subtraction cancels the leading term as well as the constant
 | 
	
		
			
				|  |  | + * term.) So in at most deg A + deg B steps, we must have reached the
 | 
	
		
			
				|  |  | + * situation where both polys are constants; in one more step after
 | 
	
		
			
				|  |  | + * that, one of them will be zero; and in one step after _that_, the
 | 
	
		
			
				|  |  | + * zero one will reliably be the one we're dividing by x. Or rather,
 | 
	
		
			
				|  |  | + * that's what happens in the case where A,B are coprime; if not, then
 | 
	
		
			
				|  |  | + * one hits zero while the other is still nonzero.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * In a normal gcd algorithm, you'd track a linear combination of the
 | 
	
		
			
				|  |  | + * two original polynomials that yields each working value, and end up
 | 
	
		
			
				|  |  | + * with a linear combination of the inputs that yields the gcd. In
 | 
	
		
			
				|  |  | + * this algorithm, the 'divide off x' step makes that awkward - but we
 | 
	
		
			
				|  |  | + * can solve that by instead multiplying by the inverse of x in the
 | 
	
		
			
				|  |  | + * ring that we want our answer to be valid in! And since the modulus
 | 
	
		
			
				|  |  | + * polynomial of the ring is x^p-x-1, the inverse of x is easy to
 | 
	
		
			
				|  |  | + * calculate, because it's always just x^{p-1} - 1, which is also very
 | 
	
		
			
				|  |  | + * easy to multiply by.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +unsigned ntru_ring_invert(uint16_t *out, const uint16_t *in,
 | 
	
		
			
				|  |  | +                          unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    SETUP;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Size of the polynomial arrays we'll work with */
 | 
	
		
			
				|  |  | +    const size_t SIZE = p+1;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Number of steps of the algorithm is the max possible value of
 | 
	
		
			
				|  |  | +     * deg A + deg B + 2, where deg A <= p-1 and deg B = p */
 | 
	
		
			
				|  |  | +    const size_t STEPS = 2*p + 1;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Our two working polynomials */
 | 
	
		
			
				|  |  | +    uint16_t *A = snewn(SIZE, uint16_t);
 | 
	
		
			
				|  |  | +    uint16_t *B = snewn(SIZE, uint16_t);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Coefficient of the input value in each one */
 | 
	
		
			
				|  |  | +    uint16_t *Ac = snewn(SIZE, uint16_t);
 | 
	
		
			
				|  |  | +    uint16_t *Bc = snewn(SIZE, uint16_t);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Initialise A to the input, and Ac correspondingly to 1 */
 | 
	
		
			
				|  |  | +    memcpy(A, in, p*sizeof(uint16_t));
 | 
	
		
			
				|  |  | +    A[p] = 0;
 | 
	
		
			
				|  |  | +    Ac[0] = 1;
 | 
	
		
			
				|  |  | +    for (size_t i = 1; i < SIZE; i++)
 | 
	
		
			
				|  |  | +        Ac[i] = 0;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Initialise B to the quotient polynomial of the ring, x^p-x-1
 | 
	
		
			
				|  |  | +     * And Bc = 0 */
 | 
	
		
			
				|  |  | +    B[0] = B[1] = q-1;
 | 
	
		
			
				|  |  | +    for (size_t i = 2; i < p; i++)
 | 
	
		
			
				|  |  | +        B[i] = 0;
 | 
	
		
			
				|  |  | +    B[p] = 1;
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < SIZE; i++)
 | 
	
		
			
				|  |  | +        Bc[i] = 0;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Run the gcd-finding algorithm. */
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < STEPS; i++) {
 | 
	
		
			
				|  |  | +        /*
 | 
	
		
			
				|  |  | +         * First swap round so that A is the one we'll be dividing by x.
 | 
	
		
			
				|  |  | +         *
 | 
	
		
			
				|  |  | +         * In the case where one of the two polys has a zero constant
 | 
	
		
			
				|  |  | +         * term, it's that one. In the other case, it's the one of
 | 
	
		
			
				|  |  | +         * smaller degree. We must compute both, and choose between
 | 
	
		
			
				|  |  | +         * them in a side-channel-safe way.
 | 
	
		
			
				|  |  | +         */
 | 
	
		
			
				|  |  | +        unsigned x_divides_A = iszero(A[0]);
 | 
	
		
			
				|  |  | +        unsigned x_divides_B = iszero(B[0]);
 | 
	
		
			
				|  |  | +        unsigned B_is_bigger = 0;
 | 
	
		
			
				|  |  | +        {
 | 
	
		
			
				|  |  | +            unsigned not_seen_top_term_of_A = 1, not_seen_top_term_of_B = 1;
 | 
	
		
			
				|  |  | +            for (size_t j = SIZE; j-- > 0 ;) {
 | 
	
		
			
				|  |  | +                not_seen_top_term_of_A &= iszero(A[j]);
 | 
	
		
			
				|  |  | +                not_seen_top_term_of_B &= iszero(B[j]);
 | 
	
		
			
				|  |  | +                B_is_bigger |= (~not_seen_top_term_of_B &
 | 
	
		
			
				|  |  | +                                not_seen_top_term_of_A);
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        unsigned need_swap = x_divides_B | (~x_divides_A & B_is_bigger);
 | 
	
		
			
				|  |  | +        uint16_t swap_mask = -need_swap;
 | 
	
		
			
				|  |  | +        for (size_t j = 0; j < SIZE; j++) {
 | 
	
		
			
				|  |  | +            uint16_t diff = (A[j] ^ B[j]) & swap_mask;
 | 
	
		
			
				|  |  | +            A[j] ^= diff;
 | 
	
		
			
				|  |  | +            B[j] ^= diff;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        for (size_t j = 0; j < SIZE; j++) {
 | 
	
		
			
				|  |  | +            uint16_t diff = (Ac[j] ^ Bc[j]) & swap_mask;
 | 
	
		
			
				|  |  | +            Ac[j] ^= diff;
 | 
	
		
			
				|  |  | +            Bc[j] ^= diff;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /*
 | 
	
		
			
				|  |  | +         * Replace A with a linear combination of both A and B that
 | 
	
		
			
				|  |  | +         * has constant term zero, which we do by calculating
 | 
	
		
			
				|  |  | +         *
 | 
	
		
			
				|  |  | +         *   (constant term of B) * A - (constant term of A) * B
 | 
	
		
			
				|  |  | +         *
 | 
	
		
			
				|  |  | +         * In one of the two cases, A's constant term is already zero,
 | 
	
		
			
				|  |  | +         * so the coefficient of B will be zero too; hence, this will
 | 
	
		
			
				|  |  | +         * do nothing useful (it will merely scale A by some scalar
 | 
	
		
			
				|  |  | +         * value), but it will take the same length of time as doing
 | 
	
		
			
				|  |  | +         * something, which is just what we want.
 | 
	
		
			
				|  |  | +         */
 | 
	
		
			
				|  |  | +        uint16_t Amult = B[0], Bmult = q - A[0];
 | 
	
		
			
				|  |  | +        for (size_t j = 0; j < SIZE; j++)
 | 
	
		
			
				|  |  | +            A[j] = REDUCE(Amult * A[j] + Bmult * B[j]);
 | 
	
		
			
				|  |  | +        /* And do the same transformation to Ac */
 | 
	
		
			
				|  |  | +        for (size_t j = 0; j < SIZE; j++)
 | 
	
		
			
				|  |  | +            Ac[j] = REDUCE(Amult * Ac[j] + Bmult * Bc[j]);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /*
 | 
	
		
			
				|  |  | +         * Now divide A by x, and compensate by multiplying Ac by
 | 
	
		
			
				|  |  | +         * x^{p-1}-1 mod x^p-x-1.
 | 
	
		
			
				|  |  | +         *
 | 
	
		
			
				|  |  | +         * That multiplication is particularly easy, precisely because
 | 
	
		
			
				|  |  | +         * x^{p-1}-1 is the multiplicative inverse of x! Each x^n term
 | 
	
		
			
				|  |  | +         * for n>0 just moves down to the x^{n-1} term, and only the
 | 
	
		
			
				|  |  | +         * constant term has to be dealt with in an interesting way.
 | 
	
		
			
				|  |  | +         */
 | 
	
		
			
				|  |  | +        for (size_t j = 1; j < SIZE; j++)
 | 
	
		
			
				|  |  | +            A[j-1] = A[j];
 | 
	
		
			
				|  |  | +        A[SIZE-1] = 0;
 | 
	
		
			
				|  |  | +        uint16_t Ac0 = Ac[0];
 | 
	
		
			
				|  |  | +        for (size_t j = 1; j < p; j++)
 | 
	
		
			
				|  |  | +            Ac[j-1] = Ac[j];
 | 
	
		
			
				|  |  | +        Ac[p-1] = Ac0;
 | 
	
		
			
				|  |  | +        Ac[0] = REDUCE(Ac[0] + q - Ac0);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Now we expect that A is 0, and B is a constant. If so, then
 | 
	
		
			
				|  |  | +     * they are coprime, and we're going to return success. If not,
 | 
	
		
			
				|  |  | +     * they have a common factor.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    unsigned success = iszero(A[0]) & (1 ^ iszero(B[0]));
 | 
	
		
			
				|  |  | +    for (size_t j = 1; j < SIZE; j++)
 | 
	
		
			
				|  |  | +        success &= iszero(A[j]) & iszero(B[j]);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * So we're going to return Bc, but first, scale it by the
 | 
	
		
			
				|  |  | +     * multiplicative inverse of the constant we ended up with in
 | 
	
		
			
				|  |  | +     * B[0].
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint16_t scale = INVERT(B[0]);
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        out[i] = REDUCE(scale * Bc[i]);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    smemclr(A, SIZE * sizeof(*A));
 | 
	
		
			
				|  |  | +    sfree(A);
 | 
	
		
			
				|  |  | +    smemclr(B, SIZE * sizeof(*B));
 | 
	
		
			
				|  |  | +    sfree(B);
 | 
	
		
			
				|  |  | +    smemclr(Ac, SIZE * sizeof(*Ac));
 | 
	
		
			
				|  |  | +    sfree(Ac);
 | 
	
		
			
				|  |  | +    smemclr(Bc, SIZE * sizeof(*Bc));
 | 
	
		
			
				|  |  | +    sfree(Bc);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    return success;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Given an array of values mod q, convert each one to its
 | 
	
		
			
				|  |  | + * minimum-absolute-value representative, and then reduce mod 3.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * Output values are 0, 1 and 0xFFFF, representing -1.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * (Normally our arrays of uint16_t are in 'minimal non-negative
 | 
	
		
			
				|  |  | + * residue' form, so the output of this function is unusual. But it's
 | 
	
		
			
				|  |  | + * useful to have it in this form so that it can be reused by
 | 
	
		
			
				|  |  | + * ntru_round3. You can put it back to the usual representation using
 | 
	
		
			
				|  |  | + * ntru_normalise, below.)
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_mod3(uint16_t *out, const uint16_t *in, unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    uint64_t qrecip = reciprocal_for_reduction(q);
 | 
	
		
			
				|  |  | +    uint64_t recip3 = reciprocal_for_reduction(3);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    unsigned bias = q/2;
 | 
	
		
			
				|  |  | +    uint16_t adjust = 3 - reduce(bias-1, 3, recip3);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    for (unsigned i = 0; i < p; i++) {
 | 
	
		
			
				|  |  | +        uint16_t val = reduce(in[i] + bias, q, qrecip);
 | 
	
		
			
				|  |  | +        uint16_t residue = reduce(val + adjust, 3, recip3);
 | 
	
		
			
				|  |  | +        out[i] = residue - 1;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Given an array of values mod q, round each one to the nearest
 | 
	
		
			
				|  |  | + * multiple of 3 to its minimum-absolute-value representative.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * Output values are signed integers coerced to uint16_t, so again,
 | 
	
		
			
				|  |  | + * use ntru_normalise afterwards to put them back to normal.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_round3(uint16_t *out, const uint16_t *in, unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    SETUP;
 | 
	
		
			
				|  |  | +    unsigned bias = q/2;
 | 
	
		
			
				|  |  | +    ntru_mod3(out, in, p, q);
 | 
	
		
			
				|  |  | +    for (unsigned i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        out[i] = REDUCE(in[i] + bias) - bias - out[i];
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Given an array of signed integers coerced to uint16_t in the range
 | 
	
		
			
				|  |  | + * [-q/2,+q/2], normalise them back to mod q values.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +static void ntru_normalise(uint16_t *out, const uint16_t *in,
 | 
	
		
			
				|  |  | +                           unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    for (unsigned i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        out[i] = in[i] + q * (in[i] >> 15);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Given an array of values mod q, add a constant to each one.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_bias(uint16_t *out, const uint16_t *in, unsigned bias,
 | 
	
		
			
				|  |  | +               unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    SETUP;
 | 
	
		
			
				|  |  | +    for (unsigned i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        out[i] = REDUCE(in[i] + bias);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Given an array of values mod q, multiply each one by a constant.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_scale(uint16_t *out, const uint16_t *in, uint16_t scale,
 | 
	
		
			
				|  |  | +                unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    SETUP;
 | 
	
		
			
				|  |  | +    for (unsigned i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        out[i] = REDUCE(in[i] * scale);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Given an array of values mod 3, convert them to values mod q in a
 | 
	
		
			
				|  |  | + * way that maps -1,0,+1 to -1,0,+1.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +static void ntru_expand(
 | 
	
		
			
				|  |  | +    uint16_t *out, const uint16_t *in, unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < p; i++) {
 | 
	
		
			
				|  |  | +        uint16_t v = in[i];
 | 
	
		
			
				|  |  | +        /* Map 2 to q-1, and leave 0 and 1 unchanged */
 | 
	
		
			
				|  |  | +        v += (v >> 1) * (q-3);
 | 
	
		
			
				|  |  | +        out[i] = v;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* ----------------------------------------------------------------------
 | 
	
		
			
				|  |  | + * Implement the binary encoding from ntruprime-20201007.pdf, which is
 | 
	
		
			
				|  |  | + * used to encode public keys and ciphertexts (though not plaintexts,
 | 
	
		
			
				|  |  | + * which are done in a much simpler way).
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * The general idea is that your encoder takes as input a list of
 | 
	
		
			
				|  |  | + * small non-negative integers (r_i), and a sequence of limits (m_i)
 | 
	
		
			
				|  |  | + * such that 0 <= r_i < m_i, and emits a sequence of bytes that encode
 | 
	
		
			
				|  |  | + * all of these as tightly as reasonably possible.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * That's more general than is really needed, because in both the
 | 
	
		
			
				|  |  | + * actual uses of this encoding, the input m_i are all the same! But
 | 
	
		
			
				|  |  | + * the array of (r_i,m_i) pairs evolves during encoding, so they don't
 | 
	
		
			
				|  |  | + * _stay_ all the same, so you still have to have all the generality.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * The encoding process makes a number of passes along the list of
 | 
	
		
			
				|  |  | + * inputs. In each step, pairs of adjacent numbers are combined into
 | 
	
		
			
				|  |  | + * one larger one by turning (r_i,m_i) and (r_{i+1},m_{i+1}) into the
 | 
	
		
			
				|  |  | + * pair (r_i + m_i r_{i+1}, m_i m_{i+1}), i.e. so that the original
 | 
	
		
			
				|  |  | + * numbers could be recovered by taking the quotient and remaiinder of
 | 
	
		
			
				|  |  | + * the new r value by m_i. Then, if the new m_i is at least 2^14, we
 | 
	
		
			
				|  |  | + * emit the low 8 bits of r_i to the output stream and reduce r_i and
 | 
	
		
			
				|  |  | + * its limit correspondingly. So at the end of the pass, we've got
 | 
	
		
			
				|  |  | + * half as many numbers still to encode, they're all still not too
 | 
	
		
			
				|  |  | + * big, and we've emitted some amount of data into the output. Then do
 | 
	
		
			
				|  |  | + * another pass, keep going until there's only one number left, and
 | 
	
		
			
				|  |  | + * emit it little-endian.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * That's all very well, but how do you decode it again? DJB exhibits
 | 
	
		
			
				|  |  | + * a pair of recursive functions that are supposed to be mutually
 | 
	
		
			
				|  |  | + * inverse, but I didn't have any confidence that I'd be able to debug
 | 
	
		
			
				|  |  | + * them sensibly if they turned out not to be (or rather, if I
 | 
	
		
			
				|  |  | + * implemented one of them wrong). So I came up with my own strategy
 | 
	
		
			
				|  |  | + * instead.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * In my strategy, we start by processing just the (m_i) into an
 | 
	
		
			
				|  |  | + * 'encoding schedule' consisting of a sequence of simple
 | 
	
		
			
				|  |  | + * instructions. The instructions operate on a FIFO queue of numbers,
 | 
	
		
			
				|  |  | + * initialised to the original (r_i). The three instruction types are:
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + *  - 'COMBINE': consume two numbers a,b from the head of the queue,
 | 
	
		
			
				|  |  | + *    combine them by calculating a + m*b for some specified m, and
 | 
	
		
			
				|  |  | + *    push the result on the tail of the queue.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + *  - 'BYTE': divide the tail element of the queue by 2^8 and emit the
 | 
	
		
			
				|  |  | + *    low bits into the output stream.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + *  - 'COPY': pop a number from the head of the queue and push it
 | 
	
		
			
				|  |  | + *    straight back on the tail. (Used for handling the leftover
 | 
	
		
			
				|  |  | + *    element at the end of a pass if the input to the pass was a list
 | 
	
		
			
				|  |  | + *    of odd length.)
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * So we effectively implement DJB's encoding process in simulation,
 | 
	
		
			
				|  |  | + * and instead of actually processing a set of (r_i), we 'compile' the
 | 
	
		
			
				|  |  | + * process into a sequence of instructions that can be handed just the
 | 
	
		
			
				|  |  | + * (r_i) later and encode them in the right way. At the end of the
 | 
	
		
			
				|  |  | + * instructions, the queue is expected to have been reduced to length
 | 
	
		
			
				|  |  | + * 1 and contain the single integer 0.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * The nice thing about this system is that each of those three
 | 
	
		
			
				|  |  | + * instructions is easy to reverse. So you can also use the same
 | 
	
		
			
				|  |  | + * instructions for decoding: start with a queue containing 0, and
 | 
	
		
			
				|  |  | + * process the instructions in reverse order and reverse sense. So
 | 
	
		
			
				|  |  | + * BYTE means to _consume_ a byte from the encoded data (starting from
 | 
	
		
			
				|  |  | + * the rightmost end) and use it to make a queue element bigger; and
 | 
	
		
			
				|  |  | + * COMBINE run in reverse pops a single element from one end of the
 | 
	
		
			
				|  |  | + * queue, divides it by m, and pushes the quotient and remainder on
 | 
	
		
			
				|  |  | + * the other end.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * (So it's easy to debug, because the queue passes through the exact
 | 
	
		
			
				|  |  | + * same sequence of states during decoding that it did during
 | 
	
		
			
				|  |  | + * encoding, just in reverse order.)
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * Also, the encoding schedule comes with information about the
 | 
	
		
			
				|  |  | + * expected size of the encoded data, because you can find that out
 | 
	
		
			
				|  |  | + * easily by just counting the BYTE commands.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +enum {
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Command values appearing in the 'ops' array. ENC_COPY and
 | 
	
		
			
				|  |  | +     * ENC_BYTE are single values; values of the form
 | 
	
		
			
				|  |  | +     * (ENC_COMBINE_BASE + m) represent a COMBINE command with
 | 
	
		
			
				|  |  | +     * parameter m.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    ENC_COPY, ENC_BYTE, ENC_COMBINE_BASE
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +struct NTRUEncodeSchedule {
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Object representing a compiled set of encoding instructions.
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * 'nvals' is the number of r_i we expect to encode. 'nops' is the
 | 
	
		
			
				|  |  | +     * number of encoding commands in the 'ops' list; 'opsize' is the
 | 
	
		
			
				|  |  | +     * physical size of the array, used during construction.
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * 'endpos' is used to avoid a last-minute faff during decoding.
 | 
	
		
			
				|  |  | +     * We implement our FIFO of integers as a ring buffer of size
 | 
	
		
			
				|  |  | +     * 'nvals'. Encoding cycles round it some number of times, and the
 | 
	
		
			
				|  |  | +     * final 0 element ends up at some random location in the array.
 | 
	
		
			
				|  |  | +     * If we know _where_ the 0 ends up during encoding, we can put
 | 
	
		
			
				|  |  | +     * the initial 0 there at the start of decoding, and then when we
 | 
	
		
			
				|  |  | +     * finish reversing all the instructions, we'll end up with the
 | 
	
		
			
				|  |  | +     * output numbers already arranged at their correct positions, so
 | 
	
		
			
				|  |  | +     * that there's no need to rotate the array at the last minute.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    size_t nvals, endpos, nops, opsize;
 | 
	
		
			
				|  |  | +    uint32_t *ops;
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +static inline void sched_append(NTRUEncodeSchedule *sched, uint16_t op)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    /* Helper function to append an operation to the schedule, and
 | 
	
		
			
				|  |  | +     * update endpos. */
 | 
	
		
			
				|  |  | +    sgrowarray(sched->ops, sched->opsize, sched->nops);
 | 
	
		
			
				|  |  | +    sched->ops[sched->nops++] = op;
 | 
	
		
			
				|  |  | +    if (op != ENC_BYTE)
 | 
	
		
			
				|  |  | +        sched->endpos = (sched->endpos + 1) % sched->nvals;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Take in the list of limit values (m_i) and compute the encoding
 | 
	
		
			
				|  |  | + * schedule.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +NTRUEncodeSchedule *ntru_encode_schedule(const uint16_t *ms_in, size_t n)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    NTRUEncodeSchedule *sched = snew(NTRUEncodeSchedule);
 | 
	
		
			
				|  |  | +    sched->nvals = n;
 | 
	
		
			
				|  |  | +    sched->endpos = n-1;
 | 
	
		
			
				|  |  | +    sched->nops = sched->opsize = 0;
 | 
	
		
			
				|  |  | +    sched->ops = NULL;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    assert(n != 0);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * 'ms' is the list of (m_i) on input to the current pass.
 | 
	
		
			
				|  |  | +     * 'ms_new' is the list output from the current pass. After each
 | 
	
		
			
				|  |  | +     * pass we swap the arrays round.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint32_t *ms = snewn(n, uint32_t);
 | 
	
		
			
				|  |  | +    uint32_t *msnew = snewn(n, uint32_t);
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < n; i++)
 | 
	
		
			
				|  |  | +        ms[i] = ms_in[i];
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    while (n > 1) {
 | 
	
		
			
				|  |  | +        size_t nnew = 0;
 | 
	
		
			
				|  |  | +        for (size_t i = 0; i < n; i += 2) {
 | 
	
		
			
				|  |  | +            if (i+1 == n) {
 | 
	
		
			
				|  |  | +                /*
 | 
	
		
			
				|  |  | +                 * Odd element at the end of the input list: just copy
 | 
	
		
			
				|  |  | +                 * it unchanged to the output.
 | 
	
		
			
				|  |  | +                 */
 | 
	
		
			
				|  |  | +                sched_append(sched, ENC_COPY);
 | 
	
		
			
				|  |  | +                msnew[nnew++] = ms[i];
 | 
	
		
			
				|  |  | +                break;
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            /*
 | 
	
		
			
				|  |  | +             * Normal case: consume two elements from the input list
 | 
	
		
			
				|  |  | +             * and combine them.
 | 
	
		
			
				|  |  | +             */
 | 
	
		
			
				|  |  | +            uint32_t m1 = ms[i], m2 = ms[i+1], m = m1*m2;
 | 
	
		
			
				|  |  | +            sched_append(sched, ENC_COMBINE_BASE + m1);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            /*
 | 
	
		
			
				|  |  | +             * And then, as long as the combined limit is big enough,
 | 
	
		
			
				|  |  | +             * emit an output byte from the bottom of it.
 | 
	
		
			
				|  |  | +             */
 | 
	
		
			
				|  |  | +            while (m >= (1<<14)) {
 | 
	
		
			
				|  |  | +                sched_append(sched, ENC_BYTE);
 | 
	
		
			
				|  |  | +                m = (m + 0xFF) >> 8;
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            /*
 | 
	
		
			
				|  |  | +             * Whatever is left after that, we emit into the output
 | 
	
		
			
				|  |  | +             * list and append to the fifo.
 | 
	
		
			
				|  |  | +             */
 | 
	
		
			
				|  |  | +            msnew[nnew++] = m;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /*
 | 
	
		
			
				|  |  | +         * End of pass. The output list of (m_i) now becomes the input
 | 
	
		
			
				|  |  | +         * list.
 | 
	
		
			
				|  |  | +         */
 | 
	
		
			
				|  |  | +        uint32_t *tmp = ms;
 | 
	
		
			
				|  |  | +        ms = msnew;
 | 
	
		
			
				|  |  | +        n = nnew;
 | 
	
		
			
				|  |  | +        msnew = tmp;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * When that loop terminates, it's because there's exactly one
 | 
	
		
			
				|  |  | +     * number left to encode. (Or, technically, _at most_ one - but we
 | 
	
		
			
				|  |  | +     * don't support encoding a completely empty list in this
 | 
	
		
			
				|  |  | +     * implementation, because what would be the point?) That number
 | 
	
		
			
				|  |  | +     * is just emitted little-endian until its limit is 1 (meaning its
 | 
	
		
			
				|  |  | +     * only possible actual value is 0).
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    assert(n == 1);
 | 
	
		
			
				|  |  | +    uint32_t m = ms[0];
 | 
	
		
			
				|  |  | +    while (m > 1) {
 | 
	
		
			
				|  |  | +        sched_append(sched, ENC_BYTE);
 | 
	
		
			
				|  |  | +        m = (m + 0xFF) >> 8;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    sfree(ms);
 | 
	
		
			
				|  |  | +    sfree(msnew);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    return sched;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +void ntru_encode_schedule_free(NTRUEncodeSchedule *sched)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    sfree(sched->ops);
 | 
	
		
			
				|  |  | +    sfree(sched);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Calculate the output length of the encoded data in bytes.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +size_t ntru_encode_schedule_length(NTRUEncodeSchedule *sched)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    size_t len = 0;
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < sched->nops; i++)
 | 
	
		
			
				|  |  | +        if (sched->ops[i] == ENC_BYTE)
 | 
	
		
			
				|  |  | +            len++;
 | 
	
		
			
				|  |  | +    return len;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Retrieve the number of items encoded. (Used by testcrypt.)
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +size_t ntru_encode_schedule_nvals(NTRUEncodeSchedule *sched)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    return sched->nvals;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Actually encode a sequence of (r_i), emitting the output bytes to
 | 
	
		
			
				|  |  | + * an arbitrary BinarySink.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_encode(NTRUEncodeSchedule *sched, const uint16_t *rs_in,
 | 
	
		
			
				|  |  | +                 BinarySink *bs)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    size_t n = sched->nvals;
 | 
	
		
			
				|  |  | +    uint32_t *rs = snewn(n, uint32_t);
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < n; i++)
 | 
	
		
			
				|  |  | +        rs[i] = rs_in[i];
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * The head and tail pointers of the queue are both 'full'. That
 | 
	
		
			
				|  |  | +     * is, rs[head] is the first element actually in the queue, and
 | 
	
		
			
				|  |  | +     * rs[tail] is the last element.
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * So you append to the queue by first advancing 'tail' and then
 | 
	
		
			
				|  |  | +     * writing to rs[tail], whereas you consume from the queue by
 | 
	
		
			
				|  |  | +     * first reading rs[head] and _then_ advancing 'head'.
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * The more normal thing would be to make 'tail' point to the
 | 
	
		
			
				|  |  | +     * first empty slot instead of the last full one. But then you'd
 | 
	
		
			
				|  |  | +     * have to faff about with modular arithmetic to find the last
 | 
	
		
			
				|  |  | +     * full slot for the BYTE command, so in this case, it's easier to
 | 
	
		
			
				|  |  | +     * do it the less usual way.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    size_t head = 0, tail = n-1;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < sched->nops; i++) {
 | 
	
		
			
				|  |  | +        uint16_t op = sched->ops[i];
 | 
	
		
			
				|  |  | +        switch (op) {
 | 
	
		
			
				|  |  | +          case ENC_BYTE:
 | 
	
		
			
				|  |  | +            put_byte(bs, rs[tail] & 0xFF);
 | 
	
		
			
				|  |  | +            rs[tail] >>= 8;
 | 
	
		
			
				|  |  | +            break;
 | 
	
		
			
				|  |  | +          case ENC_COPY: {
 | 
	
		
			
				|  |  | +            uint32_t r = rs[head];
 | 
	
		
			
				|  |  | +            head = (head + 1) % n;
 | 
	
		
			
				|  |  | +            tail = (tail + 1) % n;
 | 
	
		
			
				|  |  | +            rs[tail] = r;
 | 
	
		
			
				|  |  | +            break;
 | 
	
		
			
				|  |  | +          }
 | 
	
		
			
				|  |  | +          default: {
 | 
	
		
			
				|  |  | +            uint32_t r1 = rs[head];
 | 
	
		
			
				|  |  | +            head = (head + 1) % n;
 | 
	
		
			
				|  |  | +            uint32_t r2 = rs[head];
 | 
	
		
			
				|  |  | +            head = (head + 1) % n;
 | 
	
		
			
				|  |  | +            tail = (tail + 1) % n;
 | 
	
		
			
				|  |  | +            rs[tail] = r1 + (op - ENC_COMBINE_BASE) * r2;
 | 
	
		
			
				|  |  | +            break;
 | 
	
		
			
				|  |  | +          }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Expect that we've ended up with a single zero in the queue, at
 | 
	
		
			
				|  |  | +     * exactly the position that the setup-time analysis predicted it.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    assert(head == sched->endpos);
 | 
	
		
			
				|  |  | +    assert(tail == sched->endpos);
 | 
	
		
			
				|  |  | +    assert(rs[head] == 0);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    smemclr(rs, n * sizeof(*rs));
 | 
	
		
			
				|  |  | +    sfree(rs);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Decode a ptrlen of binary data into a sequence of (r_i). The data
 | 
	
		
			
				|  |  | + * is expected to be of exactly the right length (on pain of assertion
 | 
	
		
			
				|  |  | + * failure).
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_decode(NTRUEncodeSchedule *sched, uint16_t *rs_out, ptrlen data)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    size_t n = sched->nvals;
 | 
	
		
			
				|  |  | +    const uint8_t *base = (const uint8_t *)data.ptr;
 | 
	
		
			
				|  |  | +    const uint8_t *pos = base + data.len;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Initialise the queue to a single zero, at the 'endpos' position
 | 
	
		
			
				|  |  | +     * that will mean the final output is correctly aligned.
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * 'head' and 'tail' have the same meanings as in encoding. So
 | 
	
		
			
				|  |  | +     * 'tail' is the location that BYTE modifies and COPY and COMBINE
 | 
	
		
			
				|  |  | +     * consume from, and 'head' is the location that COPY and COMBINE
 | 
	
		
			
				|  |  | +     * push on to. As in encoding, they both point at the extremal
 | 
	
		
			
				|  |  | +     * full slots in the array.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint32_t *rs = snewn(n, uint32_t);
 | 
	
		
			
				|  |  | +    size_t head = sched->endpos, tail = head;
 | 
	
		
			
				|  |  | +    rs[tail] = 0;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    for (size_t i = sched->nops; i-- > 0 ;) {
 | 
	
		
			
				|  |  | +        uint16_t op = sched->ops[i];
 | 
	
		
			
				|  |  | +        switch (op) {
 | 
	
		
			
				|  |  | +          case ENC_BYTE: {
 | 
	
		
			
				|  |  | +            assert(pos > base);
 | 
	
		
			
				|  |  | +            uint8_t byte = *--pos;
 | 
	
		
			
				|  |  | +            rs[tail] = (rs[tail] << 8) | byte;
 | 
	
		
			
				|  |  | +            break;
 | 
	
		
			
				|  |  | +          }
 | 
	
		
			
				|  |  | +          case ENC_COPY: {
 | 
	
		
			
				|  |  | +            uint32_t r = rs[tail];
 | 
	
		
			
				|  |  | +            tail = (tail + n - 1) % n;
 | 
	
		
			
				|  |  | +            head = (head + n - 1) % n;
 | 
	
		
			
				|  |  | +            rs[head] = r;
 | 
	
		
			
				|  |  | +            break;
 | 
	
		
			
				|  |  | +          }
 | 
	
		
			
				|  |  | +          default: {
 | 
	
		
			
				|  |  | +            uint32_t r = rs[tail];
 | 
	
		
			
				|  |  | +            tail = (tail + n - 1) % n;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            uint32_t m = op - ENC_COMBINE_BASE;
 | 
	
		
			
				|  |  | +            uint64_t mrecip = reciprocal_for_reduction(m);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            uint32_t r1, r2;
 | 
	
		
			
				|  |  | +            r1 = reduce_with_quot(r, &r2, m, mrecip);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            head = (head + n - 1) % n;
 | 
	
		
			
				|  |  | +            rs[head] = r2;
 | 
	
		
			
				|  |  | +            head = (head + n - 1) % n;
 | 
	
		
			
				|  |  | +            rs[head] = r1;
 | 
	
		
			
				|  |  | +            break;
 | 
	
		
			
				|  |  | +          }
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    assert(pos == base);
 | 
	
		
			
				|  |  | +    assert(head == 0);
 | 
	
		
			
				|  |  | +    assert(tail == n-1);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < n; i++)
 | 
	
		
			
				|  |  | +        rs_out[i] = rs[i];
 | 
	
		
			
				|  |  | +    smemclr(rs, n * sizeof(*rs));
 | 
	
		
			
				|  |  | +    sfree(rs);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* ----------------------------------------------------------------------
 | 
	
		
			
				|  |  | + * The actual public-key cryptosystem.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +struct NTRUKeyPair {
 | 
	
		
			
				|  |  | +    unsigned p, q, w;
 | 
	
		
			
				|  |  | +    uint16_t *h;                       /* public key */
 | 
	
		
			
				|  |  | +    uint16_t *f3, *ginv;               /* private key */
 | 
	
		
			
				|  |  | +    uint16_t *rho;                     /* for implicit rejection */
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* Helper function to free an array of uint16_t containing a ring
 | 
	
		
			
				|  |  | + * element, clearing it on the way since some of them are sensitive. */
 | 
	
		
			
				|  |  | +static void ring_free(uint16_t *val, unsigned p)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    smemclr(val, p*sizeof(*val));
 | 
	
		
			
				|  |  | +    sfree(val);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +void ntru_keypair_free(NTRUKeyPair *keypair)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    ring_free(keypair->h, keypair->p);
 | 
	
		
			
				|  |  | +    ring_free(keypair->f3, keypair->p);
 | 
	
		
			
				|  |  | +    ring_free(keypair->ginv, keypair->p);
 | 
	
		
			
				|  |  | +    ring_free(keypair->rho, keypair->p);
 | 
	
		
			
				|  |  | +    sfree(keypair);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* Trivial accessors used by test programs. */
 | 
	
		
			
				|  |  | +unsigned ntru_keypair_p(NTRUKeyPair *keypair) { return keypair->p; }
 | 
	
		
			
				|  |  | +const uint16_t *ntru_pubkey(NTRUKeyPair *keypair) { return keypair->h; }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Generate a value of the class DJB describes as 'Short': it consists
 | 
	
		
			
				|  |  | + * of p terms that are all either 0 or +1 or -1, and exactly w of them
 | 
	
		
			
				|  |  | + * are not zero.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * Values of this kind are used for several purposes: part of the
 | 
	
		
			
				|  |  | + * private key, a plaintext, and the 'rho' fake-plaintext value used
 | 
	
		
			
				|  |  | + * for deliberately returning a duff but non-revealing session hash if
 | 
	
		
			
				|  |  | + * things go wrong.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * -1 is represented as 2 in the output array. So if you want these
 | 
	
		
			
				|  |  | + * numbers mod 3, then they come out already in the right form.
 | 
	
		
			
				|  |  | + * Otherwise, use ntru_expand.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_gen_short(uint16_t *v, unsigned p, unsigned w)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Get enough random data to generate a polynomial all of whose p
 | 
	
		
			
				|  |  | +     * terms are in {0,+1,-1}, and exactly w of them are nonzero.
 | 
	
		
			
				|  |  | +     * We'll do this by making up a completely random sequence of
 | 
	
		
			
				|  |  | +     * {+1,-1} and then setting a random subset of them to 0.
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * So we'll need p random bits to choose the nonzero values, and
 | 
	
		
			
				|  |  | +     * then (doing it the simplest way) log2(p!) bits to shuffle them,
 | 
	
		
			
				|  |  | +     * plus say 128 bits to ensure any fluctuations in uniformity are
 | 
	
		
			
				|  |  | +     * negligible.
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * log2(p!) is a pain to calculate, so we'll bound it above by
 | 
	
		
			
				|  |  | +     * p*log2(p), which we bound in turn by p*16.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    size_t randbitpos = 17 * p + 128;
 | 
	
		
			
				|  |  | +    mp_int *randdata = mp_resize(mp_random_bits(randbitpos), randbitpos + 32);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Initial value before zeroing out some terms: p randomly chosen
 | 
	
		
			
				|  |  | +     * values in {1,2}.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        v[i] = 1 + mp_get_bit(randdata, --randbitpos);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Hereafter we're going to extract random bits by multiplication,
 | 
	
		
			
				|  |  | +     * treating randdata as a large fixed-point number.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    mp_reduce_mod_2to(randdata, randbitpos);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Zero out some terms, leaving a randomly selected w of them
 | 
	
		
			
				|  |  | +     * nonzero.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint32_t nonzeros_left = w;
 | 
	
		
			
				|  |  | +    mp_int *x = mp_new(64);
 | 
	
		
			
				|  |  | +    for (size_t i = p; i-- > 0 ;) {
 | 
	
		
			
				|  |  | +        /*
 | 
	
		
			
				|  |  | +         * Pick a random number out of the number of terms remaning.
 | 
	
		
			
				|  |  | +         */
 | 
	
		
			
				|  |  | +        mp_mul_integer_into(randdata, randdata, i+1);
 | 
	
		
			
				|  |  | +        mp_rshift_fixed_into(x, randdata, randbitpos);
 | 
	
		
			
				|  |  | +        mp_reduce_mod_2to(randdata, randbitpos);
 | 
	
		
			
				|  |  | +        size_t j = mp_get_integer(x);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /*
 | 
	
		
			
				|  |  | +         * If that's less than nonzeros_left, then we're leaving this
 | 
	
		
			
				|  |  | +         * number nonzero. Otherwise we're zeroing it out.
 | 
	
		
			
				|  |  | +         */
 | 
	
		
			
				|  |  | +        uint32_t keep = (uint32_t)(j - nonzeros_left) >> 31;
 | 
	
		
			
				|  |  | +        v[i] &= -keep;            /* clear this field if keep == 0 */
 | 
	
		
			
				|  |  | +        nonzeros_left -= keep;    /* decrement counter if keep == 1 */
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    mp_free(x);
 | 
	
		
			
				|  |  | +    mp_free(randdata);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Make a single attempt at generating a key pair. This involves
 | 
	
		
			
				|  |  | + * inventing random elements of both our quotient rings and hoping
 | 
	
		
			
				|  |  | + * they're both invertible.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * They may not be, if you're unlucky. The element of Z_q/<x^p-x-1>
 | 
	
		
			
				|  |  | + * will _almost_ certainly be invertible, because that is a field, so
 | 
	
		
			
				|  |  | + * invertibility can only fail if you were so unlucky as to choose the
 | 
	
		
			
				|  |  | + * all-0s element. But the element of Z_3/<x^p-x-1> may fail to be
 | 
	
		
			
				|  |  | + * invertible because it has a common factor with x^p-x-1 (which, over
 | 
	
		
			
				|  |  | + * Z_3, is not irreducible).
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * So we can't guarantee to generate a key pair in constant time,
 | 
	
		
			
				|  |  | + * because there's no predicting how many retries we'll need. However,
 | 
	
		
			
				|  |  | + * this isn't a failure of side-channel safety, because we completely
 | 
	
		
			
				|  |  | + * discard all the random numbers and state from each failed attempt.
 | 
	
		
			
				|  |  | + * So if there were a side-channel leakage from a failure, the only
 | 
	
		
			
				|  |  | + * thing it would give away would be a bunch of random numbers that
 | 
	
		
			
				|  |  | + * turned out not to be used anyway.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * But a _successful_ call to this function should execute in a
 | 
	
		
			
				|  |  | + * secret-independent manner, and this 'make a single attempt'
 | 
	
		
			
				|  |  | + * function is exposed in the API so that 'testsc' can check that.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +NTRUKeyPair *ntru_keygen_attempt(unsigned p, unsigned q, unsigned w)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * First invent g, which is the one more likely to fail to invert.
 | 
	
		
			
				|  |  | +     * This is simply a uniformly random polynomial with p terms over
 | 
	
		
			
				|  |  | +     * Z_3. So we need p*log2(3) random bits for it, plus 128 for
 | 
	
		
			
				|  |  | +     * uniformity. It's easiest to bound log2(3) above by 2.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    size_t randbitpos = 2 * p + 128;
 | 
	
		
			
				|  |  | +    mp_int *randdata = mp_resize(mp_random_bits(randbitpos), randbitpos + 32);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Select p random values from {0,1,2}.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint16_t *g = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    mp_int *x = mp_new(64);
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < p; i++) {
 | 
	
		
			
				|  |  | +        mp_mul_integer_into(randdata, randdata, 3);
 | 
	
		
			
				|  |  | +        mp_rshift_fixed_into(x, randdata, randbitpos);
 | 
	
		
			
				|  |  | +        mp_reduce_mod_2to(randdata, randbitpos);
 | 
	
		
			
				|  |  | +        g[i] = mp_get_integer(x);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    mp_free(x);
 | 
	
		
			
				|  |  | +    mp_free(randdata);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Try to invert g over Z_3, and fail if it isn't invertible.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint16_t *ginv = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    if (!ntru_ring_invert(ginv, g, p, 3)) {
 | 
	
		
			
				|  |  | +        ring_free(g, p);
 | 
	
		
			
				|  |  | +        ring_free(ginv, p);
 | 
	
		
			
				|  |  | +        return NULL;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Fine; we have g. Now make up an f, and convert it to a
 | 
	
		
			
				|  |  | +     * polynomial over q.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint16_t *f = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    ntru_gen_short(f, p, w);
 | 
	
		
			
				|  |  | +    ntru_expand(f, f, p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Multiply f by 3.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint16_t *f3 = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    ntru_scale(f3, f, 3, p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Try to invert 3*f over Z_q. This should be _almost_ guaranteed
 | 
	
		
			
				|  |  | +     * to succeed, since Z_q/<x^p-x-1> is a field, so the only
 | 
	
		
			
				|  |  | +     * non-invertible value is 0. Even so, there _is_ one, so check
 | 
	
		
			
				|  |  | +     * the return value!
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint16_t *f3inv = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    if (!ntru_ring_invert(f3inv, f3, p, q)) {
 | 
	
		
			
				|  |  | +        ring_free(f, p);
 | 
	
		
			
				|  |  | +        ring_free(f3, p);
 | 
	
		
			
				|  |  | +        ring_free(f3inv, p);
 | 
	
		
			
				|  |  | +        ring_free(g, p);
 | 
	
		
			
				|  |  | +        ring_free(ginv, p);
 | 
	
		
			
				|  |  | +        return NULL;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Make the public key, by converting g to a polynomial over q and
 | 
	
		
			
				|  |  | +     * then multiplying by f3inv.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint16_t *g_q = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    ntru_expand(g_q, g, p, q);
 | 
	
		
			
				|  |  | +    uint16_t *h = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    ntru_ring_multiply(h, g_q, f3inv, p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Make up rho, used to substitute for the plaintext in the
 | 
	
		
			
				|  |  | +     * session hash in case of confirmation failure.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint16_t *rho = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    ntru_gen_short(rho, p, w);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * And we're done! Free everything except the pieces we're
 | 
	
		
			
				|  |  | +     * returning.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    NTRUKeyPair *keypair = snew(NTRUKeyPair);
 | 
	
		
			
				|  |  | +    keypair->p = p;
 | 
	
		
			
				|  |  | +    keypair->q = q;
 | 
	
		
			
				|  |  | +    keypair->w = w;
 | 
	
		
			
				|  |  | +    keypair->h = h;
 | 
	
		
			
				|  |  | +    keypair->f3 = f3;
 | 
	
		
			
				|  |  | +    keypair->ginv = ginv;
 | 
	
		
			
				|  |  | +    keypair->rho = rho;
 | 
	
		
			
				|  |  | +    ring_free(f, p);
 | 
	
		
			
				|  |  | +    ring_free(f3inv, p);
 | 
	
		
			
				|  |  | +    ring_free(g, p);
 | 
	
		
			
				|  |  | +    ring_free(g_q, p);
 | 
	
		
			
				|  |  | +    return keypair;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * The top-level key generation function for real use (as opposed to
 | 
	
		
			
				|  |  | + * testsc): keep trying to make a key until you succeed.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +NTRUKeyPair *ntru_keygen(unsigned p, unsigned q, unsigned w)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    while (1) {
 | 
	
		
			
				|  |  | +        NTRUKeyPair *keypair = ntru_keygen_attempt(p, q, w);
 | 
	
		
			
				|  |  | +        if (keypair)
 | 
	
		
			
				|  |  | +            return keypair;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Public-key encryption.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_encrypt(uint16_t *ciphertext, const uint16_t *plaintext,
 | 
	
		
			
				|  |  | +                  uint16_t *pubkey, unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    uint16_t *r_q = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    ntru_expand(r_q, plaintext, p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    uint16_t *unrounded = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    ntru_ring_multiply(unrounded, r_q, pubkey, p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ntru_round3(ciphertext, unrounded, p, q);
 | 
	
		
			
				|  |  | +    ntru_normalise(ciphertext, ciphertext, p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ring_free(r_q, p);
 | 
	
		
			
				|  |  | +    ring_free(unrounded, p);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Public-key decryption.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_decrypt(uint16_t *plaintext, const uint16_t *ciphertext,
 | 
	
		
			
				|  |  | +                  NTRUKeyPair *keypair)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    unsigned p = keypair->p, q = keypair->q, w = keypair->w;
 | 
	
		
			
				|  |  | +    uint16_t *tmp = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ntru_ring_multiply(tmp, ciphertext, keypair->f3, p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ntru_mod3(tmp, tmp, p, q);
 | 
	
		
			
				|  |  | +    ntru_normalise(tmp, tmp, p, 3);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ntru_ring_multiply(plaintext, tmp, keypair->ginv, p, 3);
 | 
	
		
			
				|  |  | +    ring_free(tmp, p);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * With luck, this should have recovered exactly the original
 | 
	
		
			
				|  |  | +     * plaintext. But, as per the spec, we check whether it has
 | 
	
		
			
				|  |  | +     * exactly w nonzero coefficients, and if not, then something has
 | 
	
		
			
				|  |  | +     * gone wrong - and in that situation we time-safely substitute a
 | 
	
		
			
				|  |  | +     * different output.
 | 
	
		
			
				|  |  | +     *
 | 
	
		
			
				|  |  | +     * (I don't know exactly why we do this, but I assume it's because
 | 
	
		
			
				|  |  | +     * otherwise the mis-decoded output could be made to disgorge a
 | 
	
		
			
				|  |  | +     * secret about the private key in some way.)
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    unsigned weight = p;
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        weight -= iszero(plaintext[i]);
 | 
	
		
			
				|  |  | +    unsigned ok = iszero(weight ^ w);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * The default failure return value consists of w 1s followed by
 | 
	
		
			
				|  |  | +     * 0s.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    unsigned mask = ok - 1;
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < w; i++) {
 | 
	
		
			
				|  |  | +        uint16_t diff = (1 ^ plaintext[i]) & mask;
 | 
	
		
			
				|  |  | +        plaintext[i] ^= diff;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +    for (size_t i = w; i < p; i++) {
 | 
	
		
			
				|  |  | +        uint16_t diff = (0 ^ plaintext[i]) & mask;
 | 
	
		
			
				|  |  | +        plaintext[i] ^= diff;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* ----------------------------------------------------------------------
 | 
	
		
			
				|  |  | + * Encode and decode public keys, ciphertexts and plaintexts.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * Public keys and ciphertexts use the complicated binary encoding
 | 
	
		
			
				|  |  | + * system implemented above. In both cases, the inputs are regarded as
 | 
	
		
			
				|  |  | + * symmetric about zero, and are first biased to map their most
 | 
	
		
			
				|  |  | + * negative permitted value to 0, so that they become non-negative and
 | 
	
		
			
				|  |  | + * hence suitable as inputs to the encoding system. In the case of a
 | 
	
		
			
				|  |  | + * ciphertext, where the input coefficients have also been coerced to
 | 
	
		
			
				|  |  | + * be multiples of 3, we divide by 3 as well, saving space by reducing
 | 
	
		
			
				|  |  | + * the upper bounds (m_i) on all the encoded numbers.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Compute the encoding schedule for a public key.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +static NTRUEncodeSchedule *ntru_encode_pubkey_schedule(unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    uint16_t *ms = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        ms[i] = q;
 | 
	
		
			
				|  |  | +    NTRUEncodeSchedule *sched = ntru_encode_schedule(ms, p);
 | 
	
		
			
				|  |  | +    sfree(ms);
 | 
	
		
			
				|  |  | +    return sched;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Encode a public key.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_encode_pubkey(const uint16_t *pubkey, unsigned p, unsigned q,
 | 
	
		
			
				|  |  | +                        BinarySink *bs)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    /* Compute the biased version for encoding */
 | 
	
		
			
				|  |  | +    uint16_t *biased_pubkey = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    ntru_bias(biased_pubkey, pubkey, q / 2, p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Encode it */
 | 
	
		
			
				|  |  | +    NTRUEncodeSchedule *sched = ntru_encode_pubkey_schedule(p, q);
 | 
	
		
			
				|  |  | +    ntru_encode(sched, biased_pubkey, bs);
 | 
	
		
			
				|  |  | +    ntru_encode_schedule_free(sched);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ring_free(biased_pubkey, p);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Decode a public key and write it into 'pubkey'. We also return a
 | 
	
		
			
				|  |  | + * ptrlen pointing at the chunk of data we removed from the
 | 
	
		
			
				|  |  | + * BinarySource.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +ptrlen ntru_decode_pubkey(uint16_t *pubkey, unsigned p, unsigned q,
 | 
	
		
			
				|  |  | +                          BinarySource *src)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    NTRUEncodeSchedule *sched = ntru_encode_pubkey_schedule(p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Retrieve the right number of bytes from the source */
 | 
	
		
			
				|  |  | +    size_t len = ntru_encode_schedule_length(sched);
 | 
	
		
			
				|  |  | +    ptrlen encoded = get_data(src, len);
 | 
	
		
			
				|  |  | +    if (get_err(src)) {
 | 
	
		
			
				|  |  | +        /* If there wasn't enough data, give up and return all-zeroes
 | 
	
		
			
				|  |  | +         * purely for determinism. But that value should never be
 | 
	
		
			
				|  |  | +         * used, because the caller will also check get_err(src). */
 | 
	
		
			
				|  |  | +        memset(pubkey, 0, p*sizeof(*pubkey));
 | 
	
		
			
				|  |  | +    } else {
 | 
	
		
			
				|  |  | +        /* Do the decoding */
 | 
	
		
			
				|  |  | +        ntru_decode(sched, pubkey, encoded);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Unbias the coefficients */
 | 
	
		
			
				|  |  | +        ntru_bias(pubkey, pubkey, q-q/2, p, q);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ntru_encode_schedule_free(sched);
 | 
	
		
			
				|  |  | +    return encoded;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * For ciphertext biasing: work out the largest absolute value a
 | 
	
		
			
				|  |  | + * ciphertext element can take, which is given by taking q/2 and
 | 
	
		
			
				|  |  | + * rounding it to the nearest multiple of 3.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +static inline unsigned ciphertext_bias(unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    return (q/2+1) / 3;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * The number of possible values of a ciphertext coefficient (for use
 | 
	
		
			
				|  |  | + * as the m_i in encoding) ranges from +ciphertext_bias(q) to
 | 
	
		
			
				|  |  | + * -ciphertext_bias(q) inclusive.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +static inline unsigned ciphertext_m(unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    return 1 + 2 * ciphertext_bias(q);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Compute the encoding schedule for a ciphertext.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +static NTRUEncodeSchedule *ntru_encode_ciphertext_schedule(
 | 
	
		
			
				|  |  | +    unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    unsigned m = ciphertext_m(q);
 | 
	
		
			
				|  |  | +    uint16_t *ms = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < p; i++)
 | 
	
		
			
				|  |  | +        ms[i] = m;
 | 
	
		
			
				|  |  | +    NTRUEncodeSchedule *sched = ntru_encode_schedule(ms, p);
 | 
	
		
			
				|  |  | +    sfree(ms);
 | 
	
		
			
				|  |  | +    return sched;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Encode a ciphertext.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_encode_ciphertext(const uint16_t *ciphertext, unsigned p, unsigned q,
 | 
	
		
			
				|  |  | +                            BinarySink *bs)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    SETUP;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Bias the ciphertext, and scale down by 1/3, which we do by
 | 
	
		
			
				|  |  | +     * modular multiplication by the inverse of 3 mod q. (That only
 | 
	
		
			
				|  |  | +     * works if we know the inputs are all _exact_ multiples of 3
 | 
	
		
			
				|  |  | +     * - but we do!)
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    uint16_t *biased_ciphertext = snewn(p, uint16_t);
 | 
	
		
			
				|  |  | +    ntru_bias(biased_ciphertext, ciphertext, 3 * ciphertext_bias(q), p, q);
 | 
	
		
			
				|  |  | +    ntru_scale(biased_ciphertext, biased_ciphertext, INVERT(3), p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Encode. */
 | 
	
		
			
				|  |  | +    NTRUEncodeSchedule *sched = ntru_encode_ciphertext_schedule(p, q);
 | 
	
		
			
				|  |  | +    ntru_encode(sched, biased_ciphertext, bs);
 | 
	
		
			
				|  |  | +    ntru_encode_schedule_free(sched);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ring_free(biased_ciphertext, p);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +ptrlen ntru_decode_ciphertext(uint16_t *ct, NTRUKeyPair *keypair,
 | 
	
		
			
				|  |  | +                              BinarySource *src)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    unsigned p = keypair->p, q = keypair->q;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    NTRUEncodeSchedule *sched = ntru_encode_ciphertext_schedule(p, q);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Retrieve the right number of bytes from the source */
 | 
	
		
			
				|  |  | +    size_t len = ntru_encode_schedule_length(sched);
 | 
	
		
			
				|  |  | +    ptrlen encoded = get_data(src, len);
 | 
	
		
			
				|  |  | +    if (get_err(src)) {
 | 
	
		
			
				|  |  | +        /* As above, return deterministic nonsense on failure */
 | 
	
		
			
				|  |  | +        memset(ct, 0, p*sizeof(*ct));
 | 
	
		
			
				|  |  | +    } else {
 | 
	
		
			
				|  |  | +        /* Do the decoding */
 | 
	
		
			
				|  |  | +        ntru_decode(sched, ct, encoded);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Undo the scaling and bias */
 | 
	
		
			
				|  |  | +        ntru_scale(ct, ct, 3, p, q);
 | 
	
		
			
				|  |  | +        ntru_bias(ct, ct, q - 3 * ciphertext_bias(q), p, q);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ntru_encode_schedule_free(sched);
 | 
	
		
			
				|  |  | +    return encoded;        /* also useful to the caller, optionally */
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Encode a plaintext.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * This is a much simpler encoding than the NTRUEncodeSchedule system:
 | 
	
		
			
				|  |  | + * since elements of a plaintext are mod 3, we just encode each one in
 | 
	
		
			
				|  |  | + * 2 bits, applying the usual bias so that {-1,0,+1} map to {0,1,2}
 | 
	
		
			
				|  |  | + * respectively.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * There's no corresponding decode function, because plaintexts are
 | 
	
		
			
				|  |  | + * never transmitted on the wire (the whole point is that they're too
 | 
	
		
			
				|  |  | + * secret!). Plaintexts are only encoded in order to put them into
 | 
	
		
			
				|  |  | + * hash preimages.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +void ntru_encode_plaintext(const uint16_t *plaintext, unsigned p,
 | 
	
		
			
				|  |  | +                           BinarySink *bs)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    unsigned byte = 0, bitpos = 0;
 | 
	
		
			
				|  |  | +    for (size_t i = 0; i < p; i++) {
 | 
	
		
			
				|  |  | +        unsigned encoding = (plaintext[i] + 1) * iszero(plaintext[i] >> 1);
 | 
	
		
			
				|  |  | +        byte |= encoding << bitpos;
 | 
	
		
			
				|  |  | +        bitpos += 2;
 | 
	
		
			
				|  |  | +        if (bitpos == 8 || i+1 == p) {
 | 
	
		
			
				|  |  | +            put_byte(bs, byte);
 | 
	
		
			
				|  |  | +            byte = 0;
 | 
	
		
			
				|  |  | +            bitpos = 0;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* ----------------------------------------------------------------------
 | 
	
		
			
				|  |  | + * Compute the hashes required by the key exchange layer of NTRU Prime.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * There are two of these. The 'confirmation hash' is sent by the
 | 
	
		
			
				|  |  | + * server along with the ciphertext, and the client can recalculate it
 | 
	
		
			
				|  |  | + * to check whether the ciphertext was decrypted correctly. Then, the
 | 
	
		
			
				|  |  | + * 'session hash' is the actual output of key exchange, and if the
 | 
	
		
			
				|  |  | + * confirmation hash doesn't match, it gets deliberately corrupted.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Make the confirmation hash, whose inputs are the plaintext and the
 | 
	
		
			
				|  |  | + * public key.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * This is defined as H(2 || H(3 || r) || H(4 || K)), where r is the
 | 
	
		
			
				|  |  | + * plaintext and K is the public key (as encoded by the above
 | 
	
		
			
				|  |  | + * functions), and the constants 2,3,4 are single bytes. The choice of
 | 
	
		
			
				|  |  | + * hash function (H itself) is SHA-512 truncated to 256 bits.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * (To be clear: that is _not_ the thing that FIPS 180-4 6.7 defines
 | 
	
		
			
				|  |  | + * as "SHA-512/256", which varies the initialisation vector of the
 | 
	
		
			
				|  |  | + * SHA-512 algorithm as well as truncating the output. _This_
 | 
	
		
			
				|  |  | + * algorithm uses the standard SHA-512 IV, and _just_ truncates the
 | 
	
		
			
				|  |  | + * output, in the manner suggested by FIPS 180-4 section 7.)
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * 'out' should therefore expect to receive 32 bytes of data.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +static void ntru_confirmation_hash(
 | 
	
		
			
				|  |  | +    uint8_t *out, const uint16_t *plaintext,
 | 
	
		
			
				|  |  | +    const uint16_t *pubkey, unsigned p, unsigned q)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    /* The outer hash object */
 | 
	
		
			
				|  |  | +    ssh_hash *hconfirm = ssh_hash_new(&ssh_sha512);
 | 
	
		
			
				|  |  | +    put_byte(hconfirm, 2);             /* initial byte 2 */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    uint8_t hashdata[64];
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Compute H(3 || r) and add it to the main hash */
 | 
	
		
			
				|  |  | +    ssh_hash *h3r = ssh_hash_new(&ssh_sha512);
 | 
	
		
			
				|  |  | +    put_byte(h3r, 3);
 | 
	
		
			
				|  |  | +    ntru_encode_plaintext(plaintext, p, BinarySink_UPCAST(h3r));
 | 
	
		
			
				|  |  | +    ssh_hash_final(h3r, hashdata);
 | 
	
		
			
				|  |  | +    put_data(hconfirm, hashdata, 32);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Compute H(4 || K) and add it to the main hash */
 | 
	
		
			
				|  |  | +    ssh_hash *h4K = ssh_hash_new(&ssh_sha512);
 | 
	
		
			
				|  |  | +    put_byte(h4K, 4);
 | 
	
		
			
				|  |  | +    ntru_encode_pubkey(pubkey, p, q, BinarySink_UPCAST(h4K));
 | 
	
		
			
				|  |  | +    ssh_hash_final(h4K, hashdata);
 | 
	
		
			
				|  |  | +    put_data(hconfirm, hashdata, 32);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Compute the full output of the main SHA-512 hash */
 | 
	
		
			
				|  |  | +    ssh_hash_final(hconfirm, hashdata);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* And copy the first 32 bytes into the caller's output array */
 | 
	
		
			
				|  |  | +    memcpy(out, hashdata, 32);
 | 
	
		
			
				|  |  | +    smemclr(hashdata, sizeof(hashdata));
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * Make the session hash, whose inputs are the plaintext, the
 | 
	
		
			
				|  |  | + * ciphertext, and the confirmation hash (hence, transitively, a
 | 
	
		
			
				|  |  | + * dependence on the public key as well).
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * As computed by the server, and by the client if the confirmation
 | 
	
		
			
				|  |  | + * hash matched, this is defined as
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + *   H(1 || H(3 || r) || ciphertext || confirmation hash)
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * but if the confirmation hash _didn't_ match, then the plaintext r
 | 
	
		
			
				|  |  | + * is replaced with the dummy plaintext-shaped value 'rho' we invented
 | 
	
		
			
				|  |  | + * during key generation (presumably to avoid leaking any information
 | 
	
		
			
				|  |  | + * about our secrets), and the initial byte 1 is replaced with 0 (to
 | 
	
		
			
				|  |  | + * ensure that the resulting hash preimage can't match any legitimate
 | 
	
		
			
				|  |  | + * preimage). So in that case, you instead get
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + *   H(0 || H(3 || rho) || ciphertext || confirmation hash)
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * The inputs to this function include 'ok', which is the value to use
 | 
	
		
			
				|  |  | + * as the initial byte (1 on success, 0 on failure), and 'plaintext'
 | 
	
		
			
				|  |  | + * which should already have been substituted with rho in case of
 | 
	
		
			
				|  |  | + * failure.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * The ciphertext is provided in already-encoded form.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +static void ntru_session_hash(
 | 
	
		
			
				|  |  | +    uint8_t *out, unsigned ok, const uint16_t *plaintext,
 | 
	
		
			
				|  |  | +    unsigned p, ptrlen ciphertext, ptrlen confirmation_hash)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    /* The outer hash object */
 | 
	
		
			
				|  |  | +    ssh_hash *hsession = ssh_hash_new(&ssh_sha512);
 | 
	
		
			
				|  |  | +    put_byte(hsession, ok);            /* initial byte 1 or 0 */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    uint8_t hashdata[64];
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Compute H(3 || r), or maybe H(3 || rho), and add it to the main hash */
 | 
	
		
			
				|  |  | +    ssh_hash *h3r = ssh_hash_new(&ssh_sha512);
 | 
	
		
			
				|  |  | +    put_byte(h3r, 3);
 | 
	
		
			
				|  |  | +    ntru_encode_plaintext(plaintext, p, BinarySink_UPCAST(h3r));
 | 
	
		
			
				|  |  | +    ssh_hash_final(h3r, hashdata);
 | 
	
		
			
				|  |  | +    put_data(hsession, hashdata, 32);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Put the ciphertext and confirmation hash in */
 | 
	
		
			
				|  |  | +    put_datapl(hsession, ciphertext);
 | 
	
		
			
				|  |  | +    put_datapl(hsession, confirmation_hash);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Compute the full output of the main SHA-512 hash */
 | 
	
		
			
				|  |  | +    ssh_hash_final(hsession, hashdata);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* And copy the first 32 bytes into the caller's output array */
 | 
	
		
			
				|  |  | +    memcpy(out, hashdata, 32);
 | 
	
		
			
				|  |  | +    smemclr(hashdata, sizeof(hashdata));
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* ----------------------------------------------------------------------
 | 
	
		
			
				|  |  | + * Top-level key exchange and SSH integration.
 | 
	
		
			
				|  |  | + *
 | 
	
		
			
				|  |  | + * Although this system borrows the ECDH packet structure, it's unlike
 | 
	
		
			
				|  |  | + * true ECDH in that it is completely asymmetric between client and
 | 
	
		
			
				|  |  | + * server. So we have two separate vtables of methods for the two
 | 
	
		
			
				|  |  | + * sides of the system, and a third vtable containing only the class
 | 
	
		
			
				|  |  | + * methods, in particular a constructor which chooses which one to
 | 
	
		
			
				|  |  | + * instantiate.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * The parameters p,q,w for the system. There are other choices of
 | 
	
		
			
				|  |  | + * these, but OpenSSH only specifies this set. (If that ever changes,
 | 
	
		
			
				|  |  | + * we'll need to turn these into elements of the state structures.)
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +#define p_LIVE 761
 | 
	
		
			
				|  |  | +#define q_LIVE 4591
 | 
	
		
			
				|  |  | +#define w_LIVE 286
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static char *ssh_ntru_description(const ssh_kex *kex)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    return dupprintf("NTRU Prime / Curve25519 hybrid key exchange");
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * State structure for the client, which takes the role of inventing a
 | 
	
		
			
				|  |  | + * key pair and decrypting a secret plaintext sent to it by the server.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +typedef struct ntru_client_key {
 | 
	
		
			
				|  |  | +    NTRUKeyPair *keypair;
 | 
	
		
			
				|  |  | +    ecdh_key *curve25519;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ecdh_key ek;
 | 
	
		
			
				|  |  | +} ntru_client_key;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static void ssh_ntru_client_free(ecdh_key *dh);
 | 
	
		
			
				|  |  | +static void ssh_ntru_client_getpublic(ecdh_key *dh, BinarySink *bs);
 | 
	
		
			
				|  |  | +static bool ssh_ntru_client_getkey(ecdh_key *dh, ptrlen remoteKey,
 | 
	
		
			
				|  |  | +                                   BinarySink *bs);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static const ecdh_keyalg ssh_ntru_client_vt = {
 | 
	
		
			
				|  |  | +    /* This vtable has no 'new' method, because it's constructed via
 | 
	
		
			
				|  |  | +     * the selector vt below */
 | 
	
		
			
				|  |  | +    .free = ssh_ntru_client_free,
 | 
	
		
			
				|  |  | +    .getpublic = ssh_ntru_client_getpublic,
 | 
	
		
			
				|  |  | +    .getkey = ssh_ntru_client_getkey,
 | 
	
		
			
				|  |  | +    .description = ssh_ntru_description,
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static ecdh_key *ssh_ntru_client_new(void)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    ntru_client_key *nk = snew(ntru_client_key);
 | 
	
		
			
				|  |  | +    nk->ek.vt = &ssh_ntru_client_vt;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    nk->keypair = ntru_keygen(p_LIVE, q_LIVE, w_LIVE);
 | 
	
		
			
				|  |  | +    nk->curve25519 = ecdh_key_new(&ssh_ec_kex_curve25519, false);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    return &nk->ek;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static void ssh_ntru_client_free(ecdh_key *dh)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    ntru_client_key *nk = container_of(dh, ntru_client_key, ek);
 | 
	
		
			
				|  |  | +    ntru_keypair_free(nk->keypair);
 | 
	
		
			
				|  |  | +    ecdh_key_free(nk->curve25519);
 | 
	
		
			
				|  |  | +    sfree(nk);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static void ssh_ntru_client_getpublic(ecdh_key *dh, BinarySink *bs)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    ntru_client_key *nk = container_of(dh, ntru_client_key, ek);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * The client's public information is a single SSH string
 | 
	
		
			
				|  |  | +     * containing the NTRU public key and the Curve25519 public point
 | 
	
		
			
				|  |  | +     * concatenated. So write both of those into the output
 | 
	
		
			
				|  |  | +     * BinarySink.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    ntru_encode_pubkey(nk->keypair->h, p_LIVE, q_LIVE, bs);
 | 
	
		
			
				|  |  | +    ecdh_key_getpublic(nk->curve25519, bs);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static bool ssh_ntru_client_getkey(ecdh_key *dh, ptrlen remoteKey,
 | 
	
		
			
				|  |  | +                                   BinarySink *bs)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    ntru_client_key *nk = container_of(dh, ntru_client_key, ek);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * We expect the server to have sent us a string containing a
 | 
	
		
			
				|  |  | +     * ciphertext, a confirmation hash, and a Curve25519 public point.
 | 
	
		
			
				|  |  | +     * Extract all three.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    BinarySource src[1];
 | 
	
		
			
				|  |  | +    BinarySource_BARE_INIT_PL(src, remoteKey);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    uint16_t *ciphertext = snewn(p_LIVE, uint16_t);
 | 
	
		
			
				|  |  | +    ptrlen ciphertext_encoded = ntru_decode_ciphertext(
 | 
	
		
			
				|  |  | +        ciphertext, nk->keypair, src);
 | 
	
		
			
				|  |  | +    ptrlen confirmation_hash = get_data(src, 32);
 | 
	
		
			
				|  |  | +    ptrlen curve25519_remoteKey = get_data(src, 32);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    if (get_err(src) || get_avail(src)) {
 | 
	
		
			
				|  |  | +        /* Hard-fail if the input wasn't exactly the right length */
 | 
	
		
			
				|  |  | +        ring_free(ciphertext, p_LIVE);
 | 
	
		
			
				|  |  | +        return false;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Main hash object which will combine the NTRU and Curve25519
 | 
	
		
			
				|  |  | +     * outputs.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    ssh_hash *h = ssh_hash_new(&ssh_sha512);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Reusable buffer for storing various hash outputs. */
 | 
	
		
			
				|  |  | +    uint8_t hashdata[64];
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * NTRU side.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    {
 | 
	
		
			
				|  |  | +        /* Decrypt the ciphertext to recover the server's plaintext */
 | 
	
		
			
				|  |  | +        uint16_t *plaintext = snewn(p_LIVE, uint16_t);
 | 
	
		
			
				|  |  | +        ntru_decrypt(plaintext, ciphertext, nk->keypair);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Make the confirmation hash */
 | 
	
		
			
				|  |  | +        ntru_confirmation_hash(hashdata, plaintext, nk->keypair->h,
 | 
	
		
			
				|  |  | +                               p_LIVE, q_LIVE);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Check it matches the one the server sent */
 | 
	
		
			
				|  |  | +        unsigned ok = smemeq(hashdata, confirmation_hash.ptr, 32);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* If not, substitute in rho for the plaintext in the session hash */
 | 
	
		
			
				|  |  | +        unsigned mask = ok-1;
 | 
	
		
			
				|  |  | +        for (size_t i = 0; i < p_LIVE; i++)
 | 
	
		
			
				|  |  | +            plaintext[i] ^= mask & (plaintext[i] ^ nk->keypair->rho[i]);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Compute the session hash, whether or not we did that */
 | 
	
		
			
				|  |  | +        ntru_session_hash(hashdata, ok, plaintext, p_LIVE, ciphertext_encoded,
 | 
	
		
			
				|  |  | +                          confirmation_hash);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Free temporary values */
 | 
	
		
			
				|  |  | +        ring_free(plaintext, p_LIVE);
 | 
	
		
			
				|  |  | +        ring_free(ciphertext, p_LIVE);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* And put the NTRU session hash into the main hash object. */
 | 
	
		
			
				|  |  | +        put_data(h, hashdata, 32);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Curve25519 side.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    {
 | 
	
		
			
				|  |  | +        strbuf *otherkey = strbuf_new_nm();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Call out to Curve25519 to compute the shared secret from that
 | 
	
		
			
				|  |  | +         * kex method */
 | 
	
		
			
				|  |  | +        bool ok = ecdh_key_getkey(nk->curve25519, curve25519_remoteKey,
 | 
	
		
			
				|  |  | +                                  BinarySink_UPCAST(otherkey));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* If that failed (which only happens if the other end does
 | 
	
		
			
				|  |  | +         * something wrong, like sending a low-order curve point
 | 
	
		
			
				|  |  | +         * outside the subgroup it's supposed to), we might as well
 | 
	
		
			
				|  |  | +         * just abort and return failure. That's what we'd have done
 | 
	
		
			
				|  |  | +         * in standalone Curve25519. */
 | 
	
		
			
				|  |  | +        if (!ok) {
 | 
	
		
			
				|  |  | +            ssh_hash_free(h);
 | 
	
		
			
				|  |  | +            smemclr(hashdata, sizeof(hashdata));
 | 
	
		
			
				|  |  | +            strbuf_free(otherkey);
 | 
	
		
			
				|  |  | +            return false;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /*
 | 
	
		
			
				|  |  | +         * ecdh_key_getkey will have returned us a chunk of data
 | 
	
		
			
				|  |  | +         * containing an encoded mpint, which is how the Curve25519
 | 
	
		
			
				|  |  | +         * output normally goes into the exchange hash. But in this
 | 
	
		
			
				|  |  | +         * context we want to treat it as a fixed big-endian 32 bytes,
 | 
	
		
			
				|  |  | +         * so extract it from its encoding and put it into the main
 | 
	
		
			
				|  |  | +         * hash object in the new format.
 | 
	
		
			
				|  |  | +         */
 | 
	
		
			
				|  |  | +        BinarySource src[1];
 | 
	
		
			
				|  |  | +        BinarySource_BARE_INIT_PL(src, ptrlen_from_strbuf(otherkey));
 | 
	
		
			
				|  |  | +        mp_int *curvekey = get_mp_ssh2(src);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        for (unsigned i = 32; i-- > 0 ;)
 | 
	
		
			
				|  |  | +            put_byte(h, mp_get_byte(curvekey, i));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        mp_free(curvekey);
 | 
	
		
			
				|  |  | +        strbuf_free(otherkey);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Finish up: compute the final output hash (full 64 bytes of
 | 
	
		
			
				|  |  | +     * SHA-512 this time), and return it encoded as a string.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    ssh_hash_final(h, hashdata);
 | 
	
		
			
				|  |  | +    put_stringpl(bs, make_ptrlen(hashdata, sizeof(hashdata)));
 | 
	
		
			
				|  |  | +    smemclr(hashdata, sizeof(hashdata));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    return true;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/*
 | 
	
		
			
				|  |  | + * State structure for the server, which takes the role of inventing a
 | 
	
		
			
				|  |  | + * secret plaintext and sending it to the client encrypted with the
 | 
	
		
			
				|  |  | + * public key the client sent.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +typedef struct ntru_server_key {
 | 
	
		
			
				|  |  | +    uint16_t *plaintext;
 | 
	
		
			
				|  |  | +    strbuf *ciphertext_encoded, *confirmation_hash;
 | 
	
		
			
				|  |  | +    ecdh_key *curve25519;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    ecdh_key ek;
 | 
	
		
			
				|  |  | +} ntru_server_key;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static void ssh_ntru_server_free(ecdh_key *dh);
 | 
	
		
			
				|  |  | +static void ssh_ntru_server_getpublic(ecdh_key *dh, BinarySink *bs);
 | 
	
		
			
				|  |  | +static bool ssh_ntru_server_getkey(ecdh_key *dh, ptrlen remoteKey,
 | 
	
		
			
				|  |  | +                                   BinarySink *bs);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static const ecdh_keyalg ssh_ntru_server_vt = {
 | 
	
		
			
				|  |  | +    /* This vtable has no 'new' method, because it's constructed via
 | 
	
		
			
				|  |  | +     * the selector vt below */
 | 
	
		
			
				|  |  | +    .free = ssh_ntru_server_free,
 | 
	
		
			
				|  |  | +    .getpublic = ssh_ntru_server_getpublic,
 | 
	
		
			
				|  |  | +    .getkey = ssh_ntru_server_getkey,
 | 
	
		
			
				|  |  | +    .description = ssh_ntru_description,
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static ecdh_key *ssh_ntru_server_new(void)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    ntru_server_key *nk = snew(ntru_server_key);
 | 
	
		
			
				|  |  | +    nk->ek.vt = &ssh_ntru_server_vt;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    nk->plaintext = snewn(p_LIVE, uint16_t);
 | 
	
		
			
				|  |  | +    nk->ciphertext_encoded = strbuf_new_nm();
 | 
	
		
			
				|  |  | +    nk->confirmation_hash = strbuf_new_nm();
 | 
	
		
			
				|  |  | +    ntru_gen_short(nk->plaintext, p_LIVE, w_LIVE);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    nk->curve25519 = ecdh_key_new(&ssh_ec_kex_curve25519, false);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    return &nk->ek;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static void ssh_ntru_server_free(ecdh_key *dh)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    ntru_server_key *nk = container_of(dh, ntru_server_key, ek);
 | 
	
		
			
				|  |  | +    ring_free(nk->plaintext, p_LIVE);
 | 
	
		
			
				|  |  | +    strbuf_free(nk->ciphertext_encoded);
 | 
	
		
			
				|  |  | +    strbuf_free(nk->confirmation_hash);
 | 
	
		
			
				|  |  | +    ecdh_key_free(nk->curve25519);
 | 
	
		
			
				|  |  | +    sfree(nk);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static bool ssh_ntru_server_getkey(ecdh_key *dh, ptrlen remoteKey,
 | 
	
		
			
				|  |  | +                                   BinarySink *bs)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    ntru_server_key *nk = container_of(dh, ntru_server_key, ek);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * In the server, getkey is called first, with the public
 | 
	
		
			
				|  |  | +     * information received from the client. We expect the client to
 | 
	
		
			
				|  |  | +     * have sent us a string containing a public key and a Curve25519
 | 
	
		
			
				|  |  | +     * public point.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    BinarySource src[1];
 | 
	
		
			
				|  |  | +    BinarySource_BARE_INIT_PL(src, remoteKey);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    uint16_t *pubkey = snewn(p_LIVE, uint16_t);
 | 
	
		
			
				|  |  | +    ntru_decode_pubkey(pubkey, p_LIVE, q_LIVE, src);
 | 
	
		
			
				|  |  | +    ptrlen curve25519_remoteKey = get_data(src, 32);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    if (get_err(src) || get_avail(src)) {
 | 
	
		
			
				|  |  | +        /* Hard-fail if the input wasn't exactly the right length */ 
 | 
	
		
			
				|  |  | +        ring_free(pubkey, p_LIVE);
 | 
	
		
			
				|  |  | +        return false;
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Main hash object which will combine the NTRU and Curve25519
 | 
	
		
			
				|  |  | +     * outputs.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    ssh_hash *h = ssh_hash_new(&ssh_sha512);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /* Reusable buffer for storing various hash outputs. */
 | 
	
		
			
				|  |  | +    uint8_t hashdata[64];
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * NTRU side.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    {
 | 
	
		
			
				|  |  | +        /* Encrypt the plaintext we generated at construction time,
 | 
	
		
			
				|  |  | +         * and encode the ciphertext into a strbuf so we can reuse it
 | 
	
		
			
				|  |  | +         * for both the session hash and sending to the client. */
 | 
	
		
			
				|  |  | +        uint16_t *ciphertext = snewn(p_LIVE, uint16_t);
 | 
	
		
			
				|  |  | +        ntru_encrypt(ciphertext, nk->plaintext, pubkey, p_LIVE, q_LIVE);
 | 
	
		
			
				|  |  | +        ntru_encode_ciphertext(ciphertext, p_LIVE, q_LIVE,
 | 
	
		
			
				|  |  | +                               BinarySink_UPCAST(nk->ciphertext_encoded));
 | 
	
		
			
				|  |  | +        ring_free(ciphertext, p_LIVE);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Compute the confirmation hash, and write it into another
 | 
	
		
			
				|  |  | +         * strbuf. */
 | 
	
		
			
				|  |  | +        ntru_confirmation_hash(hashdata, nk->plaintext, pubkey,
 | 
	
		
			
				|  |  | +                               p_LIVE, q_LIVE);
 | 
	
		
			
				|  |  | +        put_data(nk->confirmation_hash, hashdata, 32);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Compute the session hash (which is easy on the server side,
 | 
	
		
			
				|  |  | +         * requiring no conditional substitution). */
 | 
	
		
			
				|  |  | +        ntru_session_hash(hashdata, 1, nk->plaintext, p_LIVE,
 | 
	
		
			
				|  |  | +                          ptrlen_from_strbuf(nk->ciphertext_encoded),
 | 
	
		
			
				|  |  | +                          ptrlen_from_strbuf(nk->confirmation_hash));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* And put the NTRU session hash into the main hash object. */
 | 
	
		
			
				|  |  | +        put_data(h, hashdata, 32);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Now we can free the public key */
 | 
	
		
			
				|  |  | +        ring_free(pubkey, p_LIVE);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Curve25519 side.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    {
 | 
	
		
			
				|  |  | +        strbuf *otherkey = strbuf_new_nm();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* Call out to Curve25519 to compute the shared secret from that
 | 
	
		
			
				|  |  | +         * kex method */
 | 
	
		
			
				|  |  | +        bool ok = ecdh_key_getkey(nk->curve25519, curve25519_remoteKey,
 | 
	
		
			
				|  |  | +                                  BinarySink_UPCAST(otherkey));
 | 
	
		
			
				|  |  | +        /* As on the client side, abort if Curve25519 reported failure */
 | 
	
		
			
				|  |  | +        if (!ok) {
 | 
	
		
			
				|  |  | +            ssh_hash_free(h);
 | 
	
		
			
				|  |  | +            smemclr(hashdata, sizeof(hashdata));
 | 
	
		
			
				|  |  | +            strbuf_free(otherkey);
 | 
	
		
			
				|  |  | +            return false;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        /* As on the client side, decode Curve25519's mpint so we can
 | 
	
		
			
				|  |  | +         * re-encode it appropriately for our hash preimage */
 | 
	
		
			
				|  |  | +        BinarySource src[1];
 | 
	
		
			
				|  |  | +        BinarySource_BARE_INIT_PL(src, ptrlen_from_strbuf(otherkey));
 | 
	
		
			
				|  |  | +        mp_int *curvekey = get_mp_ssh2(src);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        for (unsigned i = 32; i-- > 0 ;)
 | 
	
		
			
				|  |  | +            put_byte(h, mp_get_byte(curvekey, i));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        mp_free(curvekey);
 | 
	
		
			
				|  |  | +        strbuf_free(otherkey);
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * Finish up: compute the final output hash (full 64 bytes of
 | 
	
		
			
				|  |  | +     * SHA-512 this time), and return it encoded as a string.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    ssh_hash_final(h, hashdata);
 | 
	
		
			
				|  |  | +    put_stringpl(bs, make_ptrlen(hashdata, sizeof(hashdata)));
 | 
	
		
			
				|  |  | +    smemclr(hashdata, sizeof(hashdata));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    return true;
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static void ssh_ntru_server_getpublic(ecdh_key *dh, BinarySink *bs)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    ntru_server_key *nk = container_of(dh, ntru_server_key, ek);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    /*
 | 
	
		
			
				|  |  | +     * In the server, this function is called after getkey, so we
 | 
	
		
			
				|  |  | +     * already have all our pieces prepared. Just concatenate them all
 | 
	
		
			
				|  |  | +     * into the 'server's public data' string to go in ECDH_REPLY.
 | 
	
		
			
				|  |  | +     */
 | 
	
		
			
				|  |  | +    put_datapl(bs, ptrlen_from_strbuf(nk->ciphertext_encoded));
 | 
	
		
			
				|  |  | +    put_datapl(bs, ptrlen_from_strbuf(nk->confirmation_hash));
 | 
	
		
			
				|  |  | +    ecdh_key_getpublic(nk->curve25519, bs);
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +/* ----------------------------------------------------------------------
 | 
	
		
			
				|  |  | + * Selector vtable that instantiates the appropriate one of the above,
 | 
	
		
			
				|  |  | + * depending on is_server.
 | 
	
		
			
				|  |  | + */
 | 
	
		
			
				|  |  | +static ecdh_key *ssh_ntru_new(const ssh_kex *kex, bool is_server)
 | 
	
		
			
				|  |  | +{
 | 
	
		
			
				|  |  | +    if (is_server)
 | 
	
		
			
				|  |  | +        return ssh_ntru_server_new();
 | 
	
		
			
				|  |  | +    else
 | 
	
		
			
				|  |  | +        return ssh_ntru_client_new();
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static const ecdh_keyalg ssh_ntru_selector_vt = {
 | 
	
		
			
				|  |  | +    /* This is a never-instantiated vtable which only implements the
 | 
	
		
			
				|  |  | +     * functions that don't require an instance. */
 | 
	
		
			
				|  |  | +    .new = ssh_ntru_new,
 | 
	
		
			
				|  |  | +    .description = ssh_ntru_description,
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static const ssh_kex ssh_ntru_curve25519 = {
 | 
	
		
			
				|  |  | +    .name = "[email protected]",
 | 
	
		
			
				|  |  | +    .main_type = KEXTYPE_ECDH,
 | 
	
		
			
				|  |  | +    .hash = &ssh_sha512,
 | 
	
		
			
				|  |  | +    .ecdh_vt = &ssh_ntru_selector_vt,
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +static const ssh_kex *const hybrid_list[] = {
 | 
	
		
			
				|  |  | +    &ssh_ntru_curve25519,
 | 
	
		
			
				|  |  | +};
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +const ssh_kexes ssh_ntru_hybrid_kex = { lenof(hybrid_list), hybrid_list };
 |