Browse Source

PuTTY snapshot 25b034ee (Complete rewrite of PuTTY's bignum library - 2018-12-31)

Source commit: c3541c9a36d3ca255052b5a32a6876b67dd60d8d
Martin Prikryl 6 years ago
parent
commit
3f4093ad87

+ 10 - 0
source/putty/defs.h

@@ -63,6 +63,16 @@ typedef struct TermWinVtable TermWinVtable;
 
 typedef struct Ssh Ssh;
 
+typedef struct mp_int mp_int;
+typedef struct MontyContext MontyContext;
+
+typedef struct WeierstrassCurve WeierstrassCurve;
+typedef struct WeierstrassPoint WeierstrassPoint;
+typedef struct MontgomeryCurve MontgomeryCurve;
+typedef struct MontgomeryPoint MontgomeryPoint;
+typedef struct EdwardsCurve EdwardsCurve;
+typedef struct EdwardsPoint EdwardsPoint;
+
 typedef struct SftpServer SftpServer;
 typedef struct SftpServerVtable SftpServerVtable;
 

+ 1112 - 0
source/putty/ecc.c

@@ -0,0 +1,1112 @@
+#include <assert.h>
+
+#include "ssh.h"
+#include "mpint.h"
+#include "ecc.h"
+
+/* ----------------------------------------------------------------------
+ * Weierstrass curves.
+ */
+
+struct WeierstrassPoint {
+    /*
+     * Internally, we represent a point using 'Jacobian coordinates',
+     * which are three values X,Y,Z whose relation to the affine
+     * coordinates x,y is that x = X/Z^2 and y = Y/Z^3.
+     *
+     * This allows us to do most of our calculations without having to
+     * take an inverse mod p: every time the obvious affine formulae
+     * would need you to divide by something, you instead multiply it
+     * into the 'denominator' coordinate Z. You only have to actually
+     * take the inverse of Z when you need to get the affine
+     * coordinates back out, which means you do it once after your
+     * entire computation instead of at every intermediate step.
+     *
+     * The point at infinity is represented by setting all three
+     * coordinates to zero.
+     *
+     * These values are also stored in the Montgomery-multiplication
+     * transformed representation.
+     */
+    mp_int *X, *Y, *Z;
+
+    WeierstrassCurve *wc;
+};
+
+struct WeierstrassCurve {
+    /* Prime modulus of the finite field. */
+    mp_int *p;
+
+    /* Persistent Montgomery context for doing arithmetic mod p. */
+    MontyContext *mc;
+
+    /* Modsqrt context for point decompression. NULL if this curve was
+     * constructed without providing nonsquare_mod_p. */
+    ModsqrtContext *sc;
+
+    /* Parameters of the curve, in Montgomery-multiplication
+     * transformed form. */
+    mp_int *a, *b;
+};
+
+WeierstrassCurve *ecc_weierstrass_curve(
+    mp_int *p, mp_int *a, mp_int *b, mp_int *nonsquare_mod_p)
+{
+    WeierstrassCurve *wc = snew(WeierstrassCurve);
+    wc->p = mp_copy(p);
+    wc->mc = monty_new(p);
+    wc->a = monty_import(wc->mc, a);
+    wc->b = monty_import(wc->mc, b);
+
+    if (nonsquare_mod_p)
+        wc->sc = modsqrt_new(p, nonsquare_mod_p);
+    else
+        wc->sc = NULL;
+
+    return wc;
+}
+
+void ecc_weierstrass_curve_free(WeierstrassCurve *wc)
+{
+    mp_free(wc->p);
+    mp_free(wc->a);
+    mp_free(wc->b);
+    monty_free(wc->mc);
+    if (wc->sc)
+        modsqrt_free(wc->sc);
+    sfree(wc);
+}
+
+static WeierstrassPoint *ecc_weierstrass_point_new_empty(WeierstrassCurve *wc)
+{
+    WeierstrassPoint *wp = snew(WeierstrassPoint);
+    wp->wc = wc;
+    wp->X = wp->Y = wp->Z = NULL;
+    return wp;
+}
+
+static WeierstrassPoint *ecc_weierstrass_point_new_imported(
+    WeierstrassCurve *wc, mp_int *monty_x, mp_int *monty_y)
+{
+    WeierstrassPoint *wp = ecc_weierstrass_point_new_empty(wc);
+    wp->X = monty_x;
+    wp->Y = monty_y;
+    wp->Z = mp_copy(monty_identity(wc->mc));
+    return wp;
+}
+
+WeierstrassPoint *ecc_weierstrass_point_new(
+    WeierstrassCurve *wc, mp_int *x, mp_int *y)
+{
+    return ecc_weierstrass_point_new_imported(
+        wc, monty_import(wc->mc, x), monty_import(wc->mc, y));
+}
+
+WeierstrassPoint *ecc_weierstrass_point_new_identity(WeierstrassCurve *wc)
+{
+    WeierstrassPoint *wp = ecc_weierstrass_point_new_empty(wc);
+    size_t bits = mp_max_bits(wc->p);
+    wp->X = mp_new(bits);
+    wp->Y = mp_new(bits);
+    wp->Z = mp_new(bits);
+    return wp;
+}
+
+WeierstrassPoint *ecc_weierstrass_point_copy(WeierstrassPoint *orig)
+{
+    WeierstrassPoint *wp = ecc_weierstrass_point_new_empty(orig->wc);
+    wp->X = mp_copy(orig->X);
+    wp->Y = mp_copy(orig->Y);
+    wp->Z = mp_copy(orig->Z);
+    return wp;
+}
+
+void ecc_weierstrass_point_free(WeierstrassPoint *wp)
+{
+    mp_free(wp->X);
+    mp_free(wp->Y);
+    mp_free(wp->Z);
+    smemclr(wp, sizeof(*wp));
+    sfree(wp);
+}
+
+static mp_int *ecc_weierstrass_equation_rhs(
+    WeierstrassCurve *wc, mp_int *monty_x)
+{
+    mp_int *x2 = monty_mul(wc->mc, monty_x, monty_x);
+    mp_int *x2_plus_a = monty_add(wc->mc, x2, wc->a);
+    mp_int *x3_plus_ax = monty_mul(wc->mc, x2_plus_a, monty_x);
+    mp_int *rhs = monty_add(wc->mc, x3_plus_ax, wc->b);
+    mp_free(x2);
+    mp_free(x2_plus_a);
+    mp_free(x3_plus_ax);
+    return rhs;
+}
+
+WeierstrassPoint *ecc_weierstrass_point_new_from_x(
+    WeierstrassCurve *wc, mp_int *xorig, unsigned desired_y_parity)
+{
+    assert(wc->sc);
+
+    /*
+     * The curve equation is y^2 = x^3 + ax + b, which is already
+     * conveniently in a form where we can compute the RHS and take
+     * the square root of it to get y.
+     */
+    unsigned success;
+
+    mp_int *x = monty_import(wc->mc, xorig);
+    mp_int *rhs = ecc_weierstrass_equation_rhs(wc, x);
+    mp_int *y = monty_modsqrt(wc->sc, rhs, &success);
+    mp_free(rhs);
+
+    if (!success) {
+        /* Failure! x^3+ax+b worked out to be a number that has no
+         * square root mod p. In this situation there's no point in
+         * trying to be time-constant, since the protocol sequence is
+         * going to diverge anyway when we complain to whoever gave us
+         * this bogus value. */
+        mp_free(x);
+        mp_free(y);
+        return NULL;
+    }
+
+    /*
+     * Choose whichever of y and p-y has the specified parity (of its
+     * lowest positive residue mod p).
+     */
+    mp_int *tmp = monty_export(wc->mc, y);
+    unsigned flip = (mp_get_bit(tmp, 0) ^ desired_y_parity) & 1;
+    mp_sub_into(tmp, wc->p, y);
+    mp_select_into(y, y, tmp, flip);
+    mp_free(tmp);
+
+    return ecc_weierstrass_point_new_imported(wc, x, y);
+}
+
+static void ecc_weierstrass_cond_overwrite(
+    WeierstrassPoint *dest, WeierstrassPoint *src, unsigned overwrite)
+{
+    mp_select_into(dest->X, dest->X, src->X, overwrite);
+    mp_select_into(dest->Y, dest->Y, src->Y, overwrite);
+    mp_select_into(dest->Z, dest->Z, src->Z, overwrite);
+}
+
+static void ecc_weierstrass_cond_swap(
+    WeierstrassPoint *P, WeierstrassPoint *Q, unsigned swap)
+{
+    mp_cond_swap(P->X, Q->X, swap);
+    mp_cond_swap(P->Y, Q->Y, swap);
+    mp_cond_swap(P->Z, Q->Z, swap);
+}
+
+/*
+ * Shared code between all three of the basic arithmetic functions:
+ * once we've determined the slope of the line that we're intersecting
+ * the curve with, this takes care of finding the coordinates of the
+ * third intersection point (given the two input x-coordinates and one
+ * of the y-coords) and negating it to generate the output.
+ */
+static inline void ecc_weierstrass_epilogue(
+    mp_int *Px, mp_int *Qx, mp_int *Py, mp_int *common_Z,
+    mp_int *lambda_n, mp_int *lambda_d, WeierstrassPoint *out)
+{
+    WeierstrassCurve *wc = out->wc;
+
+    /* Powers of the numerator and denominator of the slope lambda */
+    mp_int *lambda_n2 = monty_mul(wc->mc, lambda_n, lambda_n);
+    mp_int *lambda_d2 = monty_mul(wc->mc, lambda_d, lambda_d);
+    mp_int *lambda_d3 = monty_mul(wc->mc, lambda_d, lambda_d2);
+
+    /* Make the output x-coordinate */
+    mp_int *xsum = monty_add(wc->mc, Px, Qx);
+    mp_int *lambda_d2_xsum = monty_mul(wc->mc, lambda_d2, xsum);
+    out->X = monty_sub(wc->mc, lambda_n2, lambda_d2_xsum);
+
+    /* Make the output y-coordinate */
+    mp_int *lambda_d2_Px = monty_mul(wc->mc, lambda_d2, Px);
+    mp_int *xdiff = monty_sub(wc->mc, lambda_d2_Px, out->X);
+    mp_int *lambda_n_xdiff = monty_mul(wc->mc, lambda_n, xdiff);
+    mp_int *lambda_d3_Py = monty_mul(wc->mc, lambda_d3, Py);
+    out->Y = monty_sub(wc->mc, lambda_n_xdiff, lambda_d3_Py);
+
+    /* Make the output z-coordinate */
+    out->Z = monty_mul(wc->mc, common_Z, lambda_d);
+
+    mp_free(lambda_n2);
+    mp_free(lambda_d2);
+    mp_free(lambda_d3);
+    mp_free(xsum);
+    mp_free(xdiff);
+    mp_free(lambda_d2_xsum);
+    mp_free(lambda_n_xdiff);
+    mp_free(lambda_d2_Px);
+    mp_free(lambda_d3_Py);
+}
+
+/*
+ * Shared code between add and add_general: put the two input points
+ * over a common denominator, and determine the slope lambda of the
+ * line through both of them. If the points have the same
+ * x-coordinate, then the slope will be returned with a zero
+ * denominator.
+ */
+static inline void ecc_weierstrass_add_prologue(
+    WeierstrassPoint *P, WeierstrassPoint *Q,
+    mp_int **Px, mp_int **Py, mp_int **Qx, mp_int **denom,
+    mp_int **lambda_n, mp_int **lambda_d)
+{
+    WeierstrassCurve *wc = P->wc;
+
+    /* Powers of the points' denominators */
+    mp_int *Pz2 = monty_mul(wc->mc, P->Z, P->Z);
+    mp_int *Pz3 = monty_mul(wc->mc, Pz2, P->Z);
+    mp_int *Qz2 = monty_mul(wc->mc, Q->Z, Q->Z);
+    mp_int *Qz3 = monty_mul(wc->mc, Qz2, Q->Z);
+
+    /* Points' x,y coordinates scaled by the other one's denominator
+     * (raised to the appropriate power) */
+    *Px = monty_mul(wc->mc, P->X, Qz2);
+    *Py = monty_mul(wc->mc, P->Y, Qz3);
+    *Qx = monty_mul(wc->mc, Q->X, Pz2);
+    mp_int *Qy = monty_mul(wc->mc, Q->Y, Pz3);
+
+    /* Common denominator */
+    *denom = monty_mul(wc->mc, P->Z, Q->Z);
+
+    /* Slope of the line through the two points, if P != Q */
+    *lambda_n = monty_sub(wc->mc, Qy, *Py);
+    *lambda_d = monty_sub(wc->mc, *Qx, *Px);
+
+    mp_free(Pz2);
+    mp_free(Pz3);
+    mp_free(Qz2);
+    mp_free(Qz3);
+    mp_free(Qy);
+}
+
+WeierstrassPoint *ecc_weierstrass_add(WeierstrassPoint *P, WeierstrassPoint *Q)
+{
+    WeierstrassCurve *wc = P->wc;
+    assert(Q->wc == wc);
+
+    WeierstrassPoint *S = ecc_weierstrass_point_new_empty(wc);
+
+    mp_int *Px, *Py, *Qx, *denom, *lambda_n, *lambda_d;
+    ecc_weierstrass_add_prologue(
+        P, Q, &Px, &Py, &Qx, &denom, &lambda_n, &lambda_d);
+
+    /* Never expect to have received two mutually inverse inputs, or
+     * two identical ones (which would make this a doubling). In other
+     * words, the two input x-coordinates (after putting over a common
+     * denominator) should never have been equal. */
+    assert(!mp_eq_integer(lambda_n, 0));
+
+    /* Now go to the common epilogue code. */
+    ecc_weierstrass_epilogue(Px, Qx, Py, denom, lambda_n, lambda_d, S);
+
+    mp_free(Px);
+    mp_free(Py);
+    mp_free(Qx);
+    mp_free(denom);
+    mp_free(lambda_n);
+    mp_free(lambda_d);
+
+    return S;
+}
+
+/*
+ * Code to determine the slope of the line you need to intersect with
+ * the curve in the case where you're adding a point to itself. In
+ * this situation you can't just say "the line through both input
+ * points" because that's under-determined; instead, you have to take
+ * the _tangent_ to the curve at the given point, by differentiating
+ * the curve equation y^2=x^3+ax+b to get 2y dy/dx = 3x^2+a.
+ */
+static inline void ecc_weierstrass_tangent_slope(
+    WeierstrassPoint *P, mp_int **lambda_n, mp_int **lambda_d)
+{
+    WeierstrassCurve *wc = P->wc;
+
+    mp_int *X2 = monty_mul(wc->mc, P->X, P->X);
+    mp_int *twoX2 = monty_add(wc->mc, X2, X2);
+    mp_int *threeX2 = monty_add(wc->mc, twoX2, X2);
+    mp_int *Z2 = monty_mul(wc->mc, P->Z, P->Z);
+    mp_int *Z4 = monty_mul(wc->mc, Z2, Z2);
+    mp_int *aZ4 = monty_mul(wc->mc, wc->a, Z4);
+
+    *lambda_n = monty_add(wc->mc, threeX2, aZ4);
+    *lambda_d = monty_add(wc->mc, P->Y, P->Y);
+
+    mp_free(X2);
+    mp_free(twoX2);
+    mp_free(threeX2);
+    mp_free(Z2);
+    mp_free(Z4);
+    mp_free(aZ4);
+}
+
+WeierstrassPoint *ecc_weierstrass_double(WeierstrassPoint *P)
+{
+    WeierstrassCurve *wc = P->wc;
+    WeierstrassPoint *D = ecc_weierstrass_point_new_empty(wc);
+
+    mp_int *lambda_n, *lambda_d;
+    ecc_weierstrass_tangent_slope(P, &lambda_n, &lambda_d);
+    ecc_weierstrass_epilogue(P->X, P->X, P->Y, P->Z, lambda_n, lambda_d, D);
+    mp_free(lambda_n);
+    mp_free(lambda_d);
+
+    return D;
+}
+
+static inline void ecc_weierstrass_select_into(
+    WeierstrassPoint *dest, WeierstrassPoint *P, WeierstrassPoint *Q,
+    unsigned choose_Q)
+{
+    mp_select_into(dest->X, P->X, Q->X, choose_Q);
+    mp_select_into(dest->Y, P->Y, Q->Y, choose_Q);
+    mp_select_into(dest->Z, P->Z, Q->Z, choose_Q);
+}
+
+WeierstrassPoint *ecc_weierstrass_add_general(
+    WeierstrassPoint *P, WeierstrassPoint *Q)
+{
+    WeierstrassCurve *wc = P->wc;
+    assert(Q->wc == wc);
+
+    WeierstrassPoint *S = ecc_weierstrass_point_new_empty(wc);
+
+    /* Parameters for the epilogue, and slope of the line if P != Q */
+    mp_int *Px, *Py, *Qx, *denom, *lambda_n, *lambda_d;
+    ecc_weierstrass_add_prologue(
+        P, Q, &Px, &Py, &Qx, &denom, &lambda_n, &lambda_d);
+
+    /* Slope if P == Q */
+    mp_int *lambda_n_tangent, *lambda_d_tangent;
+    ecc_weierstrass_tangent_slope(P, &lambda_n_tangent, &lambda_d_tangent);
+
+    /* Select between those slopes depending on whether P == Q */
+    unsigned same_x_coord = mp_eq_integer(lambda_d, 0);
+    unsigned same_y_coord = mp_eq_integer(lambda_n, 0);
+    unsigned equality = same_x_coord & same_y_coord;
+    mp_select_into(lambda_n, lambda_n, lambda_n_tangent, equality);
+    mp_select_into(lambda_d, lambda_d, lambda_d_tangent, equality);
+
+    /* Now go to the common code between addition and doubling */
+    ecc_weierstrass_epilogue(Px, Qx, Py, denom, lambda_n, lambda_d, S);
+
+    /* Check for the input identity cases, and overwrite the output if
+     * necessary. */
+    ecc_weierstrass_select_into(S, S, Q, mp_eq_integer(P->Z, 0));
+    ecc_weierstrass_select_into(S, S, P, mp_eq_integer(Q->Z, 0));
+
+    /*
+     * In the case where P == -Q and so the output is the identity,
+     * we'll have calculated lambda_d = 0 and so the output will have
+     * z==0 already. Detect that and use it to normalise the other two
+     * coordinates to zero.
+     */
+    unsigned output_id = mp_eq_integer(S->Z, 0);
+    mp_cond_clear(S->X, output_id);
+    mp_cond_clear(S->Y, output_id);
+
+    mp_free(Px);
+    mp_free(Py);
+    mp_free(Qx);
+    mp_free(denom);
+    mp_free(lambda_n);
+    mp_free(lambda_d);
+    mp_free(lambda_n_tangent);
+    mp_free(lambda_d_tangent);
+
+    return S;
+}
+
+WeierstrassPoint *ecc_weierstrass_multiply(WeierstrassPoint *B, mp_int *n)
+{
+    WeierstrassPoint *two_B = ecc_weierstrass_double(B);
+    WeierstrassPoint *k_B = ecc_weierstrass_point_copy(B);
+    WeierstrassPoint *kplus1_B = ecc_weierstrass_point_copy(two_B);
+
+    /*
+     * This multiply routine more or less follows the shape of the
+     * 'Montgomery ladder' technique that you have to use under the
+     * extra constraint on addition in Montgomery curves, because it
+     * was fresh in my mind and easier to just do it the same way. See
+     * the comment in ecc_montgomery_multiply.
+     */
+
+    unsigned not_started_yet = 1;
+    for (size_t bitindex = mp_max_bits(n); bitindex-- > 0 ;) {
+        unsigned nbit = mp_get_bit(n, bitindex);
+
+        WeierstrassPoint *sum = ecc_weierstrass_add(k_B, kplus1_B);
+        ecc_weierstrass_cond_swap(k_B, kplus1_B, nbit);
+        WeierstrassPoint *other = ecc_weierstrass_double(k_B);
+        ecc_weierstrass_point_free(k_B);
+        ecc_weierstrass_point_free(kplus1_B);
+        k_B = other;
+        kplus1_B = sum;
+        ecc_weierstrass_cond_swap(k_B, kplus1_B, nbit);
+
+        ecc_weierstrass_cond_overwrite(k_B, B, not_started_yet);
+        ecc_weierstrass_cond_overwrite(kplus1_B, two_B, not_started_yet);
+        not_started_yet &= ~nbit;
+    }
+
+    ecc_weierstrass_point_free(two_B);
+    ecc_weierstrass_point_free(kplus1_B);
+    return k_B;
+}
+
+unsigned ecc_weierstrass_is_identity(WeierstrassPoint *wp)
+{
+    return mp_eq_integer(wp->Z, 0);
+}
+
+/*
+ * Normalise a point by scaling its Jacobian coordinates so that Z=1.
+ * This doesn't change what point is represented by the triple, but it
+ * means the affine x,y can now be easily recovered from X and Y.
+ */
+static void ecc_weierstrass_normalise(WeierstrassPoint *wp)
+{
+    WeierstrassCurve *wc = wp->wc;
+    mp_int *zinv = monty_invert(wc->mc, wp->Z);
+    mp_int *zinv2 = monty_mul(wc->mc, zinv, zinv);
+    mp_int *zinv3 = monty_mul(wc->mc, zinv2, zinv);
+    monty_mul_into(wc->mc, wp->X, wp->X, zinv2);
+    monty_mul_into(wc->mc, wp->Y, wp->Y, zinv3);
+    mp_free(zinv);
+    mp_free(zinv2);
+    mp_free(zinv3);
+    mp_copy_into(wp->Z, monty_identity(wc->mc));
+}
+
+void ecc_weierstrass_get_affine(
+    WeierstrassPoint *wp, mp_int **x, mp_int **y)
+{
+    WeierstrassCurve *wc = wp->wc;
+
+    ecc_weierstrass_normalise(wp);
+
+    if (x)
+        *x = monty_export(wc->mc, wp->X);
+    if (y)
+        *y = monty_export(wc->mc, wp->Y);
+}
+
+unsigned ecc_weierstrass_point_valid(WeierstrassPoint *P)
+{
+    mp_int *rhs = ecc_weierstrass_equation_rhs(P->wc, P->X);
+    mp_int *lhs = monty_mul(P->wc->mc, P->Y, P->Y);
+    unsigned valid = mp_cmp_eq(lhs, rhs);
+    mp_free(lhs);
+    mp_free(rhs);
+    return valid;
+}
+
+/* ----------------------------------------------------------------------
+ * Montgomery curves.
+ */
+
+struct MontgomeryPoint {
+    /* XZ coordinates. These represent the affine x coordinate by the
+     * relationship x = X/Z. */
+    mp_int *X, *Z;
+
+    MontgomeryCurve *mc;
+};
+
+struct MontgomeryCurve {
+    /* Prime modulus of the finite field. */
+    mp_int *p;
+
+    /* Montgomery context for arithmetic mod p. */
+    MontyContext *mc;
+
+    /* Parameters of the curve, in Montgomery-multiplication
+     * transformed form. */
+    mp_int *a, *b;
+
+    /* (a+2)/4, also in Montgomery-multiplication form. */
+    mp_int *aplus2over4;
+};
+
+MontgomeryCurve *ecc_montgomery_curve(
+    mp_int *p, mp_int *a, mp_int *b)
+{
+    MontgomeryCurve *mc = snew(MontgomeryCurve);
+    mc->p = mp_copy(p);
+    mc->mc = monty_new(p);
+    mc->a = monty_import(mc->mc, a);
+    mc->b = monty_import(mc->mc, b);
+
+    mp_int *four = mp_from_integer(4);
+    mp_int *fourinverse = mp_invert(four, mc->p);
+    mp_int *aplus2 = mp_copy(a);
+    mp_add_integer_into(aplus2, aplus2, 2);
+    mp_int *aplus2over4 = mp_modmul(aplus2, fourinverse, mc->p);
+    mc->aplus2over4 = monty_import(mc->mc, aplus2over4);
+    mp_free(four);
+    mp_free(fourinverse);
+    mp_free(aplus2);
+    mp_free(aplus2over4);
+
+    return mc;
+}
+
+void ecc_montgomery_curve_free(MontgomeryCurve *mc)
+{
+    mp_free(mc->p);
+    mp_free(mc->a);
+    mp_free(mc->b);
+    mp_free(mc->aplus2over4);
+    monty_free(mc->mc);
+    sfree(mc);
+}
+
+static MontgomeryPoint *ecc_montgomery_point_new_empty(MontgomeryCurve *mc)
+{
+    MontgomeryPoint *mp = snew(MontgomeryPoint);
+    mp->mc = mc;
+    mp->X = mp->Z = NULL;
+    return mp;
+}
+
+MontgomeryPoint *ecc_montgomery_point_new(MontgomeryCurve *mc, mp_int *x)
+{
+    MontgomeryPoint *mp = ecc_montgomery_point_new_empty(mc);
+    mp->X = monty_import(mc->mc, x);
+    mp->Z = mp_copy(monty_identity(mc->mc));
+    return mp;
+}
+
+MontgomeryPoint *ecc_montgomery_point_copy(MontgomeryPoint *orig)
+{
+    MontgomeryPoint *mp = ecc_montgomery_point_new_empty(orig->mc);
+    mp->X = mp_copy(orig->X);
+    mp->Z = mp_copy(orig->Z);
+    return mp;
+}
+
+void ecc_montgomery_point_free(MontgomeryPoint *mp)
+{
+    mp_free(mp->X);
+    mp_free(mp->Z);
+    smemclr(mp, sizeof(*mp));
+    sfree(mp);
+}
+
+static void ecc_montgomery_cond_overwrite(
+    MontgomeryPoint *dest, MontgomeryPoint *src, unsigned overwrite)
+{
+    mp_select_into(dest->X, dest->X, src->X, overwrite);
+    mp_select_into(dest->Z, dest->Z, src->Z, overwrite);
+}
+
+static void ecc_montgomery_cond_swap(
+    MontgomeryPoint *P, MontgomeryPoint *Q, unsigned swap)
+{
+    mp_cond_swap(P->X, Q->X, swap);
+    mp_cond_swap(P->Z, Q->Z, swap);
+}
+
+MontgomeryPoint *ecc_montgomery_diff_add(
+    MontgomeryPoint *P, MontgomeryPoint *Q, MontgomeryPoint *PminusQ)
+{
+    MontgomeryCurve *mc = P->mc;
+    assert(Q->mc == mc);
+    assert(PminusQ->mc == mc);
+
+    /*
+     * Differential addition is achieved using the following formula
+     * that relates the affine x-coordinates of P, Q, P+Q and P-Q:
+     *
+     * x(P+Q) x(P-Q) (x(Q)-x(P))^2 = (x(P)x(Q) - 1)^2
+     *
+     * As with the Weierstrass coordinates, the code below transforms
+     * that affine relation into a projective one to avoid having to
+     * do a division during the main arithmetic.
+     */
+
+    MontgomeryPoint *S = ecc_montgomery_point_new_empty(mc);
+
+    mp_int *Px_m_Pz = monty_sub(mc->mc, P->X, P->Z);
+    mp_int *Px_p_Pz = monty_add(mc->mc, P->X, P->Z);
+    mp_int *Qx_m_Qz = monty_sub(mc->mc, Q->X, Q->Z);
+    mp_int *Qx_p_Qz = monty_add(mc->mc, Q->X, Q->Z);
+    mp_int *PmQp = monty_mul(mc->mc, Px_m_Pz, Qx_p_Qz);
+    mp_int *PpQm = monty_mul(mc->mc, Px_p_Pz, Qx_m_Qz);
+    mp_int *Xpre = monty_add(mc->mc, PmQp, PpQm);
+    mp_int *Zpre = monty_sub(mc->mc, PmQp, PpQm);
+    mp_int *Xpre2 = monty_mul(mc->mc, Xpre, Xpre);
+    mp_int *Zpre2 = monty_mul(mc->mc, Zpre, Zpre);
+    S->X = monty_mul(mc->mc, Xpre2, PminusQ->Z);
+    S->Z = monty_mul(mc->mc, Zpre2, PminusQ->X);
+
+    mp_free(Px_m_Pz);
+    mp_free(Px_p_Pz);
+    mp_free(Qx_m_Qz);
+    mp_free(Qx_p_Qz);
+    mp_free(PmQp);
+    mp_free(PpQm);
+    mp_free(Xpre);
+    mp_free(Zpre);
+    mp_free(Xpre2);
+    mp_free(Zpre2);
+
+    return S;
+}
+
+MontgomeryPoint *ecc_montgomery_double(MontgomeryPoint *P)
+{
+    MontgomeryCurve *mc = P->mc;
+    MontgomeryPoint *D = ecc_montgomery_point_new_empty(mc);
+
+    /*
+     * To double a point in affine coordinates, in principle you can
+     * use the same technique as for Weierstrass: differentiate the
+     * curve equation to get the tangent line at the input point, use
+     * that to get an expression for y which you substitute back into
+     * the curve equation, and subtract the known two roots (in this
+     * case both the same) from the x^2 coefficient of the resulting
+     * cubic.
+     *
+     * In this case, we don't have an input y-coordinate, so you have
+     * to do a bit of extra transformation to find a formula that can
+     * work without it. The tangent formula is (3x^2 + 2ax + 1)/(2y),
+     * and when that appears in the final formula it will be squared -
+     * so we can substitute the y^2 in the denominator for the RHS of
+     * the curve equation. Put together, that gives
+     *
+     *   x_out = (x+1)^2 (x-1)^2 / 4(x^3+ax^2+x)
+     *
+     * and, as usual, the code below transforms that into projective
+     * form to avoid the division.
+     */
+
+    mp_int *Px_m_Pz = monty_sub(mc->mc, P->X, P->Z);
+    mp_int *Px_p_Pz = monty_add(mc->mc, P->X, P->Z);
+    mp_int *Px_m_Pz_2 = monty_mul(mc->mc, Px_m_Pz, Px_m_Pz);
+    mp_int *Px_p_Pz_2 = monty_mul(mc->mc, Px_p_Pz, Px_p_Pz);
+    D->X = monty_mul(mc->mc, Px_m_Pz_2, Px_p_Pz_2);
+    mp_int *XZ = monty_mul(mc->mc, P->X, P->Z);
+    mp_int *twoXZ = monty_add(mc->mc, XZ, XZ);
+    mp_int *fourXZ = monty_add(mc->mc, twoXZ, twoXZ);
+    mp_int *fourXZ_scaled = monty_mul(mc->mc, fourXZ, mc->aplus2over4);
+    mp_int *Zpre = monty_add(mc->mc, Px_m_Pz_2, fourXZ_scaled);
+    D->Z = monty_mul(mc->mc, fourXZ, Zpre);
+
+    mp_free(Px_m_Pz);
+    mp_free(Px_p_Pz);
+    mp_free(Px_m_Pz_2);
+    mp_free(Px_p_Pz_2);
+    mp_free(XZ);
+    mp_free(twoXZ);
+    mp_free(fourXZ);
+    mp_free(fourXZ_scaled);
+    mp_free(Zpre);
+
+    return D;
+}
+
+static void ecc_montgomery_normalise(MontgomeryPoint *mp)
+{
+    MontgomeryCurve *mc = mp->mc;
+    mp_int *zinv = monty_invert(mc->mc, mp->Z);
+    monty_mul_into(mc->mc, mp->X, mp->X, zinv);
+    mp_free(zinv);
+    mp_copy_into(mp->Z, monty_identity(mc->mc));
+}
+
+MontgomeryPoint *ecc_montgomery_multiply(MontgomeryPoint *B, mp_int *n)
+{
+    /*
+     * 'Montgomery ladder' technique, to compute an arbitrary integer
+     * multiple of B under the constraint that you can only add two
+     * unequal points if you also know their difference.
+     *
+     * The setup is that you maintain two curve points one of which is
+     * always the other one plus B. Call them kB and (k+1)B, where k
+     * is some integer that evolves as we go along. We begin by
+     * doubling the input B, to initialise those points to B and 2B,
+     * so that k=1.
+     *
+     * At each stage, we add kB and (k+1)B together - which we can do
+     * under the differential-addition constraint because we know
+     * their difference is always just B - to give us (2k+1)B. Then we
+     * double one of kB or (k+1)B, and depending on which one we
+     * choose, we end up with (2k)B or (2k+2)B. Either way, that
+     * differs by B from the other value we've just computed. So in
+     * each iteration, we do one diff-add and one doubling, plus a
+     * couple of conditional swaps to choose which value we double and
+     * which way round we put the output points, and the effect is to
+     * replace k with either 2k or 2k+1, which we choose based on the
+     * appropriate bit of the desired exponent.
+     *
+     * This routine doesn't assume we know the exact location of the
+     * topmost set bit of the exponent. So to maintain constant time
+     * it does an iteration for every _potential_ bit, starting from
+     * the top downwards; after each iteration in which we haven't
+     * seen a set exponent bit yet, we just overwrite the two points
+     * with B and 2B again,
+     */
+
+    MontgomeryPoint *two_B = ecc_montgomery_double(B);
+    MontgomeryPoint *k_B = ecc_montgomery_point_copy(B);
+    MontgomeryPoint *kplus1_B = ecc_montgomery_point_copy(two_B);
+
+    unsigned not_started_yet = 1;
+    for (size_t bitindex = mp_max_bits(n); bitindex-- > 0 ;) {
+        unsigned nbit = mp_get_bit(n, bitindex);
+
+        MontgomeryPoint *sum = ecc_montgomery_diff_add(k_B, kplus1_B, B);
+        ecc_montgomery_cond_swap(k_B, kplus1_B, nbit);
+        MontgomeryPoint *other = ecc_montgomery_double(k_B);
+        ecc_montgomery_point_free(k_B);
+        ecc_montgomery_point_free(kplus1_B);
+        k_B = other;
+        kplus1_B = sum;
+        ecc_montgomery_cond_swap(k_B, kplus1_B, nbit);
+
+        ecc_montgomery_cond_overwrite(k_B, B, not_started_yet);
+        ecc_montgomery_cond_overwrite(kplus1_B, two_B, not_started_yet);
+        not_started_yet &= ~nbit;
+    }
+
+    ecc_montgomery_point_free(two_B);
+    ecc_montgomery_point_free(kplus1_B);
+    return k_B;
+}
+
+void ecc_montgomery_get_affine(MontgomeryPoint *mp, mp_int **x)
+{
+    MontgomeryCurve *mc = mp->mc;
+
+    ecc_montgomery_normalise(mp);
+
+    if (x)
+        *x = monty_export(mc->mc, mp->X);
+}
+
+/* ----------------------------------------------------------------------
+ * Twisted Edwards curves.
+ */
+
+struct EdwardsPoint {
+    /*
+     * We represent an Edwards curve point in 'extended coordinates'.
+     * There's more than one coordinate system going by that name,
+     * unfortunately. These ones have the semantics that X,Y,Z are
+     * ordinary projective coordinates (so x=X/Z and y=Y/Z), but also,
+     * we store the extra value T = xyZ = XY/Z.
+     */
+    mp_int *X, *Y, *Z, *T;
+
+    EdwardsCurve *ec;
+};
+
+struct EdwardsCurve {
+    /* Prime modulus of the finite field. */
+    mp_int *p;
+
+    /* Montgomery context for arithmetic mod p. */
+    MontyContext *mc;
+
+    /* Modsqrt context for point decompression. */
+    ModsqrtContext *sc;
+
+    /* Parameters of the curve, in Montgomery-multiplication
+     * transformed form. */
+    mp_int *d, *a;
+};
+
+EdwardsCurve *ecc_edwards_curve(mp_int *p, mp_int *d, mp_int *a,
+                                mp_int *nonsquare_mod_p)
+{
+    EdwardsCurve *ec = snew(EdwardsCurve);
+    ec->p = mp_copy(p);
+    ec->mc = monty_new(p);
+    ec->d = monty_import(ec->mc, d);
+    ec->a = monty_import(ec->mc, a);
+
+    if (nonsquare_mod_p)
+        ec->sc = modsqrt_new(p, nonsquare_mod_p);
+    else
+        ec->sc = NULL;
+
+    return ec;
+}
+
+void ecc_edwards_curve_free(EdwardsCurve *ec)
+{
+    mp_free(ec->p);
+    mp_free(ec->d);
+    mp_free(ec->a);
+    monty_free(ec->mc);
+    if (ec->sc)
+        modsqrt_free(ec->sc);
+    sfree(ec);
+}
+
+static EdwardsPoint *ecc_edwards_point_new_empty(EdwardsCurve *ec)
+{
+    EdwardsPoint *ep = snew(EdwardsPoint);
+    ep->ec = ec;
+    ep->X = ep->Y = ep->Z = ep->T = NULL;
+    return ep;
+}
+
+static EdwardsPoint *ecc_edwards_point_new_imported(
+    EdwardsCurve *ec, mp_int *monty_x, mp_int *monty_y)
+{
+    EdwardsPoint *ep = ecc_edwards_point_new_empty(ec);
+    ep->X = monty_x;
+    ep->Y = monty_y;
+    ep->T = monty_mul(ec->mc, ep->X, ep->Y);
+    ep->Z = mp_copy(monty_identity(ec->mc));
+    return ep;
+}
+
+EdwardsPoint *ecc_edwards_point_new(
+    EdwardsCurve *ec, mp_int *x, mp_int *y)
+{
+    return ecc_edwards_point_new_imported(
+        ec, monty_import(ec->mc, x), monty_import(ec->mc, y));
+}
+
+EdwardsPoint *ecc_edwards_point_copy(EdwardsPoint *orig)
+{
+    EdwardsPoint *ep = ecc_edwards_point_new_empty(orig->ec);
+    ep->X = mp_copy(orig->X);
+    ep->Y = mp_copy(orig->Y);
+    ep->Z = mp_copy(orig->Z);
+    ep->T = mp_copy(orig->T);
+    return ep;
+}
+
+void ecc_edwards_point_free(EdwardsPoint *ep)
+{
+    mp_free(ep->X);
+    mp_free(ep->Y);
+    mp_free(ep->Z);
+    mp_free(ep->T);
+    smemclr(ep, sizeof(*ep));
+    sfree(ep);
+}
+
+EdwardsPoint *ecc_edwards_point_new_from_y(
+    EdwardsCurve *ec, mp_int *yorig, unsigned desired_x_parity)
+{
+    assert(ec->sc);
+
+    /*
+     * The curve equation is ax^2 + y^2 = 1 + dx^2y^2, which
+     * rearranges to x^2(dy^2-a) = y^2-1. So we compute
+     * (y^2-1)/(dy^2-a) and take its square root.
+     */
+    unsigned success;
+
+    mp_int *y = monty_import(ec->mc, yorig);
+    mp_int *y2 = monty_mul(ec->mc, y, y);
+    mp_int *dy2 = monty_mul(ec->mc, ec->d, y2);
+    mp_int *dy2ma = monty_sub(ec->mc, dy2, ec->a);
+    mp_int *y2m1 = monty_sub(ec->mc, y2, monty_identity(ec->mc));
+    mp_int *recip_denominator = monty_invert(ec->mc, dy2ma);
+    mp_int *radicand = monty_mul(ec->mc, y2m1, recip_denominator);
+    mp_int *x = monty_modsqrt(ec->sc, radicand, &success);
+    mp_free(y2);
+    mp_free(dy2);
+    mp_free(dy2ma);
+    mp_free(y2m1);
+    mp_free(recip_denominator);
+    mp_free(radicand);
+
+    if (!success) {
+        /* Failure! x^2 worked out to be a number that has no square
+         * root mod p. In this situation there's no point in trying to
+         * be time-constant, since the protocol sequence is going to
+         * diverge anyway when we complain to whoever gave us this
+         * bogus value. */
+        mp_free(x);
+        mp_free(y);
+        return NULL;
+    }
+
+    /*
+     * Choose whichever of x and p-x has the specified parity (of its
+     * lowest positive residue mod p).
+     */
+    mp_int *tmp = monty_export(ec->mc, x);
+    unsigned flip = (mp_get_bit(tmp, 0) ^ desired_x_parity) & 1;
+    mp_sub_into(tmp, ec->p, x);
+    mp_select_into(x, x, tmp, flip);
+    mp_free(tmp);
+
+    return ecc_edwards_point_new_imported(ec, x, y);
+}
+
+static void ecc_edwards_cond_overwrite(
+    EdwardsPoint *dest, EdwardsPoint *src, unsigned overwrite)
+{
+    mp_select_into(dest->X, dest->X, src->X, overwrite);
+    mp_select_into(dest->Y, dest->Y, src->Y, overwrite);
+    mp_select_into(dest->Z, dest->Z, src->Z, overwrite);
+    mp_select_into(dest->T, dest->T, src->T, overwrite);
+}
+
+static void ecc_edwards_cond_swap(
+    EdwardsPoint *P, EdwardsPoint *Q, unsigned swap)
+{
+    mp_cond_swap(P->X, Q->X, swap);
+    mp_cond_swap(P->Y, Q->Y, swap);
+    mp_cond_swap(P->Z, Q->Z, swap);
+    mp_cond_swap(P->T, Q->T, swap);
+}
+
+EdwardsPoint *ecc_edwards_add(EdwardsPoint *P, EdwardsPoint *Q)
+{
+    EdwardsCurve *ec = P->ec;
+    assert(Q->ec == ec);
+
+    EdwardsPoint *S = ecc_edwards_point_new_empty(ec);
+
+    /*
+     * The affine rule for Edwards addition of (x1,y1) and (x2,y2) is
+     *
+     *   x_out = (x1 y2 +   y1 x2) / (1 + d x1 x2 y1 y2)
+     *   y_out = (y1 y2 - a x1 x2) / (1 - d x1 x2 y1 y2)
+     *
+     * The formulae below are listed as 'add-2008-hwcd' in
+     * https://hyperelliptic.org/EFD/g1p/auto-twisted-extended.html
+     *
+     * and if you undo the careful optimisation to find out what
+     * they're actually computing, it comes out to
+     *
+     *   X_out = (X1 Y2 +   Y1 X2) (Z1 Z2 - d T1 T2)
+     *   Y_out = (Y1 Y2 - a X1 X2) (Z1 Z2 + d T1 T2)
+     *   Z_out = (Z1 Z2 - d T1 T2) (Z1 Z2 + d T1 T2)
+     *   T_out = (X1 Y2 +   Y1 X2) (Y1 Y2 - a X1 X2)
+     */
+    mp_int *PxQx = monty_mul(ec->mc, P->X, Q->X);
+    mp_int *PyQy = monty_mul(ec->mc, P->Y, Q->Y);
+    mp_int *PtQt = monty_mul(ec->mc, P->T, Q->T);
+    mp_int *PzQz = monty_mul(ec->mc, P->Z, Q->Z);
+    mp_int *Psum = monty_add(ec->mc, P->X, P->Y);
+    mp_int *Qsum = monty_add(ec->mc, Q->X, Q->Y);
+    mp_int *aPxQx = monty_mul(ec->mc, ec->a, PxQx);
+    mp_int *dPtQt = monty_mul(ec->mc, ec->d, PtQt);
+    mp_int *sumprod = monty_mul(ec->mc, Psum, Qsum);
+    mp_int *xx_p_yy = monty_add(ec->mc, PxQx, PyQy);
+    mp_int *E = monty_sub(ec->mc, sumprod, xx_p_yy);
+    mp_int *F = monty_sub(ec->mc, PzQz, dPtQt);
+    mp_int *G = monty_add(ec->mc, PzQz, dPtQt);
+    mp_int *H = monty_sub(ec->mc, PyQy, aPxQx);
+    S->X = monty_mul(ec->mc, E, F);
+    S->Z = monty_mul(ec->mc, F, G);
+    S->Y = monty_mul(ec->mc, G, H);
+    S->T = monty_mul(ec->mc, H, E);
+
+    mp_free(PxQx);
+    mp_free(PyQy);
+    mp_free(PtQt);
+    mp_free(PzQz);
+    mp_free(Psum);
+    mp_free(Qsum);
+    mp_free(aPxQx);
+    mp_free(dPtQt);
+    mp_free(sumprod);
+    mp_free(xx_p_yy);
+    mp_free(E);
+    mp_free(F);
+    mp_free(G);
+    mp_free(H);
+
+    return S;
+}
+
+static void ecc_edwards_normalise(EdwardsPoint *ep)
+{
+    EdwardsCurve *ec = ep->ec;
+    mp_int *zinv = monty_invert(ec->mc, ep->Z);
+    monty_mul_into(ec->mc, ep->X, ep->X, zinv);
+    monty_mul_into(ec->mc, ep->Y, ep->Y, zinv);
+    mp_free(zinv);
+    mp_copy_into(ep->Z, monty_identity(ec->mc));
+    monty_mul_into(ec->mc, ep->T, ep->X, ep->Y);
+}
+
+EdwardsPoint *ecc_edwards_multiply(EdwardsPoint *B, mp_int *n)
+{
+    EdwardsPoint *two_B = ecc_edwards_add(B, B);
+    EdwardsPoint *k_B = ecc_edwards_point_copy(B);
+    EdwardsPoint *kplus1_B = ecc_edwards_point_copy(two_B);
+
+    /*
+     * Another copy of the same exponentiation routine following the
+     * pattern of the Montgomery ladder, because it works as well as
+     * any other technique and this way I didn't have to debug two of
+     * them.
+     */
+
+    unsigned not_started_yet = 1;
+    for (size_t bitindex = mp_max_bits(n); bitindex-- > 0 ;) {
+        unsigned nbit = mp_get_bit(n, bitindex);
+
+        EdwardsPoint *sum = ecc_edwards_add(k_B, kplus1_B);
+        ecc_edwards_cond_swap(k_B, kplus1_B, nbit);
+        EdwardsPoint *other = ecc_edwards_add(k_B, k_B);
+        ecc_edwards_point_free(k_B);
+        ecc_edwards_point_free(kplus1_B);
+        k_B = other;
+        kplus1_B = sum;
+        ecc_edwards_cond_swap(k_B, kplus1_B, nbit);
+
+        ecc_edwards_cond_overwrite(k_B, B, not_started_yet);
+        ecc_edwards_cond_overwrite(kplus1_B, two_B, not_started_yet);
+        not_started_yet &= ~nbit;
+    }
+
+    ecc_edwards_point_free(two_B);
+    ecc_edwards_point_free(kplus1_B);
+    return k_B;
+}
+
+/*
+ * Helper routine to determine whether two values each given as a pair
+ * of projective coordinates represent the same affine value.
+ */
+static inline unsigned projective_eq(
+    MontyContext *mc, mp_int *An, mp_int *Ad,
+    mp_int *Bn, mp_int *Bd)
+{
+    mp_int *AnBd = monty_mul(mc, An, Bd);
+    mp_int *BnAd = monty_mul(mc, Bn, Ad);
+    unsigned toret = mp_cmp_eq(AnBd, BnAd);
+    mp_free(AnBd);
+    mp_free(BnAd);
+    return toret;
+}
+
+unsigned ecc_edwards_eq(EdwardsPoint *P, EdwardsPoint *Q)
+{
+    EdwardsCurve *ec = P->ec;
+    assert(Q->ec == ec);
+
+    return (projective_eq(ec->mc, P->X, P->Z, Q->X, Q->Z) &
+            projective_eq(ec->mc, P->Y, P->Z, Q->Y, Q->Z));
+}
+
+void ecc_edwards_get_affine(EdwardsPoint *ep, mp_int **x, mp_int **y)
+{
+    EdwardsCurve *ec = ep->ec;
+
+    ecc_edwards_normalise(ep);
+
+    if (x)
+        *x = monty_export(ec->mc, ep->X);
+    if (y)
+        *y = monty_export(ec->mc, ep->Y);
+}

+ 233 - 0
source/putty/ecc.h

@@ -0,0 +1,233 @@
+#ifndef PUTTY_ECC_H
+#define PUTTY_ECC_H
+
+/*
+ * Arithmetic functions for the various kinds of elliptic curves used
+ * by PuTTY's public-key cryptography.
+ *
+ * All of these elliptic curves are over the finite field whose order
+ * is a large prime p. (Elliptic curves over a field of order 2^n are
+ * also known, but PuTTY currently has no need of them.)
+ */
+
+/* ----------------------------------------------------------------------
+ * Weierstrass curves (or rather, 'short form' Weierstrass curves).
+ *
+ * A curve in this form is defined by two parameters a,b, and the
+ * non-identity points on the curve are represented by (x,y) (the
+ * 'affine coordinates') such that y^2 = x^3 + ax + b.
+ *
+ * The identity element of the curve's group is an additional 'point
+ * at infinity', which is considered to be the third point on the
+ * intersection of the curve with any vertical line. Hence, the
+ * inverse of the point (x,y) is (x,-y).
+ */
+
+/*
+ * Create and destroy Weierstrass curve data structures. The mandatory
+ * parameters to the constructor are the prime modulus p, and the
+ * curve parameters a,b.
+ *
+ * 'nonsquare_mod_p' is an optional extra parameter, only needed by
+ * ecc_edwards_point_new_from_y which has to take a modular square
+ * root. You can pass it as NULL if you don't need that function.
+ */
+WeierstrassCurve *ecc_weierstrass_curve(
+    mp_int *p, mp_int *a, mp_int *b, mp_int *nonsquare_mod_p);
+void ecc_weierstrass_curve_free(WeierstrassCurve *);
+
+/*
+ * Create points on a Weierstrass curve, given the curve.
+ *
+ * point_new_identity returns the special identity point.
+ * point_new(x,y) returns the non-identity point with the given affine
+ * coordinates.
+ *
+ * point_new_from_x constructs a non-identity point given only the
+ * x-coordinate, by using the curve equation to work out what y has to
+ * be. Of course the equation only tells you y^2, so it only
+ * determines y up to sign; the parameter desired_y_parity controls
+ * which of the two values of y you get, by saying whether you'd like
+ * its minimal non-negative residue mod p to be even or odd. (Of
+ * course, since p itself is odd, exactly one of y and p-y is odd.)
+ * This function has to take a modular square root, so it will only
+ * work if you passed in a non-square mod p when constructing the
+ * curve.
+ */
+WeierstrassPoint *ecc_weierstrass_point_new_identity(WeierstrassCurve *curve);
+WeierstrassPoint *ecc_weierstrass_point_new(
+    WeierstrassCurve *curve, mp_int *x, mp_int *y);
+WeierstrassPoint *ecc_weierstrass_point_new_from_x(
+    WeierstrassCurve *curve, mp_int *x, unsigned desired_y_parity);
+
+/* Memory management: copy and free points. */
+WeierstrassPoint *ecc_weierstrass_point_copy(WeierstrassPoint *wc);
+void ecc_weierstrass_point_free(WeierstrassPoint *point);
+
+/* Check whether a point is actually on the curve. */
+unsigned ecc_weierstrass_point_valid(WeierstrassPoint *);
+
+/*
+ * Add two points and return their sum. This function is fully
+ * general: it should do the right thing if the two inputs are the
+ * same, or if either (or both) of the input points is the identity,
+ * or if the two input points are inverses so the output is the
+ * identity. However, it pays for that generality by being slower than
+ * the special-purpose functions below..
+ */
+WeierstrassPoint *ecc_weierstrass_add_general(
+    WeierstrassPoint *, WeierstrassPoint *);
+
+/*
+ * Fast but less general arithmetic functions: add two points on the
+ * condition that they are not equal and neither is the identity, and
+ * add a point to itself.
+ */
+WeierstrassPoint *ecc_weierstrass_add(WeierstrassPoint *, WeierstrassPoint *);
+WeierstrassPoint *ecc_weierstrass_double(WeierstrassPoint *);
+
+/*
+ * Compute an integer multiple of a point. Not guaranteed to work
+ * unless the integer argument is less than the order of the point in
+ * the group (because it won't cope if an identity element shows up in
+ * any intermediate product).
+ */
+WeierstrassPoint *ecc_weierstrass_multiply(WeierstrassPoint *, mp_int *);
+
+/*
+ * Query functions to get the value of a point back out. is_identity
+ * tells you whether the point is the identity; if it isn't, then
+ * get_affine will retrieve one or both of its affine coordinates.
+ * (You can pass NULL as either output pointer, if you don't need that
+ * coordinate as output.)
+ */
+unsigned ecc_weierstrass_is_identity(WeierstrassPoint *wp);
+void ecc_weierstrass_get_affine(WeierstrassPoint *wp, mp_int **x, mp_int **y);
+
+/* ----------------------------------------------------------------------
+ * Montgomery curves.
+ *
+ * A curve in this form is defined by two parameters a,b, and the
+ * curve equation is y^2 = x^3 + ax^2 + bx.
+ *
+ * As with Weierstrass curves, there's an additional point at infinity
+ * that is the identity element, and the inverse of (x,y) is (x,-y).
+ *
+ * However, we don't actually work with full (x,y) pairs. We just
+ * store the x-coordinate (so what we're really representing is not a
+ * specific point on the curve but a two-point set {P,-P}). This means
+ * you can't quite do point addition, because if you're given {P,-P}
+ * and {Q,-Q} as input, you can work out a pair of x-coordinates that
+ * are those of P-Q and P+Q, but you don't know which is which.
+ *
+ * Instead, the basic operation is 'differential addition', in which
+ * you are given three parameters P, Q and P-Q and you return P+Q. (As
+ * well as disambiguating which of the possible answers you want, that
+ * extra input also enables a fast formulae for computing it. This
+ * fast formula is more or less why Montgomery curves are useful in
+ * the first place.)
+ *
+ * Doubling a point is still possible to do unambiguously, so you can
+ * still compute an integer multiple of P if you start by making 2P
+ * and then doing a series of differential additions.
+ */
+
+/*
+ * Create and destroy Montgomery curve data structures.
+ */
+MontgomeryCurve *ecc_montgomery_curve(mp_int *p, mp_int *a, mp_int *b);
+void ecc_montgomery_curve_free(MontgomeryCurve *);
+
+/*
+ * Create, copy and free points on the curve. We don't need to
+ * explicitly represent the identity for this application.
+ */
+MontgomeryPoint *ecc_montgomery_point_new(MontgomeryCurve *mc, mp_int *x);
+MontgomeryPoint *ecc_montgomery_point_copy(MontgomeryPoint *orig);
+void ecc_montgomery_point_free(MontgomeryPoint *mp);
+
+/*
+ * Basic arithmetic routines: differential addition and point-
+ * doubling. Each of these assumes that no special cases come up - no
+ * input or output point should be the identity, and in diff_add, P
+ * and Q shouldn't be the same.
+ */
+MontgomeryPoint *ecc_montgomery_diff_add(
+    MontgomeryPoint *P, MontgomeryPoint *Q, MontgomeryPoint *PminusQ);
+MontgomeryPoint *ecc_montgomery_double(MontgomeryPoint *P);
+
+/*
+ * Compute an integer multiple of a point.
+ */
+MontgomeryPoint *ecc_montgomery_multiply(MontgomeryPoint *, mp_int *);
+
+/*
+ * Return the affine x-coordinate of a point.
+ */
+void ecc_montgomery_get_affine(MontgomeryPoint *mp, mp_int **x);
+
+/* ----------------------------------------------------------------------
+ * Twisted Edwards curves.
+ *
+ * A curve in this form is defined by two parameters d,a, and the
+ * curve equation is a x^2 + y^2 = 1 + d x^2 y^2.
+ *
+ * Apparently if you ask a proper algebraic geometer they'll tell you
+ * that this is technically not an actual elliptic curve. Certainly it
+ * doesn't work quite the same way as the other kinds: in this form,
+ * there is no need for a point at infinity, because the identity
+ * element is represented by the affine coordinates (0,1). And you
+ * invert a point by negating its x rather than y coordinate: the
+ * inverse of (x,y) is (-x,y).
+ *
+ * The usefulness of this representation is that the addition formula
+ * is 'strongly unified', meaning that the same formula works for any
+ * input and output points, without needing special cases for the
+ * identity or for doubling.
+ */
+
+/*
+ * Create and destroy Edwards curve data structures.
+ *
+ * Similarly to ecc_weierstrass_curve, you don't have to provide
+ * nonsquare_mod_p if you don't need ecc_edwards_point_new_from_y.
+ */
+EdwardsCurve *ecc_edwards_curve(
+    mp_int *p, mp_int *d, mp_int *a, mp_int *nonsquare_mod_p);
+void ecc_edwards_curve_free(EdwardsCurve *);
+
+/*
+ * Create points.
+ *
+ * There's no need to have a separate function to create the identity
+ * point, because you can just pass x=0 and y=1 to the usual function.
+ *
+ * Similarly to the Weierstrass curve, ecc_edwards_point_new_from_y
+ * creates a point given only its y-coordinate and the desired parity
+ * of its x-coordinate, and you can only call it if you provided the
+ * optional nonsquare_mod_p argument when creating the curve.
+ */
+EdwardsPoint *ecc_edwards_point_new(
+    EdwardsCurve *curve, mp_int *x, mp_int *y);
+EdwardsPoint *ecc_edwards_point_new_from_y(
+    EdwardsCurve *curve, mp_int *y, unsigned desired_x_parity);
+
+/* Copy and free points. */
+EdwardsPoint *ecc_edwards_point_copy(EdwardsPoint *ec);
+void ecc_edwards_point_free(EdwardsPoint *point);
+
+/*
+ * Arithmetic: add two points, and calculate an integer multiple of a
+ * point.
+ */
+EdwardsPoint *ecc_edwards_add(EdwardsPoint *, EdwardsPoint *);
+EdwardsPoint *ecc_edwards_multiply(EdwardsPoint *, mp_int *);
+
+/*
+ * Query functions: compare two points for equality, and return the
+ * affine coordinates of a point.
+ */
+unsigned ecc_edwards_eq(EdwardsPoint *, EdwardsPoint *);
+void ecc_edwards_get_affine(EdwardsPoint *wp, mp_int **x, mp_int **y);
+
+#endif /* PUTTY_ECC_H */

+ 21 - 20
source/putty/import.c

@@ -10,6 +10,7 @@
 
 #include "putty.h"
 #include "ssh.h"
+#include "mpint.h"
 #include "misc.h"
 
 static bool openssh_pem_encrypted(const Filename *file);
@@ -815,7 +816,7 @@ static bool openssh_pem_write(
          */
         if (ssh_key_alg(key->key) == &ssh_rsa) {
             ptrlen n, e, d, p, q, iqmp, dmp1, dmq1;
-            Bignum bd, bp, bq, bdmp1, bdmq1;
+            mp_int *bd, *bp, *bq, *bdmp1, *bdmq1;
 
             /*
              * These blobs were generated from inside PuTTY, so we needn't
@@ -834,29 +835,29 @@ static bool openssh_pem_write(
             assert(!get_err(src));     /* can't go wrong */
 
             /* We also need d mod (p-1) and d mod (q-1). */
-            bd = bignum_from_bytes(d.ptr, d.len);
-            bp = bignum_from_bytes(p.ptr, p.len);
-            bq = bignum_from_bytes(q.ptr, q.len);
-            decbn(bp);
-            decbn(bq);
-            bdmp1 = bigmod(bd, bp);
-            bdmq1 = bigmod(bd, bq);
-            freebn(bd);
-            freebn(bp);
-            freebn(bq);
-
-            dmp1.len = (bignum_bitcount(bdmp1)+8)/8;
-            dmq1.len = (bignum_bitcount(bdmq1)+8)/8;
+            bd = mp_from_bytes_be(d);
+            bp = mp_from_bytes_be(p);
+            bq = mp_from_bytes_be(q);
+            mp_sub_integer_into(bp, bp, 1);
+            mp_sub_integer_into(bq, bq, 1);
+            bdmp1 = mp_mod(bd, bp);
+            bdmq1 = mp_mod(bd, bq);
+            mp_free(bd);
+            mp_free(bp);
+            mp_free(bq);
+
+            dmp1.len = (mp_get_nbits(bdmp1)+8)/8;
+            dmq1.len = (mp_get_nbits(bdmq1)+8)/8;
             sparelen = dmp1.len + dmq1.len;
             spareblob = snewn(sparelen, unsigned char);
             dmp1.ptr = spareblob;
             dmq1.ptr = spareblob + dmp1.len;
             for (i = 0; i < dmp1.len; i++)
-                spareblob[i] = bignum_byte(bdmp1, dmp1.len-1 - i);
+                spareblob[i] = mp_get_byte(bdmp1, dmp1.len-1 - i);
             for (i = 0; i < dmq1.len; i++)
-                spareblob[i+dmp1.len] = bignum_byte(bdmq1, dmq1.len-1 - i);
-            freebn(bdmp1);
-            freebn(bdmq1);
+                spareblob[i+dmp1.len] = mp_get_byte(bdmq1, dmq1.len-1 - i);
+            mp_free(bdmp1);
+            mp_free(bdmq1);
 
             numbers[0] = make_ptrlen(zero, 1); zero[0] = '\0';
             numbers[1] = n;
@@ -913,7 +914,7 @@ static bool openssh_pem_write(
                ssh_key_alg(key->key) == &ssh_ecdsa_nistp384 ||
                ssh_key_alg(key->key) == &ssh_ecdsa_nistp521) {
         const unsigned char *oid;
-        struct ec_key *ec = container_of(key->key, struct ec_key, sshk);
+        struct ecdsa_key *ec = container_of(key->key, struct ecdsa_key, sshk);
         int oidlen;
         int pointlen;
         strbuf *seq, *sub;
@@ -929,7 +930,7 @@ static bool openssh_pem_write(
          *     BIT STRING (0x00 public key point)
          */
         oid = ec_alg_oid(ssh_key_alg(key->key), &oidlen);
-        pointlen = (ec->publicKey.curve->fieldBits + 7) / 8 * 2;
+        pointlen = (ec->curve->fieldBits + 7) / 8 * 2;
 
         seq = strbuf_new();
 

+ 5 - 1
source/putty/marshal.h

@@ -153,6 +153,8 @@ struct strbuf;
 void BinarySink_put_stringsb(BinarySink *, struct strbuf *);
 void BinarySink_put_asciz(BinarySink *, const char *str);
 bool BinarySink_put_pstring(BinarySink *, const char *str);
+void BinarySink_put_mp_ssh1(BinarySink *bs, mp_int *x);
+void BinarySink_put_mp_ssh2(BinarySink *bs, mp_int *x);
 
 /* ---------------------------------------------------------------------- */
 
@@ -195,7 +197,7 @@ struct BinarySource {
      * types.
      *
      * If the usual return value is dynamically allocated (e.g. a
-     * Bignum, or a normal C 'char *' string), then the error value is
+     * bignum, or a normal C 'char *' string), then the error value is
      * also dynamic in the same way. So you have to free exactly the
      * same set of things whether or not there was a decoding error,
      * which simplifies exit paths - for example, you could call a big
@@ -281,5 +283,7 @@ uint64_t BinarySource_get_uint64(BinarySource *);
 ptrlen BinarySource_get_string(BinarySource *);
 const char *BinarySource_get_asciz(BinarySource *);
 ptrlen BinarySource_get_pstring(BinarySource *);
+mp_int *BinarySource_get_mp_ssh1(BinarySource *src);
+mp_int *BinarySource_get_mp_ssh2(BinarySource *src);
 
 #endif /* PUTTY_MARSHAL_H */

+ 2340 - 0
source/putty/mpint.c

@@ -0,0 +1,2340 @@
+#include <assert.h>
+#include <stdio.h>
+
+#include "defs.h"
+#include "putty.h"
+
+#include "mpint.h"
+#include "mpint_i.h"
+
+/*
+ * Inline helpers to take min and max of size_t values, used
+ * throughout this code.
+ */
+static inline size_t size_t_min(size_t a, size_t b)
+{
+    return a < b ? a : b;
+}
+static inline size_t size_t_max(size_t a, size_t b)
+{
+    return a > b ? a : b;
+}
+
+/*
+ * Helper to fetch a word of data from x with array overflow checking.
+ * If x is too short to have that word, 0 is returned.
+ */
+static inline BignumInt mp_word(mp_int *x, size_t i)
+{
+    return i < x->nw ? x->w[i] : 0;
+}
+
+static mp_int *mp_make_sized(size_t nw)
+{
+    mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt));
+    x->nw = nw;
+    x->w = snew_plus_get_aux(x);
+    mp_clear(x);
+    return x;
+}
+
+mp_int *mp_new(size_t maxbits)
+{
+    size_t words = (maxbits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
+    return mp_make_sized(words);
+}
+
+mp_int *mp_from_integer(uintmax_t n)
+{
+    mp_int *x = mp_make_sized(
+        (sizeof(n) + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES);
+    for (size_t i = 0; i < x->nw; i++)
+        x->w[i] = n >> (i * BIGNUM_INT_BITS);
+    return x;
+}
+
+size_t mp_max_bytes(mp_int *x)
+{
+    return x->nw * BIGNUM_INT_BYTES;
+}
+
+size_t mp_max_bits(mp_int *x)
+{
+    return x->nw * BIGNUM_INT_BITS;
+}
+
+void mp_free(mp_int *x)
+{
+    mp_clear(x);
+    smemclr(x, sizeof(*x));
+    sfree(x);
+}
+
+void mp_dump(FILE *fp, const char *prefix, mp_int *x, const char *suffix)
+{
+    fprintf(fp, "%s0x", prefix);
+    for (size_t i = mp_max_bytes(x); i-- > 0 ;)
+        fprintf(fp, "%02X", mp_get_byte(x, i));
+    fputs(suffix, fp);
+}
+
+void mp_copy_into(mp_int *dest, mp_int *src)
+{
+    size_t copy_nw = size_t_min(dest->nw, src->nw);
+    memmove(dest->w, src->w, copy_nw * sizeof(BignumInt));
+    smemclr(dest->w + copy_nw, (dest->nw - copy_nw) * sizeof(BignumInt));
+}
+
+/*
+ * Conditional selection is done by negating 'which', to give a mask
+ * word which is all 1s if which==1 and all 0s if which==0. Then you
+ * can select between two inputs a,b without data-dependent control
+ * flow by XORing them to get their difference; ANDing with the mask
+ * word to replace that difference with 0 if which==0; and XORing that
+ * into a, which will either turn it into b or leave it alone.
+ *
+ * This trick will be used throughout this code and taken as read the
+ * rest of the time (or else I'd be here all week typing comments),
+ * but I felt I ought to explain it in words _once_.
+ */
+void mp_select_into(mp_int *dest, mp_int *src0, mp_int *src1,
+                    unsigned which)
+{
+    BignumInt mask = -(BignumInt)(1 & which);
+    for (size_t i = 0; i < dest->nw; i++) {
+        BignumInt srcword0 = mp_word(src0, i), srcword1 = mp_word(src1, i);
+        dest->w[i] = srcword0 ^ ((srcword1 ^ srcword0) & mask);
+    }
+}
+
+void mp_cond_swap(mp_int *x0, mp_int *x1, unsigned swap)
+{
+    assert(x0->nw == x1->nw);
+    BignumInt mask = -(BignumInt)(1 & swap);
+    for (size_t i = 0; i < x0->nw; i++) {
+        BignumInt diff = (x0->w[i] ^ x1->w[i]) & mask;
+        x0->w[i] ^= diff;
+        x1->w[i] ^= diff;
+    }
+}
+
+void mp_clear(mp_int *x)
+{
+    smemclr(x->w, x->nw * sizeof(BignumInt));
+}
+
+void mp_cond_clear(mp_int *x, unsigned clear)
+{
+    BignumInt mask = ~-(BignumInt)(1 & clear);
+    for (size_t i = 0; i < x->nw; i++)
+        x->w[i] &= mask;
+}
+
+/*
+ * Common code between mp_from_bytes_{le,be} which reads bytes in an
+ * arbitrary arithmetic progression.
+ */
+static mp_int *mp_from_bytes_int(ptrlen bytes, size_t m, size_t c)
+{
+    mp_int *n = mp_make_sized(
+        (bytes.len + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES);
+    for (size_t i = 0; i < bytes.len; i++)
+        n->w[i / BIGNUM_INT_BYTES] |=
+            (BignumInt)(((const unsigned char *)bytes.ptr)[m*i+c]) <<
+            (8 * (i % BIGNUM_INT_BYTES));
+    return n;
+}
+
+mp_int *mp_from_bytes_le(ptrlen bytes)
+{
+    return mp_from_bytes_int(bytes, 1, 0);
+}
+
+mp_int *mp_from_bytes_be(ptrlen bytes)
+{
+    return mp_from_bytes_int(bytes, -1, bytes.len - 1);
+}
+
+static mp_int *mp_from_words(size_t nw, const BignumInt *w)
+{
+    mp_int *x = mp_make_sized(nw);
+    memcpy(x->w, w, x->nw * sizeof(BignumInt));
+    return x;
+}
+
+/*
+ * Decimal-to-binary conversion: just go through the input string
+ * adding on the decimal value of each digit, and then multiplying the
+ * number so far by 10.
+ */
+mp_int *mp_from_decimal_pl(ptrlen decimal)
+{
+    /* 196/59 is an upper bound (and also a continued-fraction
+     * convergent) for log2(10), so this conservatively estimates the
+     * number of bits that will be needed to store any number that can
+     * be written in this many decimal digits. */
+    assert(decimal.len < (~(size_t)0) / 196);
+    size_t bits = 196 * decimal.len / 59;
+
+    /* Now round that up to words. */
+    size_t words = bits / BIGNUM_INT_BITS + 1;
+
+    mp_int *x = mp_make_sized(words);
+    for (size_t i = 0;; i++) {
+        mp_add_integer_into(x, x, ((char *)decimal.ptr)[i] - '0');
+
+        if (i+1 == decimal.len)
+            break;
+
+        mp_mul_integer_into(x, x, 10);
+    }
+    return x;
+}
+
+mp_int *mp_from_decimal(const char *decimal)
+{
+    return mp_from_decimal_pl(ptrlen_from_asciz(decimal));
+}
+
+/*
+ * Hex-to-binary conversion: _algorithmically_ simpler than decimal
+ * (none of those multiplications by 10), but there's some fiddly
+ * bit-twiddling needed to process each hex digit without diverging
+ * control flow depending on whether it's a letter or a number.
+ */
+mp_int *mp_from_hex_pl(ptrlen hex)
+{
+    assert(hex.len <= (~(size_t)0) / 4);
+    size_t bits = hex.len * 4;
+    size_t words = (bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
+    mp_int *x = mp_make_sized(words);
+    for (size_t nibble = 0; nibble < hex.len; nibble++) {
+        BignumInt digit = ((char *)hex.ptr)[hex.len-1 - nibble];
+
+        BignumInt lmask = ~-(((digit-'a')|('f'-digit)) >> (BIGNUM_INT_BITS-1));
+        BignumInt umask = ~-(((digit-'A')|('F'-digit)) >> (BIGNUM_INT_BITS-1));
+
+        BignumInt digitval = digit - '0';
+        digitval ^= (digitval ^ (digit - 'a' + 10)) & lmask;
+        digitval ^= (digitval ^ (digit - 'A' + 10)) & umask;
+        digitval &= 0xF; /* at least be _slightly_ nice about weird input */
+
+        size_t word_idx = nibble / (BIGNUM_INT_BYTES*2);
+        size_t nibble_within_word = nibble % (BIGNUM_INT_BYTES*2);
+        x->w[word_idx] |= digitval << (nibble_within_word * 4);
+    }
+    return x;
+}
+
+mp_int *mp_from_hex(const char *hex)
+{
+    return mp_from_hex_pl(ptrlen_from_asciz(hex));
+}
+
+mp_int *mp_copy(mp_int *x)
+{
+    return mp_from_words(x->nw, x->w);
+}
+
+uint8_t mp_get_byte(mp_int *x, size_t byte)
+{
+    return 0xFF & (mp_word(x, byte / BIGNUM_INT_BYTES) >>
+                   (8 * (byte % BIGNUM_INT_BYTES)));
+}
+
+unsigned mp_get_bit(mp_int *x, size_t bit)
+{
+    return 1 & (mp_word(x, bit / BIGNUM_INT_BITS) >>
+                (bit % BIGNUM_INT_BITS));
+}
+
+void mp_set_bit(mp_int *x, size_t bit, unsigned val)
+{
+    size_t word = bit / BIGNUM_INT_BITS;
+    assert(word < x->nw);
+
+    unsigned shift = (bit % BIGNUM_INT_BITS);
+
+    x->w[word] &= ~((BignumInt)1 << shift);
+    x->w[word] |= (BignumInt)(val & 1) << shift;
+}
+
+/*
+ * Helper function used here and there to normalise any nonzero input
+ * value to 1.
+ */
+static inline unsigned normalise_to_1(BignumInt n)
+{
+    n = (n >> 1) | (n & 1);            /* ensure top bit is clear */
+    n = (-n) >> (BIGNUM_INT_BITS - 1); /* normalise to 0 or 1 */
+    return n;
+}
+
+/*
+ * Find the highest nonzero word in a number. Returns the index of the
+ * word in x->w, and also a pair of output uint64_t in which that word
+ * appears in the high one shifted left by 'shift_wanted' bits, the
+ * words immediately below it occupy the space to the right, and the
+ * words below _that_ fill up the low one.
+ *
+ * If there is no nonzero word at all, the passed-by-reference output
+ * variables retain their original values.
+ */
+static inline void mp_find_highest_nonzero_word_pair(
+    mp_int *x, size_t shift_wanted, size_t *index,
+    uint64_t *hi, uint64_t *lo)
+{
+    uint64_t curr_hi = 0, curr_lo = 0;
+
+    for (size_t curr_index = 0; curr_index < x->nw; curr_index++) {
+        BignumInt curr_word = x->w[curr_index];
+        unsigned indicator = normalise_to_1(curr_word);
+
+        curr_lo = (BIGNUM_INT_BITS < 64 ? (curr_lo >> BIGNUM_INT_BITS) : 0) |
+            (curr_hi << (64 - BIGNUM_INT_BITS));
+        curr_hi = (BIGNUM_INT_BITS < 64 ? (curr_hi >> BIGNUM_INT_BITS) : 0) |
+            ((uint64_t)curr_word << shift_wanted);
+
+        if (hi)    *hi    ^= (curr_hi    ^ *hi   ) & -(uint64_t)indicator;
+        if (lo)    *lo    ^= (curr_lo    ^ *lo   ) & -(uint64_t)indicator;
+        if (index) *index ^= (curr_index ^ *index) & -(size_t)  indicator;
+    }
+}
+
+size_t mp_get_nbits(mp_int *x)
+{
+    /* Sentinel values in case there are no bits set at all: we
+     * imagine that there's a word at position -1 (i.e. the topmost
+     * fraction word) which is all 1s, because that way, we handle a
+     * zero input by considering its highest set bit to be the top one
+     * of that word, i.e. just below the units digit, i.e. at bit
+     * index -1, i.e. so we'll return 0 on output. */
+    size_t hiword_index = -(size_t)1;
+    uint64_t hiword64 = ~(BignumInt)0;
+
+    /*
+     * Find the highest nonzero word and its index.
+     */
+    mp_find_highest_nonzero_word_pair(x, 0, &hiword_index, &hiword64, NULL);
+    BignumInt hiword = hiword64; /* in case BignumInt is a narrower type */
+
+    /*
+     * Find the index of the highest set bit within hiword.
+     */
+    BignumInt hibit_index = 0;
+    for (size_t i = (1 << (BIGNUM_INT_BITS_BITS-1)); i != 0; i >>= 1) {
+        BignumInt shifted_word = hiword >> i;
+        BignumInt indicator = (-shifted_word) >> (BIGNUM_INT_BITS-1);
+        hiword ^= (shifted_word ^ hiword ) & -indicator;
+        hibit_index += i & -(size_t)indicator;
+    }
+
+    /*
+     * Put together the result.
+     */
+    return (hiword_index << BIGNUM_INT_BITS_BITS) + hibit_index + 1;
+}
+
+/*
+ * Shared code between the hex and decimal output functions to get rid
+ * of leading zeroes on the output string. The idea is that we wrote
+ * out a fixed number of digits and a trailing \0 byte into 'buf', and
+ * now we want to shift it all left so that the first nonzero digit
+ * moves to buf[0] (or, if there are no nonzero digits at all, we move
+ * up by 'maxtrim', so that we return 0 as "0" instead of "").
+ */
+static void trim_leading_zeroes(char *buf, size_t bufsize, size_t maxtrim)
+{
+    size_t trim = maxtrim;
+
+    /*
+     * Look for the first character not equal to '0', to find the
+     * shift count.
+     */
+    if (trim > 0) {
+        for (size_t pos = trim; pos-- > 0 ;) {
+            uint8_t diff = buf[pos] ^ '0';
+            size_t mask = -((((size_t)diff) - 1) >> (BIGNUM_INT_BITS - 1));
+            trim ^= (trim ^ pos) & ~mask;
+        }
+    }
+
+    /*
+     * Now do the shift, in log n passes each of which does a
+     * conditional shift by 2^i bytes if bit i is set in the shift
+     * count.
+     */
+    uint8_t *ubuf = (uint8_t *)buf;
+    for (size_t logd = 0; bufsize >> logd; logd++) {
+        uint8_t mask = -(uint8_t)((trim >> logd) & 1);
+        size_t d = (size_t)1 << logd;
+        for (size_t i = 0; i+d < bufsize; i++) {
+            uint8_t diff = mask & (ubuf[i] ^ ubuf[i+d]);
+            ubuf[i] ^= diff;
+            ubuf[i+d] ^= diff;
+        }
+    }
+}
+
+/*
+ * Binary to decimal conversion. Our strategy here is to extract each
+ * decimal digit by finding the input number's residue mod 10, then
+ * subtract that off to give an exact multiple of 10, which then means
+ * you can safely divide by 10 by means of shifting right one bit and
+ * then multiplying by the inverse of 5 mod 2^n.
+ */
+char *mp_get_decimal(mp_int *x_orig)
+{
+    mp_int *x = mp_copy(x_orig), *y = mp_make_sized(x->nw);
+
+    /*
+     * The inverse of 5 mod 2^lots is 0xccccccccccccccccccccd, for an
+     * appropriate number of 'c's. Manually construct an integer the
+     * right size.
+     */
+    mp_int *inv5 = mp_make_sized(x->nw);
+    assert(BIGNUM_INT_BITS % 8 == 0);
+    for (size_t i = 0; i < inv5->nw; i++)
+        inv5->w[i] = BIGNUM_INT_MASK / 5 * 4;
+    inv5->w[0]++;
+
+    /*
+     * 146/485 is an upper bound (and also a continued-fraction
+     * convergent) of log10(2), so this is a conservative estimate of
+     * the number of decimal digits needed to store a value that fits
+     * in this many binary bits.
+     */
+    assert(x->nw < (~(size_t)1) / (146 * BIGNUM_INT_BITS));
+    size_t bufsize = size_t_max(x->nw * (146 * BIGNUM_INT_BITS) / 485, 1) + 2;
+    char *outbuf = snewn(bufsize, char);
+    outbuf[bufsize - 1] = '\0';
+
+    /*
+     * Loop over the number generating digits from the least
+     * significant upwards, so that we write to outbuf in reverse
+     * order.
+     */
+    for (size_t pos = bufsize - 1; pos-- > 0 ;) {
+        /*
+         * Find the current residue mod 10. We do this by first
+         * summing the bytes of the number, with all but the lowest
+         * one multiplied by 6 (because 256^i == 6 mod 10 for all
+         * i>0). That gives us a single word congruent mod 10 to the
+         * input number, and then we reduce it further by manual
+         * multiplication and shifting, just in case the compiler
+         * target implements the C division operator in a way that has
+         * input-dependent timing.
+         */
+        uint32_t low_digit = 0, maxval = 0, mult = 1;
+        for (size_t i = 0; i < x->nw; i++) {
+            for (unsigned j = 0; j < BIGNUM_INT_BYTES; j++) {
+                low_digit += mult * (0xFF & (x->w[i] >> (8*j)));
+                maxval += mult * 0xFF;
+                mult = 6;
+            }
+            /*
+             * For _really_ big numbers, prevent overflow of t by
+             * periodically folding the top half of the accumulator
+             * into the bottom half, using the same rule 'multiply by
+             * 6 when shifting down by one or more whole bytes'.
+             */
+            if (maxval > UINT32_MAX - (6 * 0xFF * BIGNUM_INT_BYTES)) {
+                low_digit = (low_digit & 0xFFFF) + 6 * (low_digit >> 16);
+                maxval = (maxval & 0xFFFF) + 6 * (maxval >> 16);
+            }
+        }
+
+        /*
+         * Final reduction of low_digit. We multiply by 2^32 / 10
+         * (that's the constant 0x19999999) to get a 64-bit value
+         * whose top 32 bits are the approximate quotient
+         * low_digit/10; then we subtract off 10 times that; and
+         * finally we do one last trial subtraction of 10 by adding 6
+         * (which sets bit 4 if the number was just over 10) and then
+         * testing bit 4.
+         */
+        low_digit -= 10 * ((0x19999999ULL * low_digit) >> 32);
+        low_digit -= 10 * ((low_digit + 6) >> 4);
+
+        assert(low_digit < 10);        /* make sure we did reduce fully */
+        outbuf[pos] = '0' + low_digit;
+
+        /*
+         * Now subtract off that digit, divide by 2 (using a right
+         * shift) and by 5 (using the modular inverse), to get the
+         * next output digit into the units position.
+         */
+        mp_sub_integer_into(x, x, low_digit);
+        mp_rshift_fixed_into(y, x, 1);
+        mp_mul_into(x, y, inv5);
+    }
+
+    mp_free(x);
+    mp_free(y);
+    mp_free(inv5);
+
+    trim_leading_zeroes(outbuf, bufsize, bufsize - 2);
+    return outbuf;
+}
+
+/*
+ * Binary to hex conversion. Reasonably simple (only a spot of bit
+ * twiddling to choose whether to output a digit or a letter for each
+ * nibble).
+ */
+static char *mp_get_hex_internal(mp_int *x, uint8_t letter_offset)
+{
+    size_t nibbles = x->nw * BIGNUM_INT_BYTES * 2;
+    size_t bufsize = nibbles + 1;
+    char *outbuf = snewn(bufsize, char);
+    outbuf[nibbles] = '\0';
+
+    for (size_t nibble = 0; nibble < nibbles; nibble++) {
+        size_t word_idx = nibble / (BIGNUM_INT_BYTES*2);
+        size_t nibble_within_word = nibble % (BIGNUM_INT_BYTES*2);
+        uint8_t digitval = 0xF & (x->w[word_idx] >> (nibble_within_word * 4));
+
+        uint8_t mask = -((digitval + 6) >> 4);
+        char digit = digitval + '0' + (letter_offset & mask);
+        outbuf[nibbles-1 - nibble] = digit;
+    }
+
+    trim_leading_zeroes(outbuf, bufsize, nibbles - 1);
+    return outbuf;
+}
+
+char *mp_get_hex(mp_int *x)
+{
+    return mp_get_hex_internal(x, 'a' - ('0'+10));
+}
+
+char *mp_get_hex_uppercase(mp_int *x)
+{
+    return mp_get_hex_internal(x, 'A' - ('0'+10));
+}
+
+/*
+ * Routines for reading and writing the SSH-1 and SSH-2 wire formats
+ * for multiprecision integers, declared in marshal.h.
+ *
+ * These can't avoid having control flow dependent on the true bit
+ * size of the number, because the wire format requires the number of
+ * output bytes to depend on that.
+ */
+void BinarySink_put_mp_ssh1(BinarySink *bs, mp_int *x)
+{
+    size_t bits = mp_get_nbits(x);
+    size_t bytes = (bits + 7) / 8;
+
+    assert(bits < 0x10000);
+    put_uint16(bs, bits);
+    for (size_t i = bytes; i-- > 0 ;)
+        put_byte(bs, mp_get_byte(x, i));
+}
+
+void BinarySink_put_mp_ssh2(BinarySink *bs, mp_int *x)
+{
+    size_t bytes = (mp_get_nbits(x) + 8) / 8;
+
+    put_uint32(bs, bytes);
+    for (size_t i = bytes; i-- > 0 ;)
+        put_byte(bs, mp_get_byte(x, i));
+}
+
+mp_int *BinarySource_get_mp_ssh1(BinarySource *src)
+{
+    unsigned bitc = get_uint16(src);
+    ptrlen bytes = get_data(src, (bitc + 7) / 8);
+    if (get_err(src)) {
+        return mp_from_integer(0);
+    } else {
+        mp_int *toret = mp_from_bytes_be(bytes);
+        /* SSH-1.5 spec says that it's OK for the prefix uint16 to be
+         * _greater_ than the actual number of bits */
+        if (mp_get_nbits(toret) > bitc) {
+            src->err = BSE_INVALID;
+            mp_free(toret);
+            toret = mp_from_integer(0);
+        }
+        return toret;
+    }
+}
+
+mp_int *BinarySource_get_mp_ssh2(BinarySource *src)
+{
+    ptrlen bytes = get_string(src);
+    if (get_err(src)) {
+        return mp_from_integer(0);
+    } else {
+        const unsigned char *p = bytes.ptr;
+        if ((bytes.len > 0 &&
+             ((p[0] & 0x80) ||
+              (p[0] == 0 && (bytes.len <= 1 || !(p[1] & 0x80)))))) {
+            src->err = BSE_INVALID;
+            return mp_from_integer(0);
+        }
+        return mp_from_bytes_be(bytes);
+    }
+}
+
+/*
+ * Make an mp_int structure whose words array aliases a subinterval of
+ * some other mp_int. This makes it easy to read or write just the low
+ * or high words of a number, e.g. to add a number starting from a
+ * high bit position, or to reduce mod 2^{n*BIGNUM_INT_BITS}.
+ *
+ * The convention throughout this code is that when we store an mp_int
+ * directly by value, we always expect it to be an alias of some kind,
+ * so its words array won't ever need freeing. Whereas an 'mp_int *'
+ * has an owner, who knows whether it needs freeing or whether it was
+ * created by address-taking an alias.
+ */
+static mp_int mp_make_alias(mp_int *in, size_t offset, size_t len)
+{
+    /*
+     * Bounds-check the offset and length so that we always return
+     * something valid, even if it's not necessarily the length the
+     * caller asked for.
+     */
+    if (offset > in->nw)
+        offset = in->nw;
+    if (len > in->nw - offset)
+        len = in->nw - offset;
+
+    mp_int toret;
+    toret.nw = len;
+    toret.w = in->w + offset;
+    return toret;
+}
+
+/*
+ * A special case of mp_make_alias: in some cases we preallocate a
+ * large mp_int to use as scratch space (to avoid pointless
+ * malloc/free churn in recursive or iterative work).
+ *
+ * mp_alloc_from_scratch creates an alias of size 'len' to part of
+ * 'pool', and adjusts 'pool' itself so that further allocations won't
+ * overwrite that space.
+ *
+ * There's no free function to go with this. Typically you just copy
+ * the pool mp_int by value, allocate from the copy, and when you're
+ * done with those allocations, throw the copy away and go back to the
+ * original value of pool. (A mark/release system.)
+ */
+static mp_int mp_alloc_from_scratch(mp_int *pool, size_t len)
+{
+    assert(len <= pool->nw);
+    mp_int toret = mp_make_alias(pool, 0, len);
+    *pool = mp_make_alias(pool, len, pool->nw);
+    return toret;
+}
+
+/*
+ * Internal component common to lots of assorted add/subtract code.
+ * Reads words from a,b; writes into w_out (which might be NULL if the
+ * output isn't even needed). Takes an input carry flag in 'carry',
+ * and returns the output carry. Each word read from b is ANDed with
+ * b_and and then XORed with b_xor.
+ *
+ * So you can implement addition by setting b_and to all 1s and b_xor
+ * to 0; you can subtract by making b_xor all 1s too (effectively
+ * bit-flipping b) and also passing 1 as the input carry (to turn
+ * one's complement into two's complement). And you can do conditional
+ * add/subtract by choosing b_and to be all 1s or all 0s based on a
+ * condition, because the value of b will be totally ignored if b_and
+ * == 0.
+ */
+static BignumCarry mp_add_masked_into(
+    BignumInt *w_out, size_t rw, mp_int *a, mp_int *b,
+    BignumInt b_and, BignumInt b_xor, BignumCarry carry)
+{
+    for (size_t i = 0; i < rw; i++) {
+        BignumInt aword = mp_word(a, i), bword = mp_word(b, i), out;
+        bword = (bword & b_and) ^ b_xor;
+        BignumADC(out, carry, aword, bword, carry);
+        if (w_out)
+            w_out[i] = out;
+    }
+    return carry;
+}
+
+/*
+ * Like the public mp_add_into except that it returns the output carry.
+ */
+static inline BignumCarry mp_add_into_internal(mp_int *r, mp_int *a, mp_int *b)
+{
+    return mp_add_masked_into(r->w, r->nw, a, b, ~(BignumInt)0, 0, 0);
+}
+
+void mp_add_into(mp_int *r, mp_int *a, mp_int *b)
+{
+    mp_add_into_internal(r, a, b);
+}
+
+void mp_sub_into(mp_int *r, mp_int *a, mp_int *b)
+{
+    mp_add_masked_into(r->w, r->nw, a, b, ~(BignumInt)0, ~(BignumInt)0, 1);
+}
+
+static void mp_cond_negate(mp_int *r, mp_int *x, unsigned yes)
+{
+    BignumCarry carry = yes;
+    BignumInt flip = -(BignumInt)yes;
+    for (size_t i = 0; i < r->nw; i++) {
+        BignumInt xword = mp_word(x, i);
+        xword ^= flip;
+        BignumADC(r->w[i], carry, 0, xword, carry);
+    }
+}
+
+/*
+ * Similar to mp_add_masked_into, but takes a C integer instead of an
+ * mp_int as the masked operand.
+ */
+static BignumCarry mp_add_masked_integer_into(
+    BignumInt *w_out, size_t rw, mp_int *a, uintmax_t b,
+    BignumInt b_and, BignumInt b_xor, BignumCarry carry)
+{
+    for (size_t i = 0; i < rw; i++) {
+        BignumInt aword = mp_word(a, i);
+        size_t shift = i * BIGNUM_INT_BITS;
+        BignumInt bword = shift < BIGNUM_INT_BYTES ? b >> shift : 0;
+        BignumInt out;
+        bword = (bword ^ b_xor) & b_and;
+        BignumADC(out, carry, aword, bword, carry);
+        if (w_out)
+            w_out[i] = out;
+    }
+    return carry;
+}
+
+void mp_add_integer_into(mp_int *r, mp_int *a, uintmax_t n)
+{
+    mp_add_masked_integer_into(r->w, r->nw, a, n, ~(BignumInt)0, 0, 0);
+}
+
+void mp_sub_integer_into(mp_int *r, mp_int *a, uintmax_t n)
+{
+    mp_add_masked_integer_into(r->w, r->nw, a, n,
+                               ~(BignumInt)0, ~(BignumInt)0, 1);
+}
+
+/*
+ * Sets r to a + n << (word_index * BIGNUM_INT_BITS), treating
+ * word_index as secret data.
+ */
+static void mp_add_integer_into_shifted_by_words(
+    mp_int *r, mp_int *a, uintmax_t n, size_t word_index)
+{
+    unsigned indicator = 0;
+    BignumCarry carry = 0;
+
+    for (size_t i = 0; i < r->nw; i++) {
+        /* indicator becomes 1 when we reach the index that the least
+         * significant bits of n want to be placed at, and it stays 1
+         * thereafter. */
+        indicator |= 1 ^ normalise_to_1(i ^ word_index);
+
+        /* If indicator is 1, we add the low bits of n into r, and
+         * shift n down. If it's 0, we add zero bits into r, and
+         * leave n alone. */
+        BignumInt bword = n & -(BignumInt)indicator;
+        uintmax_t new_n = (BIGNUM_INT_BITS < 64 ? n >> BIGNUM_INT_BITS : 0);
+        n ^= (n ^ new_n) & -(uintmax_t)indicator;
+
+        BignumInt aword = mp_word(a, i);
+        BignumInt out;
+        BignumADC(out, carry, aword, bword, carry);
+        r->w[i] = out;
+    }
+}
+
+void mp_mul_integer_into(mp_int *r, mp_int *a, uint16_t n)
+{
+    BignumInt carry = 0, mult = n;
+    for (size_t i = 0; i < r->nw; i++) {
+        BignumInt aword = mp_word(a, i);
+        BignumMULADD(carry, r->w[i], aword, mult, carry);
+    }
+    assert(!carry);
+}
+
+void mp_cond_add_into(mp_int *r, mp_int *a, mp_int *b, unsigned yes)
+{
+    BignumInt mask = -(BignumInt)(yes & 1);
+    mp_add_masked_into(r->w, r->nw, a, b, mask, 0, 0);
+}
+
+void mp_cond_sub_into(mp_int *r, mp_int *a, mp_int *b, unsigned yes)
+{
+    BignumInt mask = -(BignumInt)(yes & 1);
+    mp_add_masked_into(r->w, r->nw, a, b, mask, mask, 1 & mask);
+}
+
+/*
+ * Ordered comparison between unsigned numbers is done by subtracting
+ * one from the other and looking at the output carry.
+ */
+unsigned mp_cmp_hs(mp_int *a, mp_int *b)
+{
+    size_t rw = size_t_max(a->nw, b->nw);
+    return mp_add_masked_into(NULL, rw, a, b, ~(BignumInt)0, ~(BignumInt)0, 1);
+}
+
+unsigned mp_hs_integer(mp_int *x, uintmax_t n)
+{
+    BignumInt carry = 1;
+    for (size_t i = 0; i < x->nw; i++) {
+        size_t shift = i * BIGNUM_INT_BITS;
+        BignumInt nword = shift < BIGNUM_INT_BYTES ? n >> shift : 0;
+        BignumInt dummy_out;
+        BignumADC(dummy_out, carry, x->w[i], ~nword, carry);
+        (void)dummy_out;
+    }
+    return carry;
+}
+
+/*
+ * Equality comparison is done by bitwise XOR of the input numbers,
+ * ORing together all the output words, and normalising the result
+ * using our careful normalise_to_1 helper function.
+ */
+unsigned mp_cmp_eq(mp_int *a, mp_int *b)
+{
+    BignumInt diff = 0;
+    for (size_t i = 0, limit = size_t_max(a->nw, b->nw); i < limit; i++)
+        diff |= mp_word(a, i) ^ mp_word(b, i);
+    return 1 ^ normalise_to_1(diff);   /* return 1 if diff _is_ zero */
+}
+
+unsigned mp_eq_integer(mp_int *x, uintmax_t n)
+{
+    BignumInt diff = 0;
+    for (size_t i = 0; i < x->nw; i++) {
+        size_t shift = i * BIGNUM_INT_BITS;
+        BignumInt nword = shift < BIGNUM_INT_BYTES ? n >> shift : 0;
+        diff |= x->w[i] ^ nword;
+    }
+    return 1 ^ normalise_to_1(diff);   /* return 1 if diff _is_ zero */
+}
+
+void mp_neg_into(mp_int *r, mp_int *a)
+{
+    mp_int zero;
+    zero.nw = 0;
+    mp_sub_into(r, &zero, a);
+}
+
+mp_int *mp_add(mp_int *x, mp_int *y)
+{
+    mp_int *r = mp_make_sized(size_t_max(x->nw, y->nw) + 1);
+    mp_add_into(r, x, y);
+    return r;
+}
+
+mp_int *mp_sub(mp_int *x, mp_int *y)
+{
+    mp_int *r = mp_make_sized(size_t_max(x->nw, y->nw));
+    mp_sub_into(r, x, y);
+    return r;
+}
+
+mp_int *mp_neg(mp_int *a)
+{
+    mp_int *r = mp_make_sized(a->nw);
+    mp_neg_into(r, a);
+    return r;
+}
+
+/*
+ * Internal routine: multiply and accumulate in the trivial O(N^2)
+ * way. Sets r <- r + a*b.
+ */
+static void mp_mul_add_simple(mp_int *r, mp_int *a, mp_int *b)
+{
+    BignumInt *aend = a->w + a->nw, *bend = b->w + b->nw, *rend = r->w + r->nw;
+
+    for (BignumInt *ap = a->w, *rp = r->w;
+         ap < aend && rp < rend; ap++, rp++) {
+
+        BignumInt adata = *ap, carry = 0, *rq = rp;
+
+        for (BignumInt *bp = b->w; bp < bend && rq < rend; bp++, rq++) {
+            BignumInt bdata = bp < bend ? *bp : 0;
+            BignumMULADD2(carry, *rq, adata, bdata, *rq, carry);
+        }
+
+        for (; rq < rend; rq++)
+            BignumADC(*rq, carry, 0, *rq, carry);
+    }
+}
+
+#ifndef KARATSUBA_THRESHOLD      /* allow redefinition via -D for testing */
+#define KARATSUBA_THRESHOLD 50
+#endif
+
+static inline size_t mp_mul_scratchspace_unary(size_t n)
+{
+    /*
+     * Simplistic and overcautious bound on the amount of scratch
+     * space that the recursive multiply function will need.
+     *
+     * The rationale is: on the main Karatsuba branch of
+     * mp_mul_internal, which is the most space-intensive one, we
+     * allocate space for (a0+a1) and (b0+b1) (each just over half the
+     * input length n) and their product (the sum of those sizes, i.e.
+     * just over n itself). Then in order to actually compute the
+     * product, we do a recursive multiplication of size just over n.
+     *
+     * If all those 'just over' weren't there, and everything was
+     * _exactly_ half the length, you'd get the amount of space for a
+     * size-n multiply defined by the recurrence M(n) = 2n + M(n/2),
+     * which is satisfied by M(n) = 4n. But instead it's (2n plus a
+     * word or two) and M(n/2 plus a word or two). On the assumption
+     * that there's still some constant k such that M(n) <= kn, this
+     * gives us kn = 2n + w + k(n/2 + w), where w is a small constant
+     * (one or two words). That simplifies to kn/2 = 2n + (k+1)w, and
+     * since we don't even _start_ needing scratch space until n is at
+     * least 50, we can bound 2n + (k+1)w above by 3n, giving k=6.
+     *
+     * So I claim that 6n words of scratch space will suffice, and I
+     * check that by assertion at every stage of the recursion.
+     */
+    return n * 6;
+}
+
+static size_t mp_mul_scratchspace(size_t rw, size_t aw, size_t bw)
+{
+    size_t inlen = size_t_min(rw, size_t_max(aw, bw));
+    return mp_mul_scratchspace_unary(inlen);
+}
+
+static void mp_mul_internal(mp_int *r, mp_int *a, mp_int *b, mp_int scratch)
+{
+    size_t inlen = size_t_min(r->nw, size_t_max(a->nw, b->nw));
+    assert(scratch.nw >= mp_mul_scratchspace_unary(inlen));
+
+    mp_clear(r);
+
+    if (inlen < KARATSUBA_THRESHOLD || a->nw == 0 || b->nw == 0) {
+        /*
+         * The input numbers are too small to bother optimising. Go
+         * straight to the simple primitive approach.
+         */
+        mp_mul_add_simple(r, a, b);
+        return;
+    }
+
+    /*
+     * Karatsuba divide-and-conquer algorithm. We cut each input in
+     * half, so that it's expressed as two big 'digits' in a giant
+     * base D:
+     *
+     *   a = a_1 D + a_0
+     *   b = b_1 D + b_0
+     *
+     * Then the product is of course
+     *
+     *   ab = a_1 b_1 D^2 + (a_1 b_0 + a_0 b_1) D + a_0 b_0
+     *
+     * and we compute the three coefficients by recursively calling
+     * ourself to do half-length multiplications.
+     *
+     * The clever bit that makes this worth doing is that we only need
+     * _one_ half-length multiplication for the central coefficient
+     * rather than the two that it obviouly looks like, because we can
+     * use a single multiplication to compute
+     *
+     *   (a_1 + a_0) (b_1 + b_0) = a_1 b_1 + a_1 b_0 + a_0 b_1 + a_0 b_0
+     *
+     * and then we subtract the other two coefficients (a_1 b_1 and
+     * a_0 b_0) which we were computing anyway.
+     *
+     * Hence we get to multiply two numbers of length N in about three
+     * times as much work as it takes to multiply numbers of length
+     * N/2, which is obviously better than the four times as much work
+     * it would take if we just did a long conventional multiply.
+     */
+
+    /* Break up the input as botlen + toplen, with botlen >= toplen.
+     * The 'base' D is equal to 2^{botlen * BIGNUM_INT_BITS}. */
+    size_t toplen = inlen / 2;
+    size_t botlen = inlen - toplen;
+
+    /* Alias bignums that address the two halves of a,b, and useful
+     * pieces of r. */
+    mp_int a0 = mp_make_alias(a, 0, botlen);
+    mp_int b0 = mp_make_alias(b, 0, botlen);
+    mp_int a1 = mp_make_alias(a, botlen, toplen);
+    mp_int b1 = mp_make_alias(b, botlen, toplen);
+    mp_int r0 = mp_make_alias(r, 0, botlen*2);
+    mp_int r1 = mp_make_alias(r, botlen, r->nw);
+    mp_int r2 = mp_make_alias(r, botlen*2, r->nw);
+
+    /* Recurse to compute a0*b0 and a1*b1, in their correct positions
+     * in the output bignum. They can't overlap. */
+    mp_mul_internal(&r0, &a0, &b0, scratch);
+    mp_mul_internal(&r2, &a1, &b1, scratch);
+
+    if (r->nw < inlen*2) {
+        /*
+         * The output buffer isn't large enough to require the whole
+         * product, so some of a1*b1 won't have been stored. In that
+         * case we won't try to do the full Karatsuba optimisation;
+         * we'll just recurse again to compute a0*b1 and a1*b0 - or at
+         * least as much of them as the output buffer size requires -
+         * and add each one in.
+         */
+        mp_int s = mp_alloc_from_scratch(
+            &scratch, size_t_min(botlen+toplen, r1.nw));
+
+        mp_mul_internal(&s, &a0, &b1, scratch);
+        mp_add_into(&r1, &r1, &s);
+        mp_mul_internal(&s, &a1, &b0, scratch);
+        mp_add_into(&r1, &r1, &s);
+        return;
+    }
+
+    /* a0+a1 and b0+b1 */
+    mp_int asum = mp_alloc_from_scratch(&scratch, botlen+1);
+    mp_int bsum = mp_alloc_from_scratch(&scratch, botlen+1);
+    mp_add_into(&asum, &a0, &a1);
+    mp_add_into(&bsum, &b0, &b1);
+
+    /* Their product */
+    mp_int product = mp_alloc_from_scratch(&scratch, botlen*2+1);
+    mp_mul_internal(&product, &asum, &bsum, scratch);
+
+    /* Subtract off the outer terms we already have */
+    mp_sub_into(&product, &product, &r0);
+    mp_sub_into(&product, &product, &r2);
+
+    /* And add it in with the right offset. */
+    mp_add_into(&r1, &r1, &product);
+}
+
+void mp_mul_into(mp_int *r, mp_int *a, mp_int *b)
+{
+    mp_int *scratch = mp_make_sized(mp_mul_scratchspace(r->nw, a->nw, b->nw));
+    mp_mul_internal(r, a, b, *scratch);
+    mp_free(scratch);
+}
+
+mp_int *mp_mul(mp_int *x, mp_int *y)
+{
+    mp_int *r = mp_make_sized(x->nw + y->nw);
+    mp_mul_into(r, x, y);
+    return r;
+}
+
+void mp_lshift_fixed_into(mp_int *r, mp_int *a, size_t bits)
+{
+    size_t words = bits / BIGNUM_INT_BITS;
+    size_t bitoff = bits % BIGNUM_INT_BITS;
+
+    for (size_t i = 0; i < r->nw; i++) {
+        if (i < words) {
+            r->w[i] = 0;
+        } else {
+            r->w[i] = mp_word(a, i - words);
+            if (bitoff != 0) {
+                r->w[i] <<= bitoff;
+                if (i > words)
+                    r->w[i] |= mp_word(a, i - words - 1) >>
+                        (BIGNUM_INT_BITS - bitoff);
+            }
+        }
+    }
+}
+
+void mp_rshift_fixed_into(mp_int *r, mp_int *a, size_t bits)
+{
+    size_t words = bits / BIGNUM_INT_BITS;
+    size_t bitoff = bits % BIGNUM_INT_BITS;
+
+    for (size_t i = 0; i < r->nw; i++) {
+        r->w[i] = mp_word(a, i + words);
+        if (bitoff != 0) {
+            r->w[i] >>= bitoff;
+            r->w[i] |= mp_word(a, i + words + 1) << (BIGNUM_INT_BITS - bitoff);
+        }
+    }
+}
+
+mp_int *mp_rshift_fixed(mp_int *x, size_t bits)
+{
+    size_t words = bits / BIGNUM_INT_BITS;
+    mp_int *r = mp_make_sized(x->nw - size_t_min(x->nw, words));
+    mp_rshift_fixed_into(r, x, bits);
+    return r;
+}
+
+/*
+ * Safe right shift is done using the same technique as
+ * trim_leading_zeroes above: you make an n-word left shift by
+ * composing an appropriate subset of power-of-2-sized shifts, so it
+ * takes log_2(n) loop iterations each of which does a different shift
+ * by a power of 2 words, using the usual bit twiddling to make the
+ * whole shift conditional on the appropriate bit of n.
+ */
+mp_int *mp_rshift_safe(mp_int *x, size_t bits)
+{
+    size_t wordshift = bits / BIGNUM_INT_BITS;
+    size_t bitshift = bits % BIGNUM_INT_BITS;
+
+    mp_int *r = mp_copy(x);
+
+    unsigned clear = (r->nw - wordshift) >> (CHAR_BIT * sizeof(size_t) - 1);
+    mp_cond_clear(r, clear);
+
+    for (unsigned bit = 0; r->nw >> bit; bit++) {
+        size_t word_offset = 1 << bit;
+        BignumInt mask = -(BignumInt)((wordshift >> bit) & 1);
+        for (size_t i = 0; i < r->nw; i++) {
+            BignumInt w = mp_word(r, i + word_offset);
+            r->w[i] ^= (r->w[i] ^ w) & mask;
+        }
+    }
+
+    /*
+     * That's done the shifting by words; now we do the shifting by
+     * bits.
+     *
+     * I assume here that register-controlled right shifts are
+     * time-constant. If they're not, I could replace this with
+     * another loop over bit positions.
+     */
+    size_t upshift = BIGNUM_INT_BITS - bitshift;
+    size_t no_shift = (upshift >> BIGNUM_INT_BITS_BITS);
+    upshift &= ~-(size_t)no_shift;
+    BignumInt upshifted_mask = ~-(BignumInt)no_shift;
+
+    for (size_t i = 0; i < r->nw; i++) {
+        r->w[i] = (r->w[i] >> bitshift) |
+            ((mp_word(r, i+1) << upshift) & upshifted_mask);
+    }
+
+    return r;
+}
+
+void mp_reduce_mod_2to(mp_int *x, size_t p)
+{
+    size_t word = p / BIGNUM_INT_BITS;
+    size_t mask = ((size_t)1 << (p % BIGNUM_INT_BITS)) - 1;
+    for (; word < x->nw; word++) {
+        x->w[word] &= mask;
+        mask = -(size_t)1;
+    }
+}
+
+/*
+ * Inverse mod 2^n is computed by an iterative technique which doubles
+ * the number of bits at each step.
+ */
+mp_int *mp_invert_mod_2to(mp_int *x, size_t p)
+{
+    /* Input checks: x must be coprime to the modulus, i.e. odd, and p
+     * can't be zero */
+    assert(x->nw > 0);
+    assert(x->w[0] & 1);
+    assert(p > 0);
+
+    size_t rw = (p + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
+    mp_int *r = mp_make_sized(rw);
+
+    size_t mul_scratchsize = mp_mul_scratchspace(2*rw, rw, rw);
+    mp_int *scratch_orig = mp_make_sized(6 * rw + mul_scratchsize);
+    mp_int scratch_per_iter = *scratch_orig;
+    mp_int mul_scratch = mp_alloc_from_scratch(
+        &scratch_per_iter, mul_scratchsize);
+
+    r->w[0] = 1;
+
+    for (size_t b = 1; b < p; b <<= 1) {
+        /*
+         * In each step of this iteration, we have the inverse of x
+         * mod 2^b, and we want the inverse of x mod 2^{2b}.
+         *
+         * Write B = 2^b for convenience, so we want x^{-1} mod B^2.
+         * Let x = x_0 + B x_1 + k B^2, with 0 <= x_0,x_1 < B.
+         *
+         * We want to find r_0 and r_1 such that
+         *    (r_1 B + r_0) (x_1 B + x_0) == 1 (mod B^2)
+         *
+         * To begin with, we know r_0 must be the inverse mod B of
+         * x_0, i.e. of x, i.e. it is the inverse we computed in the
+         * previous iteration. So now all we need is r_1.
+         *
+         * Multiplying out, neglecting multiples of B^2, and writing
+         * x_0 r_0 = K B + 1, we have
+         *
+         *    r_1 x_0 B + r_0 x_1 B + K B == 0                    (mod B^2)
+         * =>                   r_1 x_0 B == - r_0 x_1 B - K B    (mod B^2)
+         * =>                     r_1 x_0 == - r_0 x_1 - K        (mod B)
+         * =>                         r_1 == r_0 (- r_0 x_1 - K)  (mod B)
+         *
+         * (the last step because we multiply through by the inverse
+         * of x_0, which we already know is r_0).
+         */
+
+        mp_int scratch_this_iter = scratch_per_iter;
+        size_t Bw = (b + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
+        size_t B2w = (2*b + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
+
+        /* Start by finding K: multiply x_0 by r_0, and shift down. */
+        mp_int x0 = mp_alloc_from_scratch(&scratch_this_iter, Bw);
+        mp_copy_into(&x0, x);
+        mp_reduce_mod_2to(&x0, b);
+        mp_int r0 = mp_make_alias(r, 0, Bw);
+        mp_int Kshift = mp_alloc_from_scratch(&scratch_this_iter, B2w);
+        mp_mul_internal(&Kshift, &x0, &r0, mul_scratch);
+        mp_int K = mp_alloc_from_scratch(&scratch_this_iter, Bw);
+        mp_rshift_fixed_into(&K, &Kshift, b);
+
+        /* Now compute the product r_0 x_1, reusing the space of Kshift. */
+        mp_int x1 = mp_alloc_from_scratch(&scratch_this_iter, Bw);
+        mp_rshift_fixed_into(&x1, x, b);
+        mp_reduce_mod_2to(&x1, b);
+        mp_int r0x1 = mp_make_alias(&Kshift, 0, Bw);
+        mp_mul_internal(&r0x1, &r0, &x1, mul_scratch);
+
+        /* Add K to that. */
+        mp_add_into(&r0x1, &r0x1, &K);
+
+        /* Negate it. */
+        mp_neg_into(&r0x1, &r0x1);
+
+        /* Multiply by r_0. */
+        mp_int r1 = mp_alloc_from_scratch(&scratch_this_iter, Bw);
+        mp_mul_internal(&r1, &r0, &r0x1, mul_scratch);
+        mp_reduce_mod_2to(&r1, b);
+
+        /* That's our r_1, so add it on to r_0 to get the full inverse
+         * output from this iteration. */
+        mp_lshift_fixed_into(&K, &r1, (b % BIGNUM_INT_BITS));
+        size_t Bpos = b / BIGNUM_INT_BITS;
+        mp_int r1_position = mp_make_alias(r, Bpos, B2w-Bpos);
+        mp_add_into(&r1_position, &r1_position, &K);
+    }
+
+    /* Finally, reduce mod the precise desired number of bits. */
+    mp_reduce_mod_2to(r, p);
+
+    mp_free(scratch_orig);
+    return r;
+}
+
+static size_t monty_scratch_size(MontyContext *mc)
+{
+    return 3*mc->rw + mc->pw + mp_mul_scratchspace(mc->pw, mc->rw, mc->rw);
+}
+
+MontyContext *monty_new(mp_int *modulus)
+{
+    MontyContext *mc = snew(MontyContext);
+
+    mc->rw = modulus->nw;
+    mc->rbits = mc->rw * BIGNUM_INT_BITS;
+    mc->pw = mc->rw * 2 + 1;
+
+    mc->m = mp_copy(modulus);
+
+    mc->minus_minv_mod_r = mp_invert_mod_2to(mc->m, mc->rbits);
+    mp_neg_into(mc->minus_minv_mod_r, mc->minus_minv_mod_r);
+
+    mp_int *r = mp_make_sized(mc->rw + 1);
+    r->w[mc->rw] = 1;
+    mc->powers_of_r_mod_m[0] = mp_mod(r, mc->m);
+    mp_free(r);
+
+    for (size_t j = 1; j < lenof(mc->powers_of_r_mod_m); j++)
+        mc->powers_of_r_mod_m[j] = mp_modmul(
+            mc->powers_of_r_mod_m[0], mc->powers_of_r_mod_m[j-1], mc->m);
+
+    mc->scratch = mp_make_sized(monty_scratch_size(mc));
+
+    return mc;
+}
+
+MontyContext *monty_copy(MontyContext *orig)
+{
+    MontyContext *mc = snew(MontyContext);
+
+    mc->rw = orig->rw;
+    mc->pw = orig->pw;
+    mc->rbits = orig->rbits;
+    mc->m = mp_copy(orig->m);
+    mc->minus_minv_mod_r = mp_copy(orig->minus_minv_mod_r);
+    for (size_t j = 0; j < 3; j++)
+        mc->powers_of_r_mod_m[j] = mp_copy(orig->powers_of_r_mod_m[j]);
+    mc->scratch = mp_make_sized(monty_scratch_size(mc));
+    return mc;
+}
+
+void monty_free(MontyContext *mc)
+{
+    mp_free(mc->m);
+    for (size_t j = 0; j < 3; j++)
+        mp_free(mc->powers_of_r_mod_m[j]);
+    mp_free(mc->minus_minv_mod_r);
+    mp_free(mc->scratch);
+    smemclr(mc, sizeof(*mc));
+    sfree(mc);
+}
+
+/*
+ * The main Montgomery reduction step.
+ */
+static mp_int monty_reduce_internal(MontyContext *mc, mp_int *x, mp_int scratch)
+{
+    /*
+     * The trick with Montgomery reduction is that on the one hand we
+     * want to reduce the size of the input by a factor of about r,
+     * and on the other hand, the two numbers we just multiplied were
+     * both stored with an extra factor of r multiplied in. So we
+     * computed ar*br = ab r^2, but we want to return abr, so we need
+     * to divide by r - and if we can do that by _actually dividing_
+     * by r then this also reduces the size of the number.
+     *
+     * But we can only do that if the number we're dividing by r is a
+     * multiple of r. So first we must add an adjustment to it which
+     * clears its bottom 'rbits' bits. That adjustment must be a
+     * multiple of m in order to leave the residue mod n unchanged, so
+     * the question is, what multiple of m can we add to x to make it
+     * congruent to 0 mod r? And the answer is, x * (-m)^{-1} mod r.
+     */
+
+    /* x mod r */
+    mp_int x_lo = mp_make_alias(x, 0, mc->rbits);
+
+    /* x * (-m)^{-1}, i.e. the number we want to multiply by m */
+    mp_int k = mp_alloc_from_scratch(&scratch, mc->rw);
+    mp_mul_internal(&k, &x_lo, mc->minus_minv_mod_r, scratch);
+
+    /* m times that, i.e. the number we want to add to x */
+    mp_int mk = mp_alloc_from_scratch(&scratch, mc->pw);
+    mp_mul_internal(&mk, mc->m, &k, scratch);
+
+    /* Add it to x */
+    mp_add_into(&mk, x, &mk);
+
+    /* Reduce mod r, by simply making an alias to the upper words of x */
+    mp_int toret = mp_make_alias(&mk, mc->rw, mk.nw - mc->rw);
+
+    /*
+     * We'll generally be doing this after a multiplication of two
+     * fully reduced values. So our input could be anything up to m^2,
+     * and then we added up to rm to it. Hence, the maximum value is
+     * rm+m^2, and after dividing by r, that becomes r + m(m/r) < 2r.
+     * So a single trial-subtraction will finish reducing to the
+     * interval [0,m).
+     */
+    mp_cond_sub_into(&toret, &toret, mc->m, mp_cmp_hs(&toret, mc->m));
+    return toret;
+}
+
+void monty_mul_into(MontyContext *mc, mp_int *r, mp_int *x, mp_int *y)
+{
+    assert(x->nw <= mc->rw);
+    assert(y->nw <= mc->rw);
+
+    mp_int scratch = *mc->scratch;
+    mp_int tmp = mp_alloc_from_scratch(&scratch, 2*mc->rw);
+    mp_mul_into(&tmp, x, y);
+    mp_int reduced = monty_reduce_internal(mc, &tmp, scratch);
+    mp_copy_into(r, &reduced);
+    mp_clear(mc->scratch);
+}
+
+mp_int *monty_mul(MontyContext *mc, mp_int *x, mp_int *y)
+{
+    mp_int *toret = mp_make_sized(mc->rw);
+    monty_mul_into(mc, toret, x, y);
+    return toret;
+}
+
+mp_int *monty_modulus(MontyContext *mc)
+{
+    return mc->m;
+}
+
+mp_int *monty_identity(MontyContext *mc)
+{
+    return mc->powers_of_r_mod_m[0];
+}
+
+mp_int *monty_invert(MontyContext *mc, mp_int *x)
+{
+    /* Given xr, we want to return x^{-1}r = (xr)^{-1} r^2 =
+     * monty_reduce((xr)^{-1} r^3) */
+    mp_int *tmp = mp_invert(x, mc->m);
+    mp_int *toret = monty_mul(mc, tmp, mc->powers_of_r_mod_m[2]);
+    mp_free(tmp);
+    return toret;
+}
+
+/*
+ * Importing a number into Montgomery representation involves
+ * multiplying it by r and reducing mod m. We could do this using the
+ * straightforward mp_modmul, but since we have the machinery to avoid
+ * division, why don't we use it? If we multiply the number not by r
+ * itself, but by the residue of r^2 mod m, then we can do an actual
+ * Montgomery reduction to reduce the result and remove the extra
+ * factor of r.
+ */
+void monty_import_into(MontyContext *mc, mp_int *r, mp_int *x)
+{
+    monty_mul_into(mc, r, x, mc->powers_of_r_mod_m[1]);
+}
+
+mp_int *monty_import(MontyContext *mc, mp_int *x)
+{
+    return monty_mul(mc, x, mc->powers_of_r_mod_m[1]);
+}
+
+/*
+ * Exporting a number means multiplying it by r^{-1}, which is exactly
+ * what monty_reduce does anyway, so we just do that.
+ */
+void monty_export_into(MontyContext *mc, mp_int *r, mp_int *x)
+{
+    assert(x->nw <= 2*mc->rw);
+    mp_int reduced = monty_reduce_internal(mc, x, *mc->scratch);
+    mp_copy_into(r, &reduced);
+    mp_clear(mc->scratch);
+}
+
+mp_int *monty_export(MontyContext *mc, mp_int *x)
+{
+    mp_int *toret = mp_make_sized(mc->rw);
+    monty_export_into(mc, toret, x);
+    return toret;
+}
+
+static void monty_reduce(MontyContext *mc, mp_int *x)
+{
+    mp_int reduced = monty_reduce_internal(mc, x, *mc->scratch);
+    mp_copy_into(x, &reduced);
+    mp_clear(mc->scratch);
+}
+
+mp_int *monty_pow(MontyContext *mc, mp_int *base, mp_int *exponent)
+{
+    /* square builds up powers of the form base^{2^i}. */
+    mp_int *square = mp_copy(base);
+    size_t i = 0;
+
+    /* out accumulates the output value. Starts at 1 (in Montgomery
+     * representation) and we multiply in each base^{2^i}. */
+    mp_int *out = mp_copy(mc->powers_of_r_mod_m[0]);
+
+    /* tmp holds each product we compute and reduce. */
+    mp_int *tmp = mp_make_sized(mc->rw * 2);
+
+    while (true) {
+        mp_mul_into(tmp, out, square);
+        monty_reduce(mc, tmp);
+        mp_select_into(out, out, tmp, mp_get_bit(exponent, i));
+
+        if (++i >= exponent->nw * BIGNUM_INT_BITS)
+            break;
+
+        mp_mul_into(tmp, square, square);
+        monty_reduce(mc, tmp);
+        mp_copy_into(square, tmp);
+    }
+
+    mp_free(square);
+    mp_free(tmp);
+    mp_clear(mc->scratch);
+    return out;
+}
+
+mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus)
+{
+    assert(base->nw <= modulus->nw);
+    assert(modulus->nw > 0);
+    assert(modulus->w[0] & 1);
+
+    MontyContext *mc = monty_new(modulus);
+    mp_int *m_base = monty_import(mc, base);
+    mp_int *m_out = monty_pow(mc, m_base, exponent);
+    mp_int *out = monty_export(mc, m_out);
+    mp_free(m_base);
+    mp_free(m_out);
+    monty_free(mc);
+    return out;
+}
+
+/*
+ * Given two coprime nonzero input integers a,b, returns two integers
+ * A,B such that A*a - B*b = 1. A,B will be the minimal non-negative
+ * pair satisfying that criterion, which is equivalent to saying that
+ * 0<=A<b and 0<=B<a.
+ *
+ * This algorithm is an adapted form of Stein's algorithm, which
+ * computes gcd(a,b) using only addition and bit shifts (i.e. without
+ * needing general division), using the following rules:
+ *
+ *  - if both of a,b are even, divide off a common factor of 2
+ *  - if one of a,b (WLOG a) is even, then gcd(a,b) = gcd(a/2,b), so
+ *    just divide a by 2
+ *  - if both of a,b are odd, then WLOG a>b, and gcd(a,b) =
+ *    gcd(b,(a-b)/2).
+ *
+ * For this application, I always expect the actual gcd to be coprime,
+ * so we can rule out the 'both even' initial case. For simplicity
+ * I've changed the 'both odd' case to turn (a,b) into (b,a-b) without
+ * the division by 2 (the next iteration would divide by 2 anyway).
+ *
+ * But the big change is that we need the Bezout coefficients as
+ * output, not just the gcd. So we need to know how to generate those
+ * in each case, based on the coefficients from the reduced pair of
+ * numbers:
+ *
+ *  - If a,b are both odd, and u,v are such that u*b + v*(a-b) = 1,
+ *    then v*a + (u-v)*b = 1.
+ *
+ *  - If a is even, and u,v are such that u*(a/2) + v*b = 1:
+ *     + if u is also even, then this is just (u/2)*a + v*b = 1
+ *     + otherwise, (u+b)*(a/2) + (v-a/2)*b is also equal to 1, and
+ *       since u and b are both odd, (u+b)/2 is an integer, so we have
+ *       ((u+b)/2)*a + (v-a/2)*b = 1.
+ *
+ * The code below transforms this from a recursive to an iterative
+ * algorithm. We first reduce a,b to 0,1, recording at each stage
+ * whether one of them was even, and whether we had to swap them; then
+ * we iterate backwards over that record of what we did, applying the
+ * above rules for building up the Bezout coefficients as we go. Of
+ * course, all the case analysis is done by the usual bit-twiddling
+ * conditionalisation to avoid data-dependent control flow.
+ *
+ * Also, since these mp_ints are generally treated as unsigned, we
+ * store the coefficients by absolute value, with the semantics that
+ * they always have opposite sign, and in the unwinding loop we keep a
+ * bit indicating whether Aa-Bb is currently expected to be +1 or -1,
+ * so that we can do one final conditional adjustment if it's -1.
+ *
+ * Once the reduction rules have managed to reduce the input numbers
+ * to (0,1), then they are stable (the next reduction will always
+ * divide the even one by 2, which maps 0 to 0). So it doesn't matter
+ * if we do more steps of the algorithm than necessary; hence, for
+ * constant time, we just need to find the maximum number we could
+ * _possibly_ require, and do that many.
+ *
+ * If a,b < 2^n, at most 3n iterations are required. Proof: consider
+ * the quantity Q = log_2(min(a,b)) + 2 log_2(max(a,b)).
+ *  - If the smaller number is even, then the next iteration halves
+ *    it, decreasing Q by 1.
+ *  - If the larger number is even, then the next iteration halves
+ *    it, decreasing Q by 2.
+ *  - If the two numbers are both odd, then the combined effect of the
+ *    next two steps will be to replace the larger number with
+ *    something less than half its original value.
+ * In any of these cases, the effect is that in k steps (where k = 1
+ * or 2 depending on the case) Q decreases by at least k. So on
+ * average it decreases by at least 1 per step, and since it starts
+ * off at 3n, that's how many steps it might take.
+ *
+ * The worst case inputs (I think) are where x=2^{n-1} and y=2^n-1
+ * (i.e. x is a power of 2 and y is all 1s). In that situation, the
+ * first n-1 steps repeatedly halve x until it's 1, and then there are
+ * n pairs of steps each of which subtracts 1 from y and then halves
+ * it.
+ */
+static void mp_bezout_into(mp_int *a_coeff_out, mp_int *b_coeff_out,
+                           mp_int *a_in, mp_int *b_in)
+{
+    size_t nw = size_t_max(1, size_t_max(a_in->nw, b_in->nw));
+
+    /* Make mutable copies of the input numbers */
+    mp_int *a = mp_make_sized(nw), *b = mp_make_sized(nw);
+    mp_copy_into(a, a_in);
+    mp_copy_into(b, b_in);
+
+    /* Space to build up the output coefficients, with an extra word
+     * so that intermediate values can overflow off the top and still
+     * right-shift back down to the correct value */
+    mp_int *ac = mp_make_sized(nw + 1), *bc = mp_make_sized(nw + 1);
+
+    /* And a general-purpose temp register */
+    mp_int *tmp = mp_make_sized(nw);
+
+    /* Space to record the sequence of reduction steps to unwind. We
+     * make it a BignumInt for no particular reason except that (a)
+     * mp_make_sized conveniently zeroes the allocation and mp_free
+     * wipes it, and (b) this way I can use mp_dump() if I have to
+     * debug this code. */
+    size_t steps = 3 * nw * BIGNUM_INT_BITS;
+    mp_int *record = mp_make_sized(
+        (steps*2 + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
+
+    for (size_t step = 0; step < steps; step++) {
+        /*
+         * If a and b are both odd, we want to sort them so that a is
+         * larger. But if one is even, we want to sort them so that a
+         * is the even one.
+         */
+        unsigned swap_if_both_odd = mp_cmp_hs(b, a);
+        unsigned swap_if_one_even = a->w[0] & 1;
+        unsigned both_odd = a->w[0] & b->w[0] & 1;
+        unsigned swap = swap_if_one_even ^ (
+            (swap_if_both_odd ^ swap_if_one_even) & both_odd);
+
+        mp_cond_swap(a, b, swap);
+
+        /*
+         * Now, if we've made a the even number, divide it by two; if
+         * we've made it the larger of two odd numbers, subtract the
+         * smaller one from it.
+         */
+        mp_rshift_fixed_into(tmp, a, 1);
+        mp_sub_into(a, a, b);
+        mp_select_into(a, tmp, a, both_odd);
+
+        /*
+         * Record the two 1-bit values both_odd and swap.
+         */
+        mp_set_bit(record, step*2, both_odd);
+        mp_set_bit(record, step*2+1, swap);
+    }
+
+    /*
+     * Now we expect to have reduced the two numbers to 0 and 1,
+     * although we don't know which way round. (But we avoid checking
+     * this by assertion; sometimes we'll need to do this computation
+     * without giving away that we already know the inputs were bogus.
+     * So we'd prefer to just press on and return nonsense.)
+     */
+
+    /*
+     * So their Bezout coefficients at this point are simply
+     * themselves.
+     */
+    mp_copy_into(ac, a);
+    mp_copy_into(bc, b);
+
+    /*
+     * We'll maintain the invariant as we unwind that ac * a - bc * b
+     * is either +1 or -1, and we'll remember which. (We _could_ keep
+     * it at +1 the whole time, but it would cost more work every time
+     * round the loop, so it's cheaper to fix that up once at the
+     * end.)
+     *
+     * Initially, the result is +1 if a was the nonzero value after
+     * reduction, and -1 if b was.
+     */
+    unsigned minus_one = b->w[0];
+
+    for (size_t step = steps; step-- > 0 ;) {
+        /*
+         * Recover the data from the step we're unwinding.
+         */
+        unsigned both_odd = mp_get_bit(record, step*2);
+        unsigned swap = mp_get_bit(record, step*2+1);
+
+        /*
+         * If this was a division step (!both_odd), and our
+         * coefficient of a is not the even one, we need to adjust the
+         * coefficients by +b and +a respectively.
+         */
+        unsigned adjust = (ac->w[0] & 1) & ~both_odd;
+        mp_cond_add_into(ac, ac, b, adjust);
+        mp_cond_add_into(bc, bc, a, adjust);
+
+        /*
+         * Now, if it was a division step, then ac is even, and we
+         * divide it by two.
+         */
+        mp_rshift_fixed_into(tmp, ac, 1);
+        mp_select_into(ac, tmp, ac, both_odd);
+
+        /*
+         * But if it was a subtraction step, we add ac to bc instead.
+         */
+        mp_cond_add_into(bc, bc, ac, both_odd);
+
+        /*
+         * Undo the transformation of the input numbers, by adding b
+         * to a (if both_odd) or multiplying a by 2 (otherwise).
+         */
+        mp_lshift_fixed_into(tmp, a, 1);
+        mp_add_into(a, a, b);
+        mp_select_into(a, tmp, a, both_odd);
+
+        /*
+         * Finally, undo the swap. If we do swap, this also reverses
+         * the sign of the current result ac*a+bc*b.
+         */
+        mp_cond_swap(a, b, swap);
+        mp_cond_swap(ac, bc, swap);
+        minus_one ^= swap;
+    }
+
+    /*
+     * Now we expect to have recovered the input a,b.
+     */
+    assert(mp_cmp_eq(a, a_in) & mp_cmp_eq(b, b_in));
+
+    /*
+     * But we might find that our current result is -1 instead of +1,
+     * that is, we have A',B' such that A'a - B'b = -1.
+     *
+     * In that situation, we set A = b-A' and B = a-B', giving us
+     * Aa-Bb = ab - A'a - ab + B'b = +1.
+     */
+    mp_sub_into(tmp, b, ac);
+    mp_select_into(ac, ac, tmp, minus_one);
+    mp_sub_into(tmp, a, bc);
+    mp_select_into(bc, bc, tmp, minus_one);
+
+    /*
+     * Now we really are done. Return the outputs.
+     */
+    if (a_coeff_out)
+        mp_copy_into(a_coeff_out, ac);
+    if (b_coeff_out)
+        mp_copy_into(b_coeff_out, bc);
+
+    mp_free(a);
+    mp_free(b);
+    mp_free(ac);
+    mp_free(bc);
+    mp_free(tmp);
+    mp_free(record);
+}
+
+mp_int *mp_invert(mp_int *x, mp_int *m)
+{
+    mp_int *result = mp_make_sized(m->nw);
+    mp_bezout_into(result, NULL, x, m);
+    return result;
+}
+
+static uint32_t recip_approx_32(uint32_t x)
+{
+    /*
+     * Given an input x in [2^31,2^32), i.e. a uint32_t with its high
+     * bit set, this function returns an approximation to 2^63/x,
+     * computed using only multiplications and bit shifts just in case
+     * the C divide operator has non-constant time (either because the
+     * underlying machine instruction does, or because the operator
+     * expands to a library function on a CPU without hardware
+     * division).
+     *
+     * The coefficients are derived from those of the degree-9
+     * polynomial which is the minimax-optimal approximation to that
+     * function on the given interval (generated using the Remez
+     * algorithm), converted into integer arithmetic with shifts used
+     * to maximise the number of significant bits at every state. (A
+     * sort of 'static floating point' - the exponent is statically
+     * known at every point in the code, so it never needs to be
+     * stored at run time or to influence runtime decisions.)
+     *
+     * Exhaustive iteration over the whole input space shows the
+     * largest possible error to be 1686.54. (The input value
+     * attaining that bound is 4226800006 == 0xfbefd986, whose true
+     * reciprocal is 2182116973.540... == 0x8210766d.8a6..., whereas
+     * this function returns 2182115287 == 0x82106fd7.)
+     */
+    uint64_t r = 0x92db03d6ULL;
+    r = 0xf63e71eaULL - ((r*x) >> 34);
+    r = 0xb63721e8ULL - ((r*x) >> 34);
+    r = 0x9c2da00eULL - ((r*x) >> 33);
+    r = 0xaada0bb8ULL - ((r*x) >> 32);
+    r = 0xf75cd403ULL - ((r*x) >> 31);
+    r = 0xecf97a41ULL - ((r*x) >> 31);
+    r = 0x90d876cdULL - ((r*x) >> 31);
+    r = 0x6682799a0ULL - ((r*x) >> 26);
+    return r;
+}
+
+void mp_divmod_into(mp_int *n, mp_int *d, mp_int *q_out, mp_int *r_out)
+{
+    assert(!mp_eq_integer(d, 0));
+
+    /*
+     * We do division by using Newton-Raphson iteration to converge to
+     * the reciprocal of d (or rather, R/d for R a sufficiently large
+     * power of 2); then we multiply that reciprocal by n; and we
+     * finish up with conditional subtraction.
+     *
+     * But we have to do it in a fixed number of N-R iterations, so we
+     * need some error analysis to know how many we might need.
+     *
+     * The iteration is derived by defining f(r) = d - R/r.
+     * Differentiating gives f'(r) = R/r^2, and the Newton-Raphson
+     * formula applied to those functions gives
+     *
+     *      r_{i+1} = r_i - f(r_i) / f'(r_i)
+     *              = r_i - (d - R/r_i) r_i^2 / R
+     *              = r_i (2 R - d r_i) / R
+     *
+     * Now let e_i be the error in a given iteration, in the sense
+     * that
+     *
+     *        d r_i = R + e_i
+     *  i.e.  e_i/R = (r_i - r_true) / r_true
+     *
+     * so e_i is the _relative_ error in r_i.
+     *
+     * We must also introduce a rounding-error term, because the
+     * division by R always gives an integer. This might make the
+     * output off by up to 1 (in the negative direction, because
+     * right-shifting gives floor of the true quotient). So when we
+     * divide by R, we must imagine adding some f in [0,1). Then we
+     * have
+     *
+     *    d r_{i+1} = d r_i (2 R - d r_i) / R - d f
+     *              = (R + e_i) (R - e_i) / R - d f
+     *              = (R^2 - e_i^2) / R - d f
+     *              = R - (e_i^2 / R + d f)
+     * =>   e_{i+1} = - (e_i^2 / R + d f)
+     *
+     * The sum of two positive quantities is bounded above by twice
+     * their max, and max |f| = 1, so we can bound this as follows:
+     *
+     *               |e_{i+1}| <= 2 max (e_i^2/R, d)
+     *             |e_{i+1}/R| <= 2 max ((e_i/R)^2, d/R)
+     *        log2 |R/e_{i+1}| <= min (2 log2 |R/e_i|, log2 |R/d|) - 1
+     *
+     * which tells us that the number of 'good' bits - i.e.
+     * log2(R/e_i) - very nearly doubles at every iteration (apart
+     * from that subtraction of 1), until it gets to the same size as
+     * log2(R/d). In other words, the size of R in bits has to be the
+     * size of denominator we're putting in, _plus_ the amount of
+     * precision we want to get back out.
+     *
+     * So when we multiply n (the input numerator) by our final
+     * reciprocal approximation r, but actually r differs from R/d by
+     * up to 2, then it follows that 
+     *
+     *   n/d - nr/R = n/d - [ n (R/d + e) ] / R
+     *              = n/d - [ (n/d) R + n e ] / R
+     *              = -ne/R
+     *      =>   0 <= n/d - nr/R < 2n/R
+     *
+     * so our computed quotient can differ from the true n/d by up to
+     * 2n/R. Hence, as long as we also choose R large enough that 2n/R
+     * is bounded above by a constant, we can guarantee a bounded
+     * number of final conditional-subtraction steps.
+     */
+
+    /*
+     * Get at least 32 of the most significant bits of the input
+     * number.
+     */
+    size_t hiword_index = 0;
+    uint64_t hibits = 0, lobits = 0;
+    mp_find_highest_nonzero_word_pair(d, 64 - BIGNUM_INT_BITS,
+                                      &hiword_index, &hibits, &lobits);
+
+    /*
+     * Make a shifted combination of those two words which puts the
+     * topmost bit of the number at bit 63.
+     */
+    size_t shift_up = 0;
+    for (size_t i = BIGNUM_INT_BITS_BITS; i-- > 0;) {
+        size_t sl = 1 << i;               /* left shift count */
+        size_t sr = BIGNUM_INT_BITS - sl; /* complementary right-shift count */
+
+        /* Should we shift up? */
+        unsigned indicator = 1 ^ normalise_to_1(hibits >> sr);
+
+        /* If we do, what will we get? */
+        uint64_t new_hibits = (hibits << sl) | (lobits >> sr);
+        uint64_t new_lobits = lobits << sl;
+        size_t new_shift_up = shift_up + sl;
+
+        /* Conditionally swap those values in. */
+        hibits    ^= (hibits    ^ new_hibits   ) & -(BignumInt)indicator;
+        lobits    ^= (lobits    ^ new_lobits   ) & -(BignumInt)indicator;
+        shift_up  ^= (shift_up  ^ new_shift_up ) & -(size_t)   indicator;
+    }
+
+    /*
+     * So now we know the most significant 32 bits of d are at the top
+     * of hibits. Approximate the reciprocal of those bits.
+     */
+    lobits = (uint64_t)recip_approx_32(hibits >> 32) << 32;
+    hibits = 0;
+
+    /*
+     * And shift that up by as many bits as the input was shifted up
+     * just now, so that the product of this approximation and the
+     * actual input will be close to a fixed power of two regardless
+     * of where the MSB was.
+     *
+     * I do this in another log n individual passes, not so much
+     * because I'm worried about the time-invariance of the CPU's
+     * register-controlled shift operation, but in case the compiler
+     * code-generates uint64_t shifts out of a variable number of
+     * smaller-word shift instructions, e.g. by splitting up into
+     * cases.
+     */
+    for (size_t i = BIGNUM_INT_BITS_BITS; i-- > 0;) {
+        size_t sl = 1 << i;               /* left shift count */
+        size_t sr = BIGNUM_INT_BITS - sl; /* complementary right-shift count */
+
+        /* Should we shift up? */
+        unsigned indicator = 1 & (shift_up >> i);
+
+        /* If we do, what will we get? */
+        uint64_t new_hibits = (hibits << sl) | (lobits >> sr);
+        uint64_t new_lobits = lobits << sl;
+
+        /* Conditionally swap those values in. */
+        hibits    ^= (hibits    ^ new_hibits   ) & -(BignumInt)indicator;
+        lobits    ^= (lobits    ^ new_lobits   ) & -(BignumInt)indicator;
+    }
+
+    /*
+     * The product of the 128-bit value now in hibits:lobits with the
+     * 128-bit value we originally retrieved in the same variables
+     * will be in the vicinity of 2^191. So we'll take log2(R) to be
+     * 191, plus a multiple of BIGNUM_INT_BITS large enough to allow R
+     * to hold the combined sizes of n and d.
+     */
+    size_t log2_R;
+    {
+        size_t max_log2_n = (n->nw + d->nw) * BIGNUM_INT_BITS;
+        log2_R = max_log2_n + 3;
+        log2_R -= size_t_min(191, log2_R);
+        log2_R = (log2_R + BIGNUM_INT_BITS - 1) & ~(BIGNUM_INT_BITS - 1);
+        log2_R += 191;
+    }
+
+    /* Number of words in a bignum capable of holding numbers the size
+     * of twice R. */
+    size_t rw = ((log2_R+2) + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
+
+    /*
+     * Now construct our full-sized starting reciprocal approximation.
+     */
+    mp_int *r_approx = mp_make_sized(rw);
+    size_t output_bit_index;
+    {
+        /* Where in the input number did the input 128-bit value come from? */
+        size_t input_bit_index =
+            (hiword_index * BIGNUM_INT_BITS) - (128 - BIGNUM_INT_BITS);
+
+        /* So how far do we need to shift our 64-bit output, if the
+         * product of those two fixed-size values is 2^191 and we want
+         * to make it 2^log2_R instead? */
+        output_bit_index = log2_R - 191 - input_bit_index;
+
+        /* If we've done all that right, it should be a whole number
+         * of words. */
+        assert(output_bit_index % BIGNUM_INT_BITS == 0);
+        size_t output_word_index = output_bit_index / BIGNUM_INT_BITS;
+
+        mp_add_integer_into_shifted_by_words(
+            r_approx, r_approx, lobits, output_word_index);
+        mp_add_integer_into_shifted_by_words(
+            r_approx, r_approx, hibits,
+            output_word_index + 64 / BIGNUM_INT_BITS);
+    }
+
+    /*
+     * Make the constant 2*R, which we'll need in the iteration.
+     */
+    mp_int *two_R = mp_make_sized(rw);
+    mp_add_integer_into_shifted_by_words(
+        two_R, two_R, (BignumInt)1 << ((log2_R+1) % BIGNUM_INT_BITS),
+        (log2_R+1) / BIGNUM_INT_BITS);
+
+    /*
+     * Scratch space.
+     */
+    mp_int *dr = mp_make_sized(rw + d->nw);
+    mp_int *diff = mp_make_sized(size_t_max(rw, dr->nw));
+    mp_int *product = mp_make_sized(rw + diff->nw);
+    size_t scratchsize = size_t_max(
+        mp_mul_scratchspace(dr->nw, r_approx->nw, d->nw),
+        mp_mul_scratchspace(product->nw, r_approx->nw, diff->nw));
+    mp_int *scratch = mp_make_sized(scratchsize);
+    mp_int product_shifted = mp_make_alias(
+        product, log2_R / BIGNUM_INT_BITS, product->nw);
+
+    /*
+     * Initial error estimate: the 32-bit output of recip_approx_32
+     * differs by less than 2048 (== 2^11) from the true top 32 bits
+     * of the reciprocal, so the relative error is at most 2^11
+     * divided by the 32-bit reciprocal, which at worst is 2^11/2^31 =
+     * 2^-20. So even in the worst case, we have 20 good bits of
+     * reciprocal to start with.
+     */
+    size_t good_bits = 31 - 11;
+    size_t good_bits_needed = BIGNUM_INT_BITS * n->nw + 4; /* add a few */
+
+    /*
+     * Now do Newton-Raphson iterations until we have reason to think
+     * they're not converging any more.
+     */
+    while (good_bits < good_bits_needed) {
+        /*
+         * Compute the next iterate.
+         */
+        mp_mul_internal(dr, r_approx, d, *scratch);
+        mp_sub_into(diff, two_R, dr);
+        mp_mul_internal(product, r_approx, diff, *scratch);
+        mp_rshift_fixed_into(r_approx, &product_shifted,
+                             log2_R % BIGNUM_INT_BITS);
+
+        /*
+         * Adjust the error estimate.
+         */
+        good_bits = good_bits * 2 - 1;
+    }
+
+    mp_free(dr);
+    mp_free(diff);
+    mp_free(product);
+    mp_free(scratch);
+
+    /*
+     * Now we've got our reciprocal, we can compute the quotient, by
+     * multiplying in n and then shifting down by log2_R bits.
+     */
+    mp_int *quotient_full = mp_mul(r_approx, n);
+    mp_int quotient_alias = mp_make_alias(
+        quotient_full, log2_R / BIGNUM_INT_BITS, quotient_full->nw);
+    mp_int *quotient = mp_make_sized(n->nw);
+    mp_rshift_fixed_into(quotient, &quotient_alias, log2_R % BIGNUM_INT_BITS);
+
+    /*
+     * Next, compute the remainder.
+     */
+    mp_int *remainder = mp_make_sized(d->nw);
+    mp_mul_into(remainder, quotient, d);
+    mp_sub_into(remainder, n, remainder);
+
+    /*
+     * Finally, two conditional subtractions to fix up any remaining
+     * rounding error. (I _think_ one should be enough, but this
+     * routine isn't time-critical enough to take chances.)
+     */
+    unsigned q_correction = 0;
+    for (unsigned iter = 0; iter < 2; iter++) {
+        unsigned need_correction = mp_cmp_hs(remainder, d);
+        mp_cond_sub_into(remainder, remainder, d, need_correction);
+        q_correction += need_correction;
+    }
+    mp_add_integer_into(quotient, quotient, q_correction);
+
+    /*
+     * Now we should have a perfect answer, i.e. 0 <= r < d.
+     */
+    assert(!mp_cmp_hs(remainder, d));
+
+    if (q_out)
+        mp_copy_into(q_out, quotient);
+    if (r_out)
+        mp_copy_into(r_out, remainder);
+
+    mp_free(r_approx);
+    mp_free(two_R);
+    mp_free(quotient_full);
+    mp_free(quotient);
+    mp_free(remainder);
+}
+
+mp_int *mp_div(mp_int *n, mp_int *d)
+{
+    mp_int *q = mp_make_sized(n->nw);
+    mp_divmod_into(n, d, q, NULL);
+    return q;
+}
+
+mp_int *mp_mod(mp_int *n, mp_int *d)
+{
+    mp_int *r = mp_make_sized(d->nw);
+    mp_divmod_into(n, d, NULL, r);
+    return r;
+}
+
+mp_int *mp_modmul(mp_int *x, mp_int *y, mp_int *modulus)
+{
+    mp_int *product = mp_mul(x, y);
+    mp_int *reduced = mp_mod(product, modulus);
+    mp_free(product);
+    return reduced;
+}
+
+mp_int *mp_modadd(mp_int *x, mp_int *y, mp_int *modulus)
+{
+    mp_int *sum = mp_add(x, y);
+    mp_int *reduced = mp_mod(sum, modulus);
+    mp_free(sum);
+    return reduced;
+}
+
+mp_int *mp_modsub(mp_int *x, mp_int *y, mp_int *modulus)
+{
+    mp_int *diff = mp_make_sized(size_t_max(x->nw, y->nw));
+    mp_sub_into(diff, x, y);
+    unsigned negate = mp_cmp_hs(y, x);
+    mp_cond_negate(diff, diff, negate);
+    mp_int *reduced = mp_mod(diff, modulus);
+    mp_cond_negate(reduced, reduced, negate);
+    mp_cond_add_into(reduced, reduced, modulus, negate);
+    mp_free(diff);
+    return reduced;
+}
+
+static mp_int *mp_modadd_in_range(mp_int *x, mp_int *y, mp_int *modulus)
+{
+    mp_int *sum = mp_make_sized(modulus->nw);
+    unsigned carry = mp_add_into_internal(sum, x, y);
+    mp_cond_sub_into(sum, sum, modulus, carry | mp_cmp_hs(sum, modulus));
+    return sum;
+}
+
+static mp_int *mp_modsub_in_range(mp_int *x, mp_int *y, mp_int *modulus)
+{
+    mp_int *diff = mp_make_sized(modulus->nw);
+    mp_sub_into(diff, x, y);
+    mp_cond_add_into(diff, diff, modulus, 1 ^ mp_cmp_hs(x, y));
+    return diff;
+}
+
+mp_int *monty_add(MontyContext *mc, mp_int *x, mp_int *y)
+{
+    return mp_modadd_in_range(x, y, mc->m);
+}
+
+mp_int *monty_sub(MontyContext *mc, mp_int *x, mp_int *y)
+{
+    return mp_modsub_in_range(x, y, mc->m);
+}
+
+void mp_min_into(mp_int *r, mp_int *x, mp_int *y)
+{
+    mp_select_into(r, x, y, mp_cmp_hs(x, y));
+}
+
+mp_int *mp_min(mp_int *x, mp_int *y)
+{
+    mp_int *r = mp_make_sized(size_t_min(x->nw, y->nw));
+    mp_min_into(r, x, y);
+    return r;
+}
+
+mp_int *mp_power_2(size_t power)
+{
+    mp_int *x = mp_new(power + 1);
+    mp_set_bit(x, power, 1);
+    return x;
+}
+
+struct ModsqrtContext {
+    mp_int *p;                      /* the prime */
+    MontyContext *mc;                  /* for doing arithmetic mod p */
+
+    /* Decompose p-1 as 2^e k, for positive integer e and odd k */
+    size_t e;
+    mp_int *k;
+    mp_int *km1o2;                  /* (k-1)/2 */
+
+    /* The user-provided value z which is not a quadratic residue mod
+     * p, and its kth power. Both in Montgomery form. */
+    mp_int *z, *zk;
+};
+
+ModsqrtContext *modsqrt_new(mp_int *p, mp_int *any_nonsquare_mod_p)
+{
+    ModsqrtContext *sc = snew(ModsqrtContext);
+    memset(sc, 0, sizeof(ModsqrtContext));
+
+    sc->p = mp_copy(p);
+    sc->mc = monty_new(sc->p);
+    sc->z = monty_import(sc->mc, any_nonsquare_mod_p);
+
+    /* Find the lowest set bit in p-1. Since this routine expects p to
+     * be non-secret (typically a well-known standard elliptic curve
+     * parameter), for once we don't need clever bit tricks. */
+    for (sc->e = 1; sc->e < BIGNUM_INT_BITS * p->nw; sc->e++)
+        if (mp_get_bit(p, sc->e))
+            break;
+
+    sc->k = mp_rshift_fixed(p, sc->e);
+    sc->km1o2 = mp_rshift_fixed(sc->k, 1);
+
+    /* Leave zk to be filled in lazily, since it's more expensive to
+     * compute. If this context turns out never to be needed, we can
+     * save the bulk of the setup time this way. */
+
+    return sc;
+}
+
+static void modsqrt_lazy_setup(ModsqrtContext *sc)
+{
+    if (!sc->zk)
+        sc->zk = monty_pow(sc->mc, sc->z, sc->k);
+}
+
+void modsqrt_free(ModsqrtContext *sc)
+{
+    monty_free(sc->mc);
+    mp_free(sc->p);
+    mp_free(sc->z);
+    mp_free(sc->k);
+    mp_free(sc->km1o2);
+
+    if (sc->zk)
+        mp_free(sc->zk);
+
+    sfree(sc);
+}
+
+mp_int *mp_modsqrt(ModsqrtContext *sc, mp_int *x, unsigned *success)
+{
+    mp_int *mx = monty_import(sc->mc, x);
+    mp_int *mroot = monty_modsqrt(sc, mx, success);
+    mp_free(mx);
+    mp_int *root = monty_export(sc->mc, mroot);
+    mp_free(mroot);
+    return root;
+}
+
+/*
+ * Modular square root, using an algorithm more or less similar to
+ * Tonelli-Shanks but adapted for constant time.
+ *
+ * The basic idea is to write p-1 = k 2^e, where k is odd and e > 0.
+ * Then the multiplicative group mod p (call it G) has a sequence of
+ * e+1 nested subgroups G = G_0 > G_1 > G_2 > ... > G_e, where each
+ * G_i is exactly half the size of G_{i-1} and consists of all the
+ * squares of elements in G_{i-1}. So the innermost group G_e has
+ * order k, which is odd, and hence within that group you can take a
+ * square root by raising to the power (k+1)/2.
+ *
+ * Our strategy is to iterate over these groups one by one and make
+ * sure the number x we're trying to take the square root of is inside
+ * each one, by adjusting it if it isn't.
+ *
+ * Suppose g is a primitive root of p, i.e. a generator of G_0. (We
+ * don't actually need to know what g _is_; we just imagine it for the
+ * sake of understanding.) Then G_i consists of precisely the (2^i)th
+ * powers of g, and hence, you can tell if a number is in G_i if
+ * raising it to the power k 2^{e-i} gives 1. So the conceptual
+ * algorithm goes: for each i, test whether x is in G_i by that
+ * method. If it isn't, then the previous iteration ensured it's in
+ * G_{i-1}, so it will be an odd power of g^{2^{i-1}}, and hence
+ * multiplying by any other odd power of g^{2^{i-1}} will give x' in
+ * G_i. And we have one of those, because our non-square z is an odd
+ * power of g, so z^{2^{i-1}} is an odd power of g^{2^{i-1}}.
+ *
+ * (There's a special case in the very first iteration, where we don't
+ * have a G_{i-1}. If it turns out that x is not even in G_1, that
+ * means it's not a square, so we set *success to 0. We still run the
+ * rest of the algorithm anyway, for the sake of constant time, but we
+ * don't give a hoot what it returns.)
+ *
+ * When we get to the end and have x in G_e, then we can take its
+ * square root by raising to (k+1)/2. But of course that's not the
+ * square root of the original input - it's only the square root of
+ * the adjusted version we produced during the algorithm. To get the
+ * true output answer we also have to multiply by a power of z,
+ * namely, z to the power of _half_ whatever we've been multiplying in
+ * as we go along. (The power of z we multiplied in must have been
+ * even, because the case in which we would have multiplied in an odd
+ * power of z is the i=0 case, in which we instead set the failure
+ * flag.)
+ *
+ * The code below is an optimised version of that basic idea, in which
+ * we _start_ by computing x^k so as to be able to test membership in
+ * G_i by only a few squarings rather than a full from-scratch modpow
+ * every time; we also start by computing our candidate output value
+ * x^{(k+1)/2}. So when the above description says 'adjust x by z^i'
+ * for some i, we have to adjust our running values of x^k and
+ * x^{(k+1)/2} by z^{ik} and z^{ik/2} respectively (the latter is safe
+ * because, as above, i is always even). And it turns out that we
+ * don't actually have to store the adjusted version of x itself at
+ * all - we _only_ keep those two powers of it.
+ */
+mp_int *monty_modsqrt(ModsqrtContext *sc, mp_int *x, unsigned *success)
+{
+    modsqrt_lazy_setup(sc);
+
+    mp_int *scratch_to_free = mp_make_sized(3 * sc->mc->rw);
+    mp_int scratch = *scratch_to_free;
+
+    /*
+     * Compute toret = x^{(k+1)/2}, our starting point for the output
+     * square root, and also xk = x^k which we'll use as we go along
+     * for knowing when to apply correction factors. We do this by
+     * first computing x^{(k-1)/2}, then multiplying it by x, then
+     * multiplying the two together.
+     */
+    mp_int *toret = monty_pow(sc->mc, x, sc->km1o2);
+    mp_int xk = mp_alloc_from_scratch(&scratch, sc->mc->rw);
+    mp_copy_into(&xk, toret);
+    monty_mul_into(sc->mc, toret, toret, x);
+    monty_mul_into(sc->mc, &xk, toret, &xk);
+
+    mp_int tmp = mp_alloc_from_scratch(&scratch, sc->mc->rw);
+
+    mp_int power_of_zk = mp_alloc_from_scratch(&scratch, sc->mc->rw);
+    mp_copy_into(&power_of_zk, sc->zk);
+
+    for (size_t i = 0; i < sc->e; i++) {
+        mp_copy_into(&tmp, &xk);
+        for (size_t j = i+1; j < sc->e; j++)
+            monty_mul_into(sc->mc, &tmp, &tmp, &tmp);
+        unsigned eq1 = mp_cmp_eq(&tmp, monty_identity(sc->mc));
+
+        if (i == 0) {
+            *success = eq1;
+        } else {
+            monty_mul_into(sc->mc, &tmp, toret, &power_of_zk);
+            mp_select_into(toret, &tmp, toret, eq1);
+
+            monty_mul_into(sc->mc, &power_of_zk,
+                           &power_of_zk, &power_of_zk);
+
+            monty_mul_into(sc->mc, &tmp, &xk, &power_of_zk);
+            mp_select_into(&xk, &tmp, &xk, eq1);
+        }
+    }
+
+    mp_free(scratch_to_free);
+
+    return toret;
+}
+
+mp_int *mp_random_bits_fn(size_t bits, int (*gen_byte)(void))
+{
+    size_t bytes = (bits + 7) / 8;
+    size_t words = (bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS;
+    mp_int *x = mp_make_sized(words);
+    for (size_t i = 0; i < bytes; i++) {
+        BignumInt byte = gen_byte();
+        unsigned mask = (1 << size_t_min(8, bits-i*8)) - 1;
+        x->w[i / BIGNUM_INT_BYTES] |=
+            (byte & mask) << (8*(i % BIGNUM_INT_BYTES));
+    }
+    return x;
+}
+
+mp_int *mp_random_in_range_fn(mp_int *lo, mp_int *hi, int (*gen_byte)(void))
+{
+    mp_int *n_outcomes = mp_sub(hi, lo);
+
+    /*
+     * It would be nice to generate our random numbers in such a way
+     * as to make every possible outcome literally equiprobable. But
+     * we can't do that in constant time, so we have to go for a very
+     * close approximation instead. I'm going to take the view that a
+     * factor of (1+2^-128) between the probabilities of two outcomes
+     * is acceptable on the grounds that you'd have to examine so many
+     * outputs to even detect it.
+     */
+    mp_int *unreduced = mp_random_bits_fn(
+        mp_max_bits(n_outcomes) + 128, gen_byte);
+    mp_int *reduced = mp_mod(unreduced, n_outcomes);
+    mp_add_into(reduced, reduced, lo);
+    mp_free(unreduced);
+    mp_free(n_outcomes);
+    return reduced;
+}

+ 386 - 0
source/putty/mpint.h

@@ -0,0 +1,386 @@
+#ifndef PUTTY_MPINT_H
+#define PUTTY_MPINT_H
+
+/*
+ * PuTTY's multiprecision integer library.
+ *
+ * This library is written with the aim of avoiding leaking the input
+ * numbers via timing and cache side channels. This means avoiding
+ * making any control flow change, or deciding the address of any
+ * memory access, based on the value of potentially secret input data.
+ *
+ * But in a library that has to handle numbers of arbitrary size, you
+ * can't avoid your control flow depending on the _size_ of the input!
+ * So the rule is that an mp_int has a nominal size that need not be
+ * its mathematical size: i.e. if you call (say) mp_from_bytes_be to
+ * turn an array of 256 bytes into an integer, and all but the last of
+ * those bytes is zero, then you get an mp_int which has space for 256
+ * bytes of data but just happens to store the value 1. So the
+ * _nominal_ sizes of input data - e.g. the size in bits of some
+ * public-key modulus - are not considered secret, and control flow is
+ * allowed to do what it likes based on those sizes. But the same
+ * function, called with the same _nominally sized_ arguments
+ * containing different values, should run in the same length of time.
+ *
+ * When a function returns an 'mp_int *', it is newly allocated to an
+ * appropriate nominal size (which, again, depends only on the nominal
+ * sizes of the inputs). Other functions have 'into' in their name,
+ * and they instead overwrite the contents of an existing mp_int.
+ *
+ * Functions in this API which return values that are logically
+ * boolean return them as 'unsigned' rather than the C99 bool type.
+ * That's because C99 bool does an implicit test for non-zero-ness
+ * when converting any other integer type to it, which compilers might
+ * well implement using data-dependent control flow.
+ */
+
+/*
+ * Create and destroy mp_ints. A newly created one is initialised to
+ * zero. mp_clear also resets an existing number to zero.
+ */
+mp_int *mp_new(size_t maxbits);
+void mp_free(mp_int *);
+void mp_clear(mp_int *x);
+
+/*
+ * Create mp_ints from various sources: little- and big-endian binary
+ * data, an ordinary C unsigned integer type, a decimal or hex string
+ * (given either as a ptrlen or a C NUL-terminated string), and
+ * another mp_int.
+ *
+ * The decimal and hex conversion functions have running time
+ * dependent on the length of the input data, of course.
+ */
+mp_int *mp_from_bytes_le(ptrlen bytes);
+mp_int *mp_from_bytes_be(ptrlen bytes);
+mp_int *mp_from_integer(uintmax_t n);
+mp_int *mp_from_decimal_pl(ptrlen decimal);
+mp_int *mp_from_decimal(const char *decimal);
+mp_int *mp_from_hex_pl(ptrlen hex);
+mp_int *mp_from_hex(const char *hex);
+mp_int *mp_copy(mp_int *x);
+
+/*
+ * A macro for declaring large fixed numbers in source code (such as
+ * elliptic curve parameters, or standard Diffie-Hellman moduli). The
+ * idea is that you just write something like
+ *
+ *   mp_int *value = MP_LITERAL(0x19284376283754638745693467245);
+ *
+ * and it newly allocates you an mp_int containing that number.
+ *
+ * Internally, the macro argument is stringified and passed to
+ * mp_from_hex. That's not as fast as it could be if I had instead set
+ * up some kind of mp_from_array_of_uint64_t() function, but I think
+ * this system is valuable for the fact that the literal integers
+ * appear in a very natural syntax that can be pasted directly out
+ * into, say, Python if you want to cross-check a calculation.
+ */
+static inline mp_int *mp__from_string_literal(const char *lit)
+{
+    /* Don't call this directly; it's not equipped to deal with
+     * hostile data. Use only via the MP_LITERAL macro. */
+    if (lit[0] && (lit[1] == 'x' || lit[1] == 'X'))
+        return mp_from_hex(lit+2);
+    else
+        return mp_from_decimal(lit);
+}
+#define MP_LITERAL(number) mp__from_string_literal(#number)
+
+/*
+ * Create an mp_int with the value 2^power.
+ */
+mp_int *mp_power_2(size_t power);
+
+/*
+ * Retrieve the value of a particular bit or byte of an mp_int. The
+ * byte / bit index is not considered to be secret data. Out-of-range
+ * byte/bit indices are handled cleanly and return zero.
+ */
+uint8_t mp_get_byte(mp_int *x, size_t byte);
+unsigned mp_get_bit(mp_int *x, size_t bit);
+
+/*
+ * Set an mp_int bit. Again, the bit index is not considered secret.
+ * Do not pass an out-of-range index, on pain of assertion failure.
+ */
+void mp_set_bit(mp_int *x, size_t bit, unsigned val);
+
+/*
+ * Return the nominal size of an mp_int, in terms of the maximum
+ * number of bytes or bits that can fit in it.
+ */
+size_t mp_max_bytes(mp_int *x);
+size_t mp_max_bits(mp_int *x);
+
+/*
+ * Return the _mathematical_ bit count of an mp_int (not its nominal
+ * size), i.e. a value n such that 2^{n-1} <= x < 2^n.
+ *
+ * This function is supposed to run in constant time for a given
+ * nominal input size. Of course it's likely that clients of this
+ * function will promptly need to use the result as the limit of some
+ * loop (e.g. marshalling an mp_int into an SSH packet, which doesn't
+ * permit extra prefix zero bytes). But that's up to the caller to
+ * decide the safety of.
+ */
+size_t mp_get_nbits(mp_int *x);
+
+/*
+ * Return the value of an mp_int as a decimal or hex string. The
+ * result is dynamically allocated, and the caller is responsible for
+ * freeing it.
+ *
+ * These functions should run in constant time for a given nominal
+ * input size, even though the exact number of digits returned is
+ * variable. They always allocate enough space for the largest output
+ * that might be needed, but they don't always fill it.
+ */
+char *mp_get_decimal(mp_int *x);
+char *mp_get_hex(mp_int *x);
+char *mp_get_hex_uppercase(mp_int *x);
+
+/*
+ * Compare two mp_ints, or compare one mp_int against a C integer. The
+ * 'eq' functions return 1 if the two inputs are equal, or 0
+ * otherwise; the 'hs' functions return 1 if the first input is >= the
+ * second, and 0 otherwise.
+ */
+unsigned mp_cmp_hs(mp_int *a, mp_int *b);
+unsigned mp_cmp_eq(mp_int *a, mp_int *b);
+unsigned mp_hs_integer(mp_int *x, uintmax_t n);
+unsigned mp_eq_integer(mp_int *x, uintmax_t n);
+
+/*
+ * Take the minimum of two mp_ints, without using a conditional branch.
+ */
+void mp_min_into(mp_int *r, mp_int *x, mp_int *y);
+mp_int *mp_min(mp_int *x, mp_int *y);
+
+/*
+ * Diagnostic function. Writes out x in hex to the supplied stdio
+ * stream, preceded by the string 'prefix' and followed by 'suffix'.
+ *
+ * This is useful to put temporarily into code, but it's also
+ * potentially useful to call from a debugger.
+ */
+void mp_dump(FILE *fp, const char *prefix, mp_int *x, const char *suffix);
+
+/*
+ * Overwrite one mp_int with another.
+ */
+void mp_copy_into(mp_int *dest, mp_int *src);
+
+/*
+ * Conditional selection. Overwrites dest with either src0 or src1,
+ * according to the value of 'choose_src1'. choose_src1 should be 0 or
+ * 1; if it's 1, then dest is set to src1, otherwise src0.
+ *
+ * The value of choose_src1 is considered to be secret data, so
+ * control flow and memory access should not depend on it.
+ */
+void mp_select_into(mp_int *dest, mp_int *src0, mp_int *src1,
+                    unsigned choose_src1);
+
+/*
+ * Addition, subtraction and multiplication, either targeting an
+ * existing mp_int or making a new one large enough to hold whatever
+ * the output might be..
+ */
+void mp_add_into(mp_int *r, mp_int *a, mp_int *b);
+void mp_sub_into(mp_int *r, mp_int *a, mp_int *b);
+void mp_mul_into(mp_int *r, mp_int *a, mp_int *b);
+mp_int *mp_add(mp_int *x, mp_int *y);
+mp_int *mp_sub(mp_int *x, mp_int *y);
+mp_int *mp_mul(mp_int *x, mp_int *y);
+
+/*
+ * Addition, subtraction and multiplication with one argument small
+ * enough to fit in a C integer. For mp_mul_integer_into, it has to be
+ * even smaller than that.
+ */
+void mp_add_integer_into(mp_int *r, mp_int *a, uintmax_t n);
+void mp_sub_integer_into(mp_int *r, mp_int *a, uintmax_t n);
+void mp_mul_integer_into(mp_int *r, mp_int *a, uint16_t n);
+
+/*
+ * Conditional addition/subtraction. If yes == 1, sets r to a+b or a-b
+ * (respectively). If yes == 0, sets r to just a. 'yes' is considered
+ * secret data.
+ */
+void mp_cond_add_into(mp_int *r, mp_int *a, mp_int *b, unsigned yes);
+void mp_cond_sub_into(mp_int *r, mp_int *a, mp_int *b, unsigned yes);
+
+/*
+ * Swap x0 and x1 if swap == 1, and not if swap == 0. 'swap' is
+ * considered secret.
+ */
+void mp_cond_swap(mp_int *x0, mp_int *x1, unsigned swap);
+
+/*
+ * Set x to 0 if clear == 1, and otherwise leave it unchanged. 'clear'
+ * is considered secret.
+ */
+void mp_cond_clear(mp_int *x, unsigned clear);
+
+/*
+ * Division. mp_divmod_into divides n by d, and writes the quotient
+ * into q and the remainder into r. You can pass either of q and r as
+ * NULL if you don't need one of the outputs.
+ *
+ * mp_div and mp_mod are wrappers that return one or other of those
+ * outputs as a freshly allocated mp_int of the appropriate size.
+ *
+ * Division by zero gives no error, and returns a quotient of 0 and a
+ * remainder of n (so as to still satisfy the division identity that
+ * n=qd+r).
+ */
+void mp_divmod_into(mp_int *n, mp_int *d, mp_int *q, mp_int *r);
+mp_int *mp_div(mp_int *n, mp_int *d);
+mp_int *mp_mod(mp_int *x, mp_int *modulus);
+
+/*
+ * Trivially easy special case of mp_mod: reduce a number mod a power
+ * of two.
+ */
+void mp_reduce_mod_2to(mp_int *x, size_t p);
+
+/*
+ * Modular inverses. mp_invert computes the inverse of x mod modulus
+ * (and will expect the two to be coprime). mp_invert_mod_2to computes
+ * the inverse of x mod 2^p, and is a great deal faster.
+ */
+mp_int *mp_invert_mod_2to(mp_int *x, size_t p);
+mp_int *mp_invert(mp_int *x, mp_int *modulus);
+
+/*
+ * System for taking square roots modulo an odd prime.
+ *
+ * In order to do this efficiently, you need to provide an extra piece
+ * of information at setup time, namely a number which is not
+ * congruent mod p to any square. Given p and that non-square, you can
+ * use modsqrt_new to make a context containing all the necessary
+ * equipment for actually calculating the square roots, and then you
+ * can call mp_modsqrt as many times as you like on that context
+ * before freeing it.
+ *
+ * The output parameter '*success' will be filled in with 1 if the
+ * operation was successful, or 0 if the input number doesn't have a
+ * square root mod p at all. In the latter case, the returned mp_int
+ * will be nonsense and you shouldn't depend on it.
+ *
+ * ==== WARNING ====
+ *
+ * This function DOES NOT TREAT THE PRIME MODULUS AS SECRET DATA! It
+ * will protect the number you're taking the square root _of_, but not
+ * the number you're taking the root of it _mod_.
+ *
+ * (This is because the algorithm requires a number of loop iterations
+ * equal to the number of factors of 2 in p-1. And the expected use of
+ * this function is for elliptic-curve point decompression, in which
+ * the modulus is always a well-known one written down in standards
+ * documents.)
+ */
+typedef struct ModsqrtContext ModsqrtContext;
+ModsqrtContext *modsqrt_new(mp_int *p, mp_int *any_nonsquare_mod_p);
+void modsqrt_free(ModsqrtContext *);
+mp_int *mp_modsqrt(ModsqrtContext *sc, mp_int *x, unsigned *success);
+
+/*
+ * Functions for Montgomery multiplication, a fast technique for doing
+ * a long series of modular multiplications all with the same modulus
+ * (which has to be odd).
+ *
+ * You start by calling monty_new to set up a context structure
+ * containing all the precomputed bits and pieces needed by the
+ * algorithm. Then, any numbers you want to work with must first be
+ * transformed into the internal Montgomery representation using
+ * monty_import; having done that, you can use monty_mul and monty_pow
+ * to operate on them efficiently; and finally, monty_export will
+ * convert numbers back out of Montgomery representation to give their
+ * ordinary values.
+ *
+ * Addition and subtraction are not optimised by the Montgomery trick,
+ * but monty_add and monty_sub are provided anyway for convenience.
+ *
+ * There are also monty_invert and monty_modsqrt, which are analogues
+ * of mp_invert and mp_modsqrt which take their inputs in Montgomery
+ * representation. For mp_modsqrt, the prime modulus of the
+ * ModsqrtContext must be the same as the modulus of the MontyContext.
+ *
+ * The query functions monty_modulus and monty_identity return numbers
+ * stored inside the MontyContext, without copying them. The returned
+ * pointers are still owned by the MontyContext, so don't free them!
+ */
+MontyContext *monty_new(mp_int *modulus);
+MontyContext *monty_copy(MontyContext *mc);
+void monty_free(MontyContext *mc);
+mp_int *monty_modulus(MontyContext *mc); /* doesn't transfer ownership */
+mp_int *monty_identity(MontyContext *mc); /* doesn't transfer ownership */
+void monty_import_into(MontyContext *mc, mp_int *r, mp_int *x);
+mp_int *monty_import(MontyContext *mc, mp_int *x);
+void monty_export_into(MontyContext *mc, mp_int *r, mp_int *x);
+mp_int *monty_export(MontyContext *mc, mp_int *x);
+void monty_mul_into(MontyContext *, mp_int *r, mp_int *, mp_int *);
+mp_int *monty_add(MontyContext *, mp_int *, mp_int *);
+mp_int *monty_sub(MontyContext *, mp_int *, mp_int *);
+mp_int *monty_mul(MontyContext *, mp_int *, mp_int *);
+mp_int *monty_pow(MontyContext *, mp_int *base, mp_int *exponent);
+mp_int *monty_invert(MontyContext *, mp_int *);
+mp_int *monty_modsqrt(ModsqrtContext *sc, mp_int *mx, unsigned *success);
+
+/*
+ * Modular arithmetic functions which don't use an explicit
+ * MontyContext. mp_modpow will use one internally (on the assumption
+ * that the exponent is likely to be large enough to make it
+ * worthwhile); the other three will just do ordinary non-Montgomery-
+ * optimised modular reduction. Use mp_modmul if you only have one
+ * product to compute; if you have a lot, consider using a
+ * MontyContext in the client code.
+ */
+mp_int *mp_modpow(mp_int *base, mp_int *exponent, mp_int *modulus);
+mp_int *mp_modmul(mp_int *x, mp_int *y, mp_int *modulus);
+mp_int *mp_modadd(mp_int *x, mp_int *y, mp_int *modulus);
+mp_int *mp_modsub(mp_int *x, mp_int *y, mp_int *modulus);
+
+/*
+ * Shift an mp_int right by a given number of bits. The shift count is
+ * considered to be secret data, and as a result, the algorithm takes
+ * O(n log n) time instead of the obvious O(n).
+ */
+mp_int *mp_rshift_safe(mp_int *x, size_t shift);
+
+/*
+ * Shift an mp_int left or right by a fixed number of bits. The shift
+ * count is NOT considered to be secret data! Use this if you're
+ * always dividing by 2, for example, but don't use it to shift by a
+ * variable amount derived from another secret number.
+ *
+ * The upside is that these functions run in sensible linear time.
+ */
+void mp_lshift_fixed_into(mp_int *r, mp_int *a, size_t shift);
+void mp_rshift_fixed_into(mp_int *r, mp_int *x, size_t shift);
+mp_int *mp_rshift_fixed(mp_int *x, size_t shift);
+
+/*
+ * Generate a random mp_int.
+ *
+ * The _function_ definitions here will expect to be given a gen_byte
+ * function that provides random data. Normally you'd use this using
+ * random_byte() from random.c, and the macro wrappers automate that.
+ *
+ * (This is a bit of a dodge to avoid mpint.c having a link-time
+ * dependency on random.c, so that programs can link against one but
+ * not the other: if a client of this header uses one of these macros
+ * then _they_ have link-time dependencies on both modules.)
+ *
+ * mp_random_bits[_fn] returns an integer 0 <= n < 2^bits.
+ * mp_random_in_range[_fn](lo,hi) returns an integer lo <= n < hi.
+ */
+mp_int *mp_random_bits_fn(size_t bits, int (*gen_byte)(void));
+mp_int *mp_random_in_range_fn(
+    mp_int *lo_inclusive, mp_int *hi_exclusive, int (*gen_byte)(void));
+#define mp_random_bits(bits) mp_random_bits_fn(bits, random_byte)
+#define mp_random_in_range(lo, hi) mp_random_in_range_fn(lo, hi, random_byte)
+
+#endif /* PUTTY_MPINT_H */

+ 73 - 12
source/putty/sshbn.h → source/putty/mpint_i.h

@@ -1,10 +1,15 @@
 /*
- * sshbn.h: the assorted conditional definitions of BignumInt and
- * multiply macros used throughout the bignum code to treat numbers as
- * arrays of the most conveniently sized word for the target machine.
+ * mpint_i.h: definitions used internally by the bignum code, and
+ * also a few other vaguely-bignum-like places.
+ */
+
+/* ----------------------------------------------------------------------
+ * The assorted conditional definitions of BignumInt and multiply
+ * macros used throughout the bignum code to treat numbers as arrays
+ * of the most conveniently sized word for the target machine.
  * Exported so that other code (e.g. poly1305) can use it too.
  *
- * This file must export, in whatever ifdef branch it ends up in:
+ * This code must export, in whatever ifdef branch it ends up in:
  *
  *  - two types: 'BignumInt' and 'BignumCarry'. BignumInt is an
  *    unsigned integer type which will be used as the base word size
@@ -64,7 +69,7 @@
    */
 
   typedef unsigned long long BignumInt;
-  #define BIGNUM_INT_BITS 64
+  #define BIGNUM_INT_BITS_BITS 6
   #define DEFINE_BIGNUMDBLINT typedef __uint128_t BignumDblInt
 
 #elif defined _MSC_VER && defined _M_AMD64
@@ -85,7 +90,7 @@
   #include <intrin.h>
   typedef unsigned char BignumCarry; /* the type _addcarry_u64 likes to use */
   typedef unsigned __int64 BignumInt;
-  #define BIGNUM_INT_BITS 64
+  #define BIGNUM_INT_BITS_BITS 6
   #define BignumADC(ret, retc, a, b, c) do                \
       {                                                   \
           BignumInt ADC_tmp;                              \
@@ -119,7 +124,7 @@
   /* 32-bit BignumInt, using C99 unsigned long long as BignumDblInt */
 
   typedef unsigned int BignumInt;
-  #define BIGNUM_INT_BITS 32
+  #define BIGNUM_INT_BITS_BITS 5
   #define DEFINE_BIGNUMDBLINT typedef unsigned long long BignumDblInt
 
 #elif defined _MSC_VER && defined _M_IX86
@@ -127,7 +132,7 @@
   /* 32-bit BignumInt, using Visual Studio __int64 as BignumDblInt */
 
   typedef unsigned int BignumInt;
-  #define BIGNUM_INT_BITS  32
+  #define BIGNUM_INT_BITS_BITS 5
   #define DEFINE_BIGNUMDBLINT typedef unsigned __int64 BignumDblInt
 
 #elif defined _LP64
@@ -139,7 +144,7 @@
    */
 
   typedef unsigned int BignumInt;
-  #define BIGNUM_INT_BITS  32
+  #define BIGNUM_INT_BITS_BITS 5
   #define DEFINE_BIGNUMDBLINT typedef unsigned long BignumDblInt
 
 #else
@@ -155,15 +160,16 @@
    */
 
   typedef unsigned short BignumInt;
-  #define BIGNUM_INT_BITS  16
+  #define BIGNUM_INT_BITS_BITS 4
   #define DEFINE_BIGNUMDBLINT typedef unsigned long BignumDblInt
 
 #endif
 
 /*
- * Common code across all branches of that ifdef: define the three
- * easy constant macros in terms of BIGNUM_INT_BITS.
+ * Common code across all branches of that ifdef: define all the
+ * easy constant macros in terms of BIGNUM_INT_BITS_BITS.
  */
+#define BIGNUM_INT_BITS (1 << BIGNUM_INT_BITS_BITS)
 #define BIGNUM_INT_BYTES (BIGNUM_INT_BITS / 8)
 #define BIGNUM_TOP_BIT (((BignumInt)1) << (BIGNUM_INT_BITS-1))
 #define BIGNUM_INT_MASK (BIGNUM_TOP_BIT | (BIGNUM_TOP_BIT-1))
@@ -218,3 +224,58 @@
       } while (0)
 
 #endif /* DEFINE_BIGNUMDBLINT */
+
+/* ----------------------------------------------------------------------
+ * Data structures used inside bignum.c.
+ */
+
+struct mp_int {
+    size_t nw;
+    BignumInt *w;
+};
+
+struct MontyContext {
+    /*
+     * The actual modulus.
+     */
+    mp_int *m;
+
+    /*
+     * Montgomery multiplication works by selecting a value r > m,
+     * coprime to m, which is really easy to divide by. In binary
+     * arithmetic, that means making it a power of 2; in fact we make
+     * it a whole number of BignumInt.
+     *
+     * We don't store r directly as an mp_int (there's no need). But
+     * its value is 2^rbits; we also store rw = rbits/BIGNUM_INT_BITS
+     * (the corresponding word offset within an mp_int).
+     *
+     * pw is the number of words needed to store an mp_int you're
+     * doing reduction on: it has to be big enough to hold the sum of
+     * an input value up to m^2 plus an extra addend up to m*r.
+     */
+    size_t rbits, rw, pw;
+
+    /*
+     * The key step in Montgomery reduction requires the inverse of -m
+     * mod r.
+     */
+    mp_int *minus_minv_mod_r;
+
+    /*
+     * r^1, r^2 and r^3 mod m, which are used for various purposes.
+     *
+     * (Annoyingly, this is one of the rare cases where it would have
+     * been nicer to have a Pascal-style 1-indexed array. I couldn't
+     * _quite_ bring myself to put a gratuitous zero element in here.
+     * So you just have to live with getting r^k by taking the [k-1]th
+     * element of this array.)
+     */
+    mp_int *powers_of_r_mod_m[3];
+
+    /*
+     * Persistent scratch space from which monty_* functions can
+     * allocate storage for intermediate values.
+     */
+    mp_int *scratch;
+};

+ 68 - 101
source/putty/ssh.h

@@ -390,10 +390,6 @@ void ssh_user_close(Ssh *ssh, const char *fmt, ...);
 #define SSH_CIPHER_3DES		3
 #define SSH_CIPHER_BLOWFISH	6
 
-#ifndef BIGNUM_INTERNAL
-typedef void *Bignum;
-#endif
-
 typedef struct ssh_keyalg ssh_keyalg;
 typedef struct ssh_key {
     const struct ssh_keyalg *vt;
@@ -402,57 +398,52 @@ typedef struct ssh_key {
 struct RSAKey {
     int bits;
     int bytes;
-    Bignum modulus;
-    Bignum exponent;
-    Bignum private_exponent;
-    Bignum p;
-    Bignum q;
-    Bignum iqmp;
+    mp_int *modulus;
+    mp_int *exponent;
+    mp_int *private_exponent;
+    mp_int *p;
+    mp_int *q;
+    mp_int *iqmp;
     char *comment;
     ssh_key sshk;
 };
 
 struct dss_key {
-    Bignum p, q, g, y, x;
+    mp_int *p, *q, *g, *y, *x;
     ssh_key sshk;
 };
 
 struct ec_curve;
 
-struct ec_point {
-    const struct ec_curve *curve;
-    Bignum x, y;
-    Bignum z;  /* Jacobian denominator */
-    bool infinity;
-};
-
-/* A couple of ECC functions exported for use outside sshecc.c */
-struct ec_point *ecp_mul(const struct ec_point *a, const Bignum b);
-void ec_point_free(struct ec_point *point);
-
 /* Weierstrass form curve */
 struct ec_wcurve
 {
-    Bignum a, b, n;
-    struct ec_point G;
+    WeierstrassCurve *wc;
+    WeierstrassPoint *G;
+    mp_int *G_order;
 };
 
 /* Montgomery form curve */
 struct ec_mcurve
 {
-    Bignum a, b;
-    struct ec_point G;
+    MontgomeryCurve *mc;
+    MontgomeryPoint *G;
 };
 
 /* Edwards form curve */
 struct ec_ecurve
 {
-    Bignum l, d;
-    struct ec_point B;
+    EdwardsCurve *ec;
+    EdwardsPoint *G;
+    mp_int *G_order;
 };
 
+typedef enum EllipticCurveType {
+    EC_WEIERSTRASS, EC_MONTGOMERY, EC_EDWARDS
+} EllipticCurveType;
+
 struct ec_curve {
-    enum { EC_WEIERSTRASS, EC_MONTGOMERY, EC_EDWARDS } type;
+    EllipticCurveType type;
     /* 'name' is the identifier of the curve when it has to appear in
      * wire protocol encodings, as it does in e.g. the public key and
      * signature formats for NIST curves. Curves which do not format
@@ -461,8 +452,8 @@ struct ec_curve {
      * 'textname' is non-NULL for all curves, and is a human-readable
      * identification suitable for putting in log messages. */
     const char *name, *textname;
-    unsigned int fieldBits;
-    Bignum p;
+    size_t fieldBits, fieldBytes;
+    mp_int *p;
     union {
         struct ec_wcurve w;
         struct ec_mcurve m;
@@ -481,13 +472,21 @@ bool ec_ed_alg_and_curve_by_bits(int bits,
                                  const struct ec_curve **curve,
                                  const ssh_keyalg **alg);
 
-struct ec_key {
-    struct ec_point publicKey;
-    Bignum privateKey;
+struct ecdsa_key {
+    const struct ec_curve *curve;
+    WeierstrassPoint *publicKey;
+    mp_int *privateKey;
+    ssh_key sshk;
+};
+struct eddsa_key {
+    const struct ec_curve *curve;
+    EdwardsPoint *publicKey;
+    mp_int *privateKey;
     ssh_key sshk;
 };
 
-struct ec_point *ec_public(const Bignum privateKey, const struct ec_curve *curve);
+WeierstrassPoint *ecdsa_public(mp_int *private_key, const ssh_keyalg *alg);
+EdwardsPoint *eddsa_public(mp_int *private_key, const ssh_keyalg *alg);
 
 /*
  * SSH-1 never quite decided which order to store the two components
@@ -504,8 +503,9 @@ void BinarySource_get_rsa_ssh1_pub(
 void BinarySource_get_rsa_ssh1_priv(
     BinarySource *src, struct RSAKey *rsa);
 bool rsa_ssh1_encrypt(unsigned char *data, int length, struct RSAKey *key);
-Bignum rsa_ssh1_decrypt(Bignum input, struct RSAKey *key);
-bool rsa_ssh1_decrypt_pkcs1(Bignum input, struct RSAKey *key, strbuf *outbuf);
+mp_int *rsa_ssh1_decrypt(mp_int *input, struct RSAKey *key);
+bool rsa_ssh1_decrypt_pkcs1(mp_int *input, struct RSAKey *key,
+                            strbuf *outbuf);
 char *rsastr_fmt(struct RSAKey *key);
 char *rsa_ssh1_fingerprint(struct RSAKey *key);
 bool rsa_verify(struct RSAKey *key);
@@ -538,25 +538,26 @@ int ssh_rsakex_klen(struct RSAKey *key);
 void ssh_rsakex_encrypt(const struct ssh_hashalg *h,
                         unsigned char *in, int inlen,
                         unsigned char *out, int outlen, struct RSAKey *key);
-Bignum ssh_rsakex_decrypt(const struct ssh_hashalg *h, ptrlen ciphertext,
-                          struct RSAKey *rsa);
+mp_int *ssh_rsakex_decrypt(const struct ssh_hashalg *h, ptrlen ciphertext,
+                              struct RSAKey *rsa);
 
 /*
  * SSH2 ECDH key exchange functions
  */
 struct ssh_kex;
+typedef struct ecdh_key ecdh_key;
 const char *ssh_ecdhkex_curve_textname(const struct ssh_kex *kex);
-struct ec_key *ssh_ecdhkex_newkey(const struct ssh_kex *kex);
-void ssh_ecdhkex_freekey(struct ec_key *key);
-void ssh_ecdhkex_getpublic(struct ec_key *key, BinarySink *bs);
-Bignum ssh_ecdhkex_getkey(struct ec_key *key,
-                          const void *remoteKey, int remoteKeyLen);
+ecdh_key *ssh_ecdhkex_newkey(const struct ssh_kex *kex);
+void ssh_ecdhkex_freekey(ecdh_key *key);
+void ssh_ecdhkex_getpublic(ecdh_key *key, BinarySink *bs);
+mp_int *ssh_ecdhkex_getkey(ecdh_key *key, ptrlen remoteKey);
 
 /*
  * Helper function for k generation in DSA, reused in ECDSA
  */
-Bignum *dss_gen_k(const char *id_string, Bignum modulus, Bignum private_key,
-                  unsigned char *digest, int digest_len);
+mp_int *dss_gen_k(const char *id_string,
+                     mp_int *modulus, mp_int *private_key,
+                     unsigned char *digest, int digest_len);
 
 struct ssh2_cipheralg;
 typedef struct ssh2_cipher {
@@ -740,14 +741,14 @@ typedef struct ssh_hash {
     BinarySink_DELEGATE_IMPLEMENTATION;
 } ssh_hash;
 
-struct ssh_hashalg {
+typedef struct ssh_hashalg {
     ssh_hash *(*new)(const struct ssh_hashalg *alg);
     ssh_hash *(*copy)(ssh_hash *);
     void (*final)(ssh_hash *, unsigned char *); /* ALSO FREES THE ssh_hash! */
     void (*free)(ssh_hash *);
     int hlen; /* output length in bytes */
     const char *text_name;
-};   
+} ssh_hashalg;
 
 #define ssh_hash_new(alg) ((alg)->new(alg))
 #define ssh_hash_copy(ctx) ((ctx)->vt->copy(ctx))
@@ -1053,58 +1054,15 @@ void *x11_dehexify(ptrlen hex, int *outlen);
 
 Channel *agentf_new(SshChannel *c);
 
-Bignum copybn(Bignum b);
-Bignum bn_power_2(int n);
-void bn_restore_invariant(Bignum b);
-Bignum bignum_from_long(unsigned long n);
-void freebn(Bignum b);
-Bignum modpow(Bignum base, Bignum exp, Bignum mod);
-Bignum modmul(Bignum a, Bignum b, Bignum mod);
-Bignum modsub(const Bignum a, const Bignum b, const Bignum n);
-void decbn(Bignum n);
-extern Bignum Zero, One;
-Bignum bignum_from_bytes(const void *data, int nbytes);
-Bignum bignum_from_bytes_le(const void *data, int nbytes);
-Bignum bignum_random_in_range(const Bignum lower, const Bignum upper);
-int bignum_bitcount(Bignum bn);
-int bignum_byte(Bignum bn, int i);
-int bignum_bit(Bignum bn, int i);
-void bignum_set_bit(Bignum bn, int i, int value);
-Bignum biggcd(Bignum a, Bignum b);
-unsigned short bignum_mod_short(Bignum number, unsigned short modulus);
-Bignum bignum_add_long(Bignum number, unsigned long addend);
-Bignum bigadd(Bignum a, Bignum b);
-Bignum bigsub(Bignum a, Bignum b);
-Bignum bigmul(Bignum a, Bignum b);
-Bignum bigmuladd(Bignum a, Bignum b, Bignum addend);
-Bignum bigdiv(Bignum a, Bignum b);
-Bignum bigmod(Bignum a, Bignum b);
-Bignum modinv(Bignum number, Bignum modulus);
-Bignum bignum_bitmask(Bignum number);
-Bignum bignum_rshift(Bignum number, int shift);
-Bignum bignum_lshift(Bignum number, int shift);
-int bignum_cmp(Bignum a, Bignum b);
-char *bignum_decimal(Bignum x);
-Bignum bignum_from_decimal(const char *decimal);
-
-void BinarySink_put_mp_ssh1(BinarySink *, Bignum);
-void BinarySink_put_mp_ssh2(BinarySink *, Bignum);
-Bignum BinarySource_get_mp_ssh1(BinarySource *);
-Bignum BinarySource_get_mp_ssh2(BinarySource *);
-
-#ifdef DEBUG
-void diagbn(char *prefix, Bignum md);
-#endif
-
 bool dh_is_gex(const struct ssh_kex *kex);
 struct dh_ctx;
 struct dh_ctx *dh_setup_group(const struct ssh_kex *kex);
-struct dh_ctx *dh_setup_gex(Bignum pval, Bignum gval);
+struct dh_ctx *dh_setup_gex(mp_int *pval, mp_int *gval);
 int dh_modulus_bit_size(const struct dh_ctx *ctx);
 void dh_cleanup(struct dh_ctx *);
-Bignum dh_create_e(struct dh_ctx *, int nbits);
-const char *dh_validate_f(struct dh_ctx *, Bignum f);
-Bignum dh_find_K(struct dh_ctx *, Bignum f);
+mp_int *dh_create_e(struct dh_ctx *, int nbits);
+const char *dh_validate_f(struct dh_ctx *, mp_int *f);
+mp_int *dh_find_K(struct dh_ctx *, mp_int *f);
 
 bool rsa_ssh1_encrypted(const Filename *filename, char **comment);
 int rsa_ssh1_loadpub(const Filename *filename, BinarySink *bs,
@@ -1114,6 +1072,14 @@ int rsa_ssh1_loadkey(const Filename *filename, struct RSAKey *key,
 bool rsa_ssh1_savekey(const Filename *filename, struct RSAKey *key,
                       char *passphrase);
 
+static inline bool is_base64_char(char c)
+{
+    return ((c >= '0' && c <= '9') ||
+            (c >= 'a' && c <= 'z') ||
+            (c >= 'A' && c <= 'Z') ||
+            c == '+' || c == '/' || c == '=');
+}
+
 extern int base64_decode_atom(const char *atom, unsigned char *out);
 extern int base64_lines(int datalen);
 extern void base64_encode_atom(const unsigned char *data, int n, char *out);
@@ -1233,12 +1199,13 @@ int rsa_generate(struct RSAKey *key, int bits, progfn_t pfn,
 		 void *pfnparam);
 int dsa_generate(struct dss_key *key, int bits, progfn_t pfn,
 		 void *pfnparam);
-int ec_generate(struct ec_key *key, int bits, progfn_t pfn,
-                void *pfnparam);
-int ec_edgenerate(struct ec_key *key, int bits, progfn_t pfn,
-                  void *pfnparam);
-Bignum primegen(int bits, int modulus, int residue, Bignum factor,
-		int phase, progfn_t pfn, void *pfnparam, unsigned firstbits);
+int ecdsa_generate(struct ecdsa_key *key, int bits, progfn_t pfn,
+                   void *pfnparam);
+int eddsa_generate(struct eddsa_key *key, int bits, progfn_t pfn,
+                   void *pfnparam);
+mp_int *primegen(
+    int bits, int modulus, int residue, mp_int *factor,
+    int phase, progfn_t pfn, void *pfnparam, unsigned firstbits);
 void invent_firstbits(unsigned *one, unsigned *two);
 
 /*

+ 12 - 11
source/putty/ssh1login.c

@@ -7,6 +7,7 @@
 
 #include "putty.h"
 #include "ssh.h"
+#include "mpint.h"
 #include "sshbpp.h"
 #include "sshppl.h"
 #include "sshcr.h"
@@ -49,7 +50,7 @@ struct ssh1_login_state {
     int keyi, nkeys;
     bool authed;
     struct RSAKey key;
-    Bignum challenge;
+    mp_int *challenge;
     ptrlen comment;
     int dlgret;
     Filename *keyfile;
@@ -537,7 +538,7 @@ static void ssh1_login_process_queue(PacketProtocolLayer *ppl)
                     ppl_logevent("Received RSA challenge");
                     s->challenge = get_mp_ssh1(pktin);
                     if (get_err(pktin)) {
-                        freebn(s->challenge);
+                        mp_free(s->challenge);
                         ssh_proto_error(s->ppl.ssh, "Server's RSA challenge "
                                         "was badly formatted");
                         return;
@@ -549,7 +550,7 @@ static void ssh1_login_process_queue(PacketProtocolLayer *ppl)
 
                         agentreq = strbuf_new_for_agent_query();
                         put_byte(agentreq, SSH1_AGENTC_RSA_CHALLENGE);
-                        put_uint32(agentreq, bignum_bitcount(s->key.modulus));
+                        put_uint32(agentreq, mp_get_nbits(s->key.modulus));
                         put_mp_ssh1(agentreq, s->key.exponent);
                         put_mp_ssh1(agentreq, s->key.modulus);
                         put_mp_ssh1(agentreq, s->challenge);
@@ -594,9 +595,9 @@ static void ssh1_login_process_queue(PacketProtocolLayer *ppl)
                             ppl_logevent("No reply received from Pageant");
                         }
                     }
-                    freebn(s->key.exponent);
-                    freebn(s->key.modulus);
-                    freebn(s->challenge);
+                    mp_free(s->key.exponent);
+                    mp_free(s->key.modulus);
+                    mp_free(s->challenge);
                     if (s->authed)
                         break;
                 }
@@ -719,11 +720,11 @@ static void ssh1_login_process_queue(PacketProtocolLayer *ppl)
                 {
                     int i;
                     unsigned char buffer[32];
-                    Bignum challenge, response;
+                    mp_int *challenge, *response;
 
                     challenge = get_mp_ssh1(pktin);
                     if (get_err(pktin)) {
-                        freebn(challenge);
+                        mp_free(challenge);
                         ssh_proto_error(s->ppl.ssh, "Server's RSA challenge "
                                         "was badly formatted");
                         return;
@@ -732,7 +733,7 @@ static void ssh1_login_process_queue(PacketProtocolLayer *ppl)
                     freersapriv(&s->key);   /* burn the evidence */
 
                     for (i = 0; i < 32; i++) {
-                        buffer[i] = bignum_byte(response, 31 - i);
+                        buffer[i] = mp_get_byte(response, 31 - i);
                     }
 
                     {
@@ -748,8 +749,8 @@ static void ssh1_login_process_queue(PacketProtocolLayer *ppl)
                     put_data(pkt, buffer, 16);
                     pq_push(s->ppl.out_pq, pkt);
 
-                    freebn(challenge);
-                    freebn(response);
+                    mp_free(challenge);
+                    mp_free(response);
                 }
 
                 crMaybeWaitUntilV((pktin = ssh1_login_pop(s))

+ 10 - 9
source/putty/ssh2kex-client.c

@@ -11,6 +11,7 @@
 #include "sshcr.h"
 #include "storage.h"
 #include "ssh2transport.h"
+#include "mpint.h"
 
 void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted)
 {
@@ -170,10 +171,10 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted)
 
         dh_cleanup(s->dh_ctx);
         s->dh_ctx = NULL;
-        freebn(s->f); s->f = NULL;
+        mp_free(s->f); s->f = NULL;
         if (dh_is_gex(s->kex_alg)) {
-            freebn(s->g); s->g = NULL;
-            freebn(s->p); s->p = NULL;
+            mp_free(s->g); s->g = NULL;
+            mp_free(s->p); s->p = NULL;
         }
     } else if (s->kex_alg->main_type == KEXTYPE_ECDH) {
 
@@ -223,7 +224,7 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted)
         {
             ptrlen keydata = get_string(pktin);
             put_stringpl(s->exhash, keydata);
-            s->K = ssh_ecdhkex_getkey(s->ecdh_key, keydata.ptr, keydata.len);
+            s->K = ssh_ecdhkex_getkey(s->ecdh_key, keydata);
             if (!get_err(pktin) && !s->K) {
                 ssh_proto_error(s->ppl.ssh, "Received invalid elliptic curve "
                                 "point in ECDH reply");
@@ -501,10 +502,10 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted)
 
         dh_cleanup(s->dh_ctx);
         s->dh_ctx = NULL;
-        freebn(s->f); s->f = NULL;
+        mp_free(s->f); s->f = NULL;
         if (dh_is_gex(s->kex_alg)) {
-            freebn(s->g); s->g = NULL;
-            freebn(s->p); s->p = NULL;
+            mp_free(s->g); s->g = NULL;
+            mp_free(s->p); s->p = NULL;
         }
 #endif
     } else {
@@ -560,13 +561,13 @@ void ssh2kex_coroutine(struct ssh2_transport_state *s, bool *aborted)
             unsigned char *outstr;
             int outstrlen;
 
-            s->K = bn_power_2(nbits - 1);
+            s->K = mp_power_2(nbits - 1);
 
             for (i = 0; i < nbits; i++) {
                 if ((i & 7) == 0) {
                     byte = random_byte();
                 }
-                bignum_set_bit(s->K, i, (byte >> (i & 7)) & 1);
+                mp_set_bit(s->K, i, (byte >> (i & 7)) & 1);
             }
 
             /*

+ 7 - 6
source/putty/ssh2transport.c

@@ -11,6 +11,7 @@
 #include "sshcr.h"
 #include "storage.h"
 #include "ssh2transport.h"
+#include "mpint.h"
 
 const struct ssh_signkey_with_user_pref_id ssh2_hostkey_algs[] = {
     #define ARRAYENT_HOSTKEY_ALGORITHM(type, alg) { &alg, type },
@@ -200,10 +201,10 @@ static void ssh2_transport_free(PacketProtocolLayer *ppl)
         ssh_key_free(s->hkey);
         s->hkey = NULL;
     }
-    if (s->f) freebn(s->f);
-    if (s->p) freebn(s->p);
-    if (s->g) freebn(s->g);
-    if (s->K) freebn(s->K);
+    if (s->f) mp_free(s->f);
+    if (s->p) mp_free(s->p);
+    if (s->g) mp_free(s->g);
+    if (s->K) mp_free(s->K);
     if (s->dh_ctx)
         dh_cleanup(s->dh_ctx);
     if (s->rsa_kex_key)
@@ -225,7 +226,7 @@ static void ssh2_transport_free(PacketProtocolLayer *ppl)
  */
 static void ssh2_mkkey(
     struct ssh2_transport_state *s, strbuf *out,
-    Bignum K, unsigned char *H, char chr, int keylen)
+    mp_int *K, unsigned char *H, char chr, int keylen)
 {
     int hlen = s->kex_alg->hash->hlen;
     int keylen_padded;
@@ -1365,7 +1366,7 @@ static void ssh2_transport_process_queue(PacketProtocolLayer *ppl)
     /*
      * Free shared secret.
      */
-    freebn(s->K); s->K = NULL;
+    mp_free(s->K); s->K = NULL;
 
     /*
      * Update the specials menu to list the remaining uncertified host

+ 2 - 2
source/putty/ssh2transport.h

@@ -166,7 +166,7 @@ struct ssh2_transport_state {
 
     int nbits, pbits;
     bool warn_kex, warn_hk, warn_cscipher, warn_sccipher;
-    Bignum p, g, e, f, K;
+    mp_int *p, *g, *e, *f, *K;
     strbuf *outgoing_kexinit, *incoming_kexinit;
     strbuf *client_kexinit, *server_kexinit; /* aliases to the above */
     int kex_init_value, kex_reply_value;
@@ -176,7 +176,7 @@ struct ssh2_transport_state {
     char *keystr, *fingerprint;
     ssh_key *hkey;                     /* actual host key */
     struct RSAKey *rsa_kex_key;             /* for RSA kex */
-    struct ec_key *ecdh_key;              /* for ECDH kex */
+    ecdh_key *ecdh_key;                     /* for ECDH kex */
     unsigned char exchange_hash[SSH2_KEX_MAX_HASH_LEN];
     bool can_gssapi_keyex;
     bool need_gss_transient_hostkey;

+ 0 - 2180
source/putty/sshbn.c

@@ -1,2180 +0,0 @@
-/*
- * Bignum routines for RSA and DH and stuff.
- */
-
-#include <stdio.h>
-#include <assert.h>
-#include <stdlib.h>
-#include <string.h>
-#include <limits.h>
-#include <ctype.h>
-
-#include "misc.h"
-
-#include "sshbn.h"
-
-#define BIGNUM_INTERNAL
-typedef BignumInt *Bignum;
-
-#include "ssh.h"
-#include "marshal.h"
-
-BignumInt bnZero[1] = { 0 };
-BignumInt bnOne[2] = { 1, 1 };
-BignumInt bnTen[2] = { 1, 10 };
-
-/*
- * The Bignum format is an array of `BignumInt'. The first
- * element of the array counts the remaining elements. The
- * remaining elements express the actual number, base 2^BIGNUM_INT_BITS, _least_
- * significant digit first. (So it's trivial to extract the bit
- * with value 2^n for any n.)
- *
- * All Bignums in this module are positive. Negative numbers must
- * be dealt with outside it.
- *
- * INVARIANT: the most significant word of any Bignum must be
- * nonzero.
- */
-
-Bignum Zero = bnZero, One = bnOne, Ten = bnTen;
-
-static Bignum newbn(int length)
-{
-    Bignum b;
-
-    assert(length >= 0 && length < INT_MAX / BIGNUM_INT_BITS);
-
-    b = snewn(length + 1, BignumInt);
-    memset(b, 0, (length + 1) * sizeof(*b));
-    b[0] = length;
-    return b;
-}
-
-void bn_restore_invariant(Bignum b)
-{
-    while (b[0] > 1 && b[b[0]] == 0)
-	b[0]--;
-}
-
-Bignum copybn(Bignum orig)
-{
-    Bignum b = snewn(orig[0] + 1, BignumInt);
-    if (!b)
-	abort();		       /* FIXME */
-    memcpy(b, orig, (orig[0] + 1) * sizeof(*b));
-    return b;
-}
-
-void freebn(Bignum b)
-{
-    /*
-     * Burn the evidence, just in case.
-     */
-    smemclr(b, sizeof(b[0]) * (b[0] + 1));
-    sfree(b);
-}
-
-Bignum bn_power_2(int n)
-{
-    Bignum ret;
-
-    assert(n >= 0);
-
-    ret = newbn(n / BIGNUM_INT_BITS + 1);
-    bignum_set_bit(ret, n, 1);
-    return ret;
-}
-
-/*
- * Internal addition. Sets c = a - b, where 'a', 'b' and 'c' are all
- * big-endian arrays of 'len' BignumInts. Returns the carry off the
- * top.
- */
-static BignumCarry internal_add(const BignumInt *a, const BignumInt *b,
-                                BignumInt *c, int len)
-{
-    int i;
-    BignumCarry carry = 0;
-
-    for (i = len-1; i >= 0; i--)
-        BignumADC(c[i], carry, a[i], b[i], carry);
-
-    return (BignumInt)carry;
-}
-
-/*
- * Internal subtraction. Sets c = a - b, where 'a', 'b' and 'c' are
- * all big-endian arrays of 'len' BignumInts. Any borrow from the top
- * is ignored.
- */
-static void internal_sub(const BignumInt *a, const BignumInt *b,
-                         BignumInt *c, int len)
-{
-    int i;
-    BignumCarry carry = 1;
-
-    for (i = len-1; i >= 0; i--)
-        BignumADC(c[i], carry, a[i], ~b[i], carry);
-}
-
-/*
- * Compute c = a * b.
- * Input is in the first len words of a and b.
- * Result is returned in the first 2*len words of c.
- *
- * 'scratch' must point to an array of BignumInt of size at least
- * mul_compute_scratch(len). (This covers the needs of internal_mul
- * and all its recursive calls to itself.)
- */
-#define KARATSUBA_THRESHOLD 50
-static int mul_compute_scratch(int len)
-{
-    int ret = 0;
-    while (len > KARATSUBA_THRESHOLD) {
-        int toplen = len/2, botlen = len - toplen; /* botlen is the bigger */
-        int midlen = botlen + 1;
-        ret += 4*midlen;
-        len = midlen;
-    }
-    return ret;
-}
-static void internal_mul(const BignumInt *a, const BignumInt *b,
-			 BignumInt *c, int len, BignumInt *scratch)
-{
-    if (len > KARATSUBA_THRESHOLD) {
-        int i;
-
-        /*
-         * Karatsuba divide-and-conquer algorithm. Cut each input in
-         * half, so that it's expressed as two big 'digits' in a giant
-         * base D:
-         *
-         *   a = a_1 D + a_0
-         *   b = b_1 D + b_0
-         *
-         * Then the product is of course
-         *
-         *  ab = a_1 b_1 D^2 + (a_1 b_0 + a_0 b_1) D + a_0 b_0
-         *
-         * and we compute the three coefficients by recursively
-         * calling ourself to do half-length multiplications.
-         *
-         * The clever bit that makes this worth doing is that we only
-         * need _one_ half-length multiplication for the central
-         * coefficient rather than the two that it obviouly looks
-         * like, because we can use a single multiplication to compute
-         *
-         *   (a_1 + a_0) (b_1 + b_0) = a_1 b_1 + a_1 b_0 + a_0 b_1 + a_0 b_0
-         *
-         * and then we subtract the other two coefficients (a_1 b_1
-         * and a_0 b_0) which we were computing anyway.
-         *
-         * Hence we get to multiply two numbers of length N in about
-         * three times as much work as it takes to multiply numbers of
-         * length N/2, which is obviously better than the four times
-         * as much work it would take if we just did a long
-         * conventional multiply.
-         */
-
-        int toplen = len/2, botlen = len - toplen; /* botlen is the bigger */
-        int midlen = botlen + 1;
-        BignumCarry carry;
-#ifdef KARA_DEBUG
-        int i;
-#endif
-
-        /*
-         * The coefficients a_1 b_1 and a_0 b_0 just avoid overlapping
-         * in the output array, so we can compute them immediately in
-         * place.
-         */
-
-#ifdef KARA_DEBUG
-        printf("a1,a0 = 0x");
-        for (i = 0; i < len; i++) {
-            if (i == toplen) printf(", 0x");
-            printf("%0*x", BIGNUM_INT_BITS/4, a[i]);
-        }
-        printf("\n");
-        printf("b1,b0 = 0x");
-        for (i = 0; i < len; i++) {
-            if (i == toplen) printf(", 0x");
-            printf("%0*x", BIGNUM_INT_BITS/4, b[i]);
-        }
-        printf("\n");
-#endif
-
-        /* a_1 b_1 */
-        internal_mul(a, b, c, toplen, scratch);
-#ifdef KARA_DEBUG
-        printf("a1b1 = 0x");
-        for (i = 0; i < 2*toplen; i++) {
-            printf("%0*x", BIGNUM_INT_BITS/4, c[i]);
-        }
-        printf("\n");
-#endif
-
-        /* a_0 b_0 */
-        internal_mul(a + toplen, b + toplen, c + 2*toplen, botlen, scratch);
-#ifdef KARA_DEBUG
-        printf("a0b0 = 0x");
-        for (i = 0; i < 2*botlen; i++) {
-            printf("%0*x", BIGNUM_INT_BITS/4, c[2*toplen+i]);
-        }
-        printf("\n");
-#endif
-
-        /* Zero padding. midlen exceeds toplen by at most 2, so just
-         * zero the first two words of each input and the rest will be
-         * copied over. */
-        scratch[0] = scratch[1] = scratch[midlen] = scratch[midlen+1] = 0;
-
-        for (i = 0; i < toplen; i++) {
-            scratch[midlen - toplen + i] = a[i]; /* a_1 */
-            scratch[2*midlen - toplen + i] = b[i]; /* b_1 */
-        }
-
-        /* compute a_1 + a_0 */
-        scratch[0] = internal_add(scratch+1, a+toplen, scratch+1, botlen);
-#ifdef KARA_DEBUG
-        printf("a1plusa0 = 0x");
-        for (i = 0; i < midlen; i++) {
-            printf("%0*x", BIGNUM_INT_BITS/4, scratch[i]);
-        }
-        printf("\n");
-#endif
-        /* compute b_1 + b_0 */
-        scratch[midlen] = internal_add(scratch+midlen+1, b+toplen,
-                                       scratch+midlen+1, botlen);
-#ifdef KARA_DEBUG
-        printf("b1plusb0 = 0x");
-        for (i = 0; i < midlen; i++) {
-            printf("%0*x", BIGNUM_INT_BITS/4, scratch[midlen+i]);
-        }
-        printf("\n");
-#endif
-
-        /*
-         * Now we can do the third multiplication.
-         */
-        internal_mul(scratch, scratch + midlen, scratch + 2*midlen, midlen,
-                     scratch + 4*midlen);
-#ifdef KARA_DEBUG
-        printf("a1plusa0timesb1plusb0 = 0x");
-        for (i = 0; i < 2*midlen; i++) {
-            printf("%0*x", BIGNUM_INT_BITS/4, scratch[2*midlen+i]);
-        }
-        printf("\n");
-#endif
-
-        /*
-         * Now we can reuse the first half of 'scratch' to compute the
-         * sum of the outer two coefficients, to subtract from that
-         * product to obtain the middle one.
-         */
-        scratch[0] = scratch[1] = scratch[2] = scratch[3] = 0;
-        for (i = 0; i < 2*toplen; i++)
-            scratch[2*midlen - 2*toplen + i] = c[i];
-        scratch[1] = internal_add(scratch+2, c + 2*toplen,
-                                  scratch+2, 2*botlen);
-#ifdef KARA_DEBUG
-        printf("a1b1plusa0b0 = 0x");
-        for (i = 0; i < 2*midlen; i++) {
-            printf("%0*x", BIGNUM_INT_BITS/4, scratch[i]);
-        }
-        printf("\n");
-#endif
-
-        internal_sub(scratch + 2*midlen, scratch,
-                     scratch + 2*midlen, 2*midlen);
-#ifdef KARA_DEBUG
-        printf("a1b0plusa0b1 = 0x");
-        for (i = 0; i < 2*midlen; i++) {
-            printf("%0*x", BIGNUM_INT_BITS/4, scratch[2*midlen+i]);
-        }
-        printf("\n");
-#endif
-
-        /*
-         * And now all we need to do is to add that middle coefficient
-         * back into the output. We may have to propagate a carry
-         * further up the output, but we can be sure it won't
-         * propagate right the way off the top.
-         */
-        carry = internal_add(c + 2*len - botlen - 2*midlen,
-                             scratch + 2*midlen,
-                             c + 2*len - botlen - 2*midlen, 2*midlen);
-        i = 2*len - botlen - 2*midlen - 1;
-        while (carry) {
-            assert(i >= 0);
-            BignumADC(c[i], carry, c[i], 0, carry);
-            i--;
-        }
-#ifdef KARA_DEBUG
-        printf("ab = 0x");
-        for (i = 0; i < 2*len; i++) {
-            printf("%0*x", BIGNUM_INT_BITS/4, c[i]);
-        }
-        printf("\n");
-#endif
-
-    } else {
-        int i;
-        BignumInt carry;
-        const BignumInt *ap, *bp;
-        BignumInt *cp, *cps;
-
-        /*
-         * Multiply in the ordinary O(N^2) way.
-         */
-
-        for (i = 0; i < 2 * len; i++)
-            c[i] = 0;
-
-        for (cps = c + 2*len, ap = a + len; ap-- > a; cps--) {
-            carry = 0;
-            for (cp = cps, bp = b + len; cp--, bp-- > b ;)
-                BignumMULADD2(carry, *cp, *ap, *bp, *cp, carry);
-            *cp = carry;
-        }
-    }
-}
-
-/*
- * Variant form of internal_mul used for the initial step of
- * Montgomery reduction. Only bothers outputting 'len' words
- * (everything above that is thrown away).
- */
-static void internal_mul_low(const BignumInt *a, const BignumInt *b,
-                             BignumInt *c, int len, BignumInt *scratch)
-{
-    if (len > KARATSUBA_THRESHOLD) {
-        int i;
-
-        /*
-         * Karatsuba-aware version of internal_mul_low. As before, we
-         * express each input value as a shifted combination of two
-         * halves:
-         *
-         *   a = a_1 D + a_0
-         *   b = b_1 D + b_0
-         *
-         * Then the full product is, as before,
-         *
-         *  ab = a_1 b_1 D^2 + (a_1 b_0 + a_0 b_1) D + a_0 b_0
-         *
-         * Provided we choose D on the large side (so that a_0 and b_0
-         * are _at least_ as long as a_1 and b_1), we don't need the
-         * topmost term at all, and we only need half of the middle
-         * term. So there's no point in doing the proper Karatsuba
-         * optimisation which computes the middle term using the top
-         * one, because we'd take as long computing the top one as
-         * just computing the middle one directly.
-         *
-         * So instead, we do a much more obvious thing: we call the
-         * fully optimised internal_mul to compute a_0 b_0, and we
-         * recursively call ourself to compute the _bottom halves_ of
-         * a_1 b_0 and a_0 b_1, each of which we add into the result
-         * in the obvious way.
-         *
-         * In other words, there's no actual Karatsuba _optimisation_
-         * in this function; the only benefit in doing it this way is
-         * that we call internal_mul proper for a large part of the
-         * work, and _that_ can optimise its operation.
-         */
-
-        int toplen = len/2, botlen = len - toplen; /* botlen is the bigger */
-
-        /*
-         * Scratch space for the various bits and pieces we're going
-         * to be adding together: we need botlen*2 words for a_0 b_0
-         * (though we may end up throwing away its topmost word), and
-         * toplen words for each of a_1 b_0 and a_0 b_1. That adds up
-         * to exactly 2*len.
-         */
-
-        /* a_0 b_0 */
-        internal_mul(a + toplen, b + toplen, scratch + 2*toplen, botlen,
-                     scratch + 2*len);
-
-        /* a_1 b_0 */
-        internal_mul_low(a, b + len - toplen, scratch + toplen, toplen,
-                         scratch + 2*len);
-
-        /* a_0 b_1 */
-        internal_mul_low(a + len - toplen, b, scratch, toplen,
-                         scratch + 2*len);
-
-        /* Copy the bottom half of the big coefficient into place */
-        for (i = 0; i < botlen; i++)
-            c[toplen + i] = scratch[2*toplen + botlen + i];
-
-        /* Add the two small coefficients, throwing away the returned carry */
-        internal_add(scratch, scratch + toplen, scratch, toplen);
-
-        /* And add that to the large coefficient, leaving the result in c. */
-        internal_add(scratch, scratch + 2*toplen + botlen - toplen,
-                     c, toplen);
-
-    } else {
-        int i;
-        BignumInt carry;
-        const BignumInt *ap, *bp;
-        BignumInt *cp, *cps;
-
-        /*
-         * Multiply in the ordinary O(N^2) way.
-         */
-
-        for (i = 0; i < len; i++)
-            c[i] = 0;
-
-        for (cps = c + len, ap = a + len; ap-- > a; cps--) {
-            carry = 0;
-            for (cp = cps, bp = b + len; bp--, cp-- > c ;)
-                BignumMULADD2(carry, *cp, *ap, *bp, *cp, carry);
-        }
-    }
-}
-
-/*
- * Montgomery reduction. Expects x to be a big-endian array of 2*len
- * BignumInts whose value satisfies 0 <= x < rn (where r = 2^(len *
- * BIGNUM_INT_BITS) is the Montgomery base). Returns in the same array
- * a value x' which is congruent to xr^{-1} mod n, and satisfies 0 <=
- * x' < n.
- *
- * 'n' and 'mninv' should be big-endian arrays of 'len' BignumInts
- * each, containing respectively n and the multiplicative inverse of
- * -n mod r.
- *
- * 'tmp' is an array of BignumInt used as scratch space, of length at
- * least 3*len + mul_compute_scratch(len).
- */
-static void monty_reduce(BignumInt *x, const BignumInt *n,
-                         const BignumInt *mninv, BignumInt *tmp, int len)
-{
-    int i;
-    BignumInt carry;
-
-    /*
-     * Multiply x by (-n)^{-1} mod r. This gives us a value m such
-     * that mn is congruent to -x mod r. Hence, mn+x is an exact
-     * multiple of r, and is also (obviously) congruent to x mod n.
-     */
-    internal_mul_low(x + len, mninv, tmp, len, tmp + 3*len);
-
-    /*
-     * Compute t = (mn+x)/r in ordinary, non-modular, integer
-     * arithmetic. By construction this is exact, and is congruent mod
-     * n to x * r^{-1}, i.e. the answer we want.
-     *
-     * The following multiply leaves that answer in the _most_
-     * significant half of the 'x' array, so then we must shift it
-     * down.
-     */
-    internal_mul(tmp, n, tmp+len, len, tmp + 3*len);
-    carry = internal_add(x, tmp+len, x, 2*len);
-    for (i = 0; i < len; i++)
-        x[len + i] = x[i], x[i] = 0;
-
-    /*
-     * Reduce t mod n. This doesn't require a full-on division by n,
-     * but merely a test and single optional subtraction, since we can
-     * show that 0 <= t < 2n.
-     *
-     * Proof:
-     *  + we computed m mod r, so 0 <= m < r.
-     *  + so 0 <= mn < rn, obviously
-     *  + hence we only need 0 <= x < rn to guarantee that 0 <= mn+x < 2rn
-     *  + yielding 0 <= (mn+x)/r < 2n as required.
-     */
-    if (!carry) {
-        for (i = 0; i < len; i++)
-            if (x[len + i] != n[i])
-                break;
-    }
-    if (carry || i >= len || x[len + i] > n[i])
-        internal_sub(x+len, n, x+len, len);
-}
-
-static void internal_add_shifted(BignumInt *number,
-				 BignumInt n, int shift)
-{
-    int word = 1 + (shift / BIGNUM_INT_BITS);
-    int bshift = shift % BIGNUM_INT_BITS;
-    BignumInt addendh, addendl;
-    BignumCarry carry;
-
-    addendl = n << bshift;
-    addendh = (bshift == 0 ? 0 : n >> (BIGNUM_INT_BITS - bshift));
-
-    assert(word <= number[0]);
-    BignumADC(number[word], carry, number[word], addendl, 0);
-    word++;
-    if (!addendh && !carry)
-        return;
-    assert(word <= number[0]);
-    BignumADC(number[word], carry, number[word], addendh, carry);
-    word++;
-    while (carry) {
-        assert(word <= number[0]);
-        BignumADC(number[word], carry, number[word], 0, carry);
-	word++;
-    }
-}
-
-static int bn_clz(BignumInt x)
-{
-    /*
-     * Count the leading zero bits in x. Equivalently, how far left
-     * would we need to shift x to make its top bit set?
-     *
-     * Precondition: x != 0.
-     */
-
-    /* FIXME: would be nice to put in some compiler intrinsics under
-     * ifdef here */
-    int i, ret = 0;
-    for (i = BIGNUM_INT_BITS / 2; i != 0; i >>= 1) {
-        if ((x >> (BIGNUM_INT_BITS-i)) == 0) {
-            x <<= i;
-            ret += i;
-        }
-    }
-    return ret;
-}
-
-static BignumInt reciprocal_word(BignumInt d)
-{
-    BignumInt dshort, recip, prodh, prodl;
-    int corrections;
-
-    /*
-     * Input: a BignumInt value d, with its top bit set.
-     */
-    assert(d >> (BIGNUM_INT_BITS-1) == 1);
-
-    /*
-     * Output: a value, shifted to fill a BignumInt, which is strictly
-     * less than 1/(d+1), i.e. is an *under*-estimate (but by as
-     * little as possible within the constraints) of the reciprocal of
-     * any number whose first BIGNUM_INT_BITS bits match d.
-     *
-     * Ideally we'd like to _totally_ fill BignumInt, i.e. always
-     * return a value with the top bit set. Unfortunately we can't
-     * quite guarantee that for all inputs and also return a fixed
-     * exponent. So instead we take our reciprocal to be
-     * 2^(BIGNUM_INT_BITS*2-1) / d, so that it has the top bit clear
-     * only in the exceptional case where d takes exactly the maximum
-     * value BIGNUM_INT_MASK; in that case, the top bit is clear and
-     * the next bit down is set.
-     */
-
-    /*
-     * Start by computing a half-length version of the answer, by
-     * straightforward division within a BignumInt.
-     */
-    dshort = (d >> (BIGNUM_INT_BITS/2)) + 1;
-    recip = (BIGNUM_TOP_BIT + dshort - 1) / dshort;
-    recip <<= BIGNUM_INT_BITS - BIGNUM_INT_BITS/2;
-
-    /*
-     * Newton-Raphson iteration to improve that starting reciprocal
-     * estimate: take f(x) = d - 1/x, and then the N-R formula gives
-     * x_new = x - f(x)/f'(x) = x - (d-1/x)/(1/x^2) = x(2-d*x). Or,
-     * taking our fixed-point representation into account, take f(x)
-     * to be d - K/x (where K = 2^(BIGNUM_INT_BITS*2-1) as discussed
-     * above) and then we get (2K - d*x) * x/K.
-     *
-     * Newton-Raphson doubles the number of correct bits at every
-     * iteration, and the initial division above already gave us half
-     * the output word, so it's only worth doing one iteration.
-     */
-    BignumMULADD(prodh, prodl, recip, d, recip);
-    prodl = ~prodl;
-    prodh = ~prodh;
-    {
-        BignumCarry c;
-        BignumADC(prodl, c, prodl, 1, 0);
-        prodh += c;
-    }
-    BignumMUL(prodh, prodl, prodh, recip);
-    recip = (prodh << 1) | (prodl >> (BIGNUM_INT_BITS-1));
-
-    /*
-     * Now make sure we have the best possible reciprocal estimate,
-     * before we return it. We might have been off by a handful either
-     * way - not enough to bother with any better-thought-out kind of
-     * correction loop.
-     */
-    BignumMULADD(prodh, prodl, recip, d, recip);
-    corrections = 0;
-    if (prodh >= BIGNUM_TOP_BIT) {
-        do {
-            BignumCarry c = 1;
-            BignumADC(prodl, c, prodl, ~d, c); prodh += BIGNUM_INT_MASK + c;
-            recip--;
-            corrections++;
-        } while (prodh >= ((BignumInt)1 << (BIGNUM_INT_BITS-1)));
-    } else {
-        while (1) {
-            BignumInt newprodh, newprodl;
-            BignumCarry c = 0;
-            BignumADC(newprodl, c, prodl, d, c); newprodh = prodh + c;
-            if (newprodh >= BIGNUM_TOP_BIT)
-                break;
-            prodh = newprodh;
-            prodl = newprodl;
-            recip++;
-            corrections++;
-        }
-    }
-
-    return recip;
-}
-
-/*
- * Compute a = a % m.
- * Input in first alen words of a and first mlen words of m.
- * Output in first alen words of a
- * (of which first alen-mlen words will be zero).
- * Quotient is accumulated in the `quotient' array, which is a Bignum
- * rather than the internal bigendian format.
- *
- * 'recip' must be the result of calling reciprocal_word() on the top
- * BIGNUM_INT_BITS of the modulus (denoted m0 in comments below), with
- * the topmost set bit normalised to the MSB of the input to
- * reciprocal_word. 'rshift' is how far left the top nonzero word of
- * the modulus had to be shifted to set that top bit.
- */
-static void internal_mod(BignumInt *a, int alen,
-			 BignumInt *m, int mlen,
-			 BignumInt *quot, BignumInt recip, int rshift)
-{
-    int i, k;
-
-#ifdef DIVISION_DEBUG
-    {
-        int d;
-        printf("start division, m=0x");
-        for (d = 0; d < mlen; d++)
-            printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)m[d]);
-        printf(", recip=%#0*llx, rshift=%d\n",
-               BIGNUM_INT_BITS/4, (unsigned long long)recip, rshift);
-    }
-#endif
-
-    /*
-     * Repeatedly use that reciprocal estimate to get a decent number
-     * of quotient bits, and subtract off the resulting multiple of m.
-     *
-     * Normally we expect to terminate this loop by means of finding
-     * out q=0 part way through, but one way in which we might not get
-     * that far in the first place is if the input a is actually zero,
-     * in which case we'll discard zero words from the front of a
-     * until we reach the termination condition in the for statement
-     * here.
-     */
-    for (i = 0; i <= alen - mlen ;) {
-	BignumInt product;
-        BignumInt aword, q;
-        int shift, full_bitoffset, bitoffset, wordoffset;
-
-#ifdef DIVISION_DEBUG
-        {
-            int d;
-            printf("main loop, a=0x");
-            for (d = 0; d < alen; d++)
-                printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]);
-            printf("\n");
-        }
-#endif
-
-        if (a[i] == 0) {
-#ifdef DIVISION_DEBUG
-            printf("zero word at i=%d\n", i);
-#endif
-            i++;
-            continue;
-        }
-
-        aword = a[i];
-        shift = bn_clz(aword);
-        aword <<= shift;
-        if (shift > 0 && i+1 < alen)
-            aword |= a[i+1] >> (BIGNUM_INT_BITS - shift);
-
-        {
-            BignumInt unused;
-            BignumMUL(q, unused, recip, aword);
-            (void)unused;
-        }
-
-#ifdef DIVISION_DEBUG
-        printf("i=%d, aword=%#0*llx, shift=%d, q=%#0*llx\n",
-               i, BIGNUM_INT_BITS/4, (unsigned long long)aword,
-               shift, BIGNUM_INT_BITS/4, (unsigned long long)q);
-#endif
-
-        /*
-         * Work out the right bit and word offsets to use when
-         * subtracting q*m from a.
-         *
-         * aword was taken from a[i], which means its LSB was at bit
-         * position (alen-1-i) * BIGNUM_INT_BITS. But then we shifted
-         * it left by 'shift', so now the low bit of aword corresponds
-         * to bit position (alen-1-i) * BIGNUM_INT_BITS - shift, i.e.
-         * aword is approximately equal to a / 2^(that).
-         *
-         * m0 comes from the top word of mod, so its LSB is at bit
-         * position (mlen-1) * BIGNUM_INT_BITS - rshift, i.e. it can
-         * be considered to be m / 2^(that power). 'recip' is the
-         * reciprocal of m0, times 2^(BIGNUM_INT_BITS*2-1), i.e. it's
-         * about 2^((mlen+1) * BIGNUM_INT_BITS - rshift - 1) / m.
-         *
-         * Hence, recip * aword is approximately equal to the product
-         * of those, which simplifies to
-         *
-         * a/m * 2^((mlen+2+i-alen)*BIGNUM_INT_BITS + shift - rshift - 1)
-         *
-         * But we've also shifted recip*aword down by BIGNUM_INT_BITS
-         * to form q, so we have
-         *
-         * q ~= a/m * 2^((mlen+1+i-alen)*BIGNUM_INT_BITS + shift - rshift - 1)
-         *
-         * and hence, when we now compute q*m, it will be about
-         * a*2^(all that lot), i.e. the negation of that expression is
-         * how far left we have to shift the product q*m to make it
-         * approximately equal to a.
-         */
-        full_bitoffset = -((mlen+1+i-alen)*BIGNUM_INT_BITS + shift-rshift-1);
-#ifdef DIVISION_DEBUG
-        printf("full_bitoffset=%d\n", full_bitoffset);
-#endif
-
-        if (full_bitoffset < 0) {
-            /*
-             * If we find ourselves needing to shift q*m _right_, that
-             * means we've reached the bottom of the quotient. Clip q
-             * so that its right shift becomes zero, and if that means
-             * q becomes _actually_ zero, this loop is done.
-             */
-            if (full_bitoffset <= -BIGNUM_INT_BITS)
-                break;
-            q >>= -full_bitoffset;
-            full_bitoffset = 0;
-            if (!q)
-                break;
-#ifdef DIVISION_DEBUG
-            printf("now full_bitoffset=%d, q=%#0*llx\n",
-                   full_bitoffset, BIGNUM_INT_BITS/4, (unsigned long long)q);
-#endif
-        }
-
-        wordoffset = full_bitoffset / BIGNUM_INT_BITS;
-        bitoffset = full_bitoffset % BIGNUM_INT_BITS;
-#ifdef DIVISION_DEBUG
-        printf("wordoffset=%d, bitoffset=%d\n", wordoffset, bitoffset);
-#endif
-
-        /* wordoffset as computed above is the offset between the LSWs
-         * of m and a. But in fact m and a are stored MSW-first, so we
-         * need to adjust it to be the offset between the actual array
-         * indices, and flip the sign too. */
-        wordoffset = alen - mlen - wordoffset;
-
-        if (bitoffset == 0) {
-            BignumCarry c = 1;
-            BignumInt prev_hi_word = 0;
-            for (k = mlen - 1; wordoffset+k >= i; k--) {
-                BignumInt mword = k<0 ? 0 : m[k];
-                BignumMULADD(prev_hi_word, product, q, mword, prev_hi_word);
-#ifdef DIVISION_DEBUG
-                printf("  aligned sub: product word for m[%d] = %#0*llx\n",
-                       k, BIGNUM_INT_BITS/4,
-                       (unsigned long long)product);
-#endif
-#ifdef DIVISION_DEBUG
-                printf("  aligned sub: subtrahend for a[%d] = %#0*llx\n",
-                       wordoffset+k, BIGNUM_INT_BITS/4,
-                       (unsigned long long)product);
-#endif
-                BignumADC(a[wordoffset+k], c, a[wordoffset+k], ~product, c);
-            }
-        } else {
-            BignumInt add_word = 0;
-            BignumInt c = 1;
-            BignumInt prev_hi_word = 0;
-            for (k = mlen - 1; wordoffset+k >= i; k--) {
-                BignumInt mword = k<0 ? 0 : m[k];
-                BignumMULADD(prev_hi_word, product, q, mword, prev_hi_word);
-#ifdef DIVISION_DEBUG
-                printf("  unaligned sub: product word for m[%d] = %#0*llx\n",
-                       k, BIGNUM_INT_BITS/4,
-                       (unsigned long long)product);
-#endif
-
-                add_word |= product << bitoffset;
-
-#ifdef DIVISION_DEBUG
-                printf("  unaligned sub: subtrahend for a[%d] = %#0*llx\n",
-                       wordoffset+k,
-                       BIGNUM_INT_BITS/4, (unsigned long long)add_word);
-#endif
-                BignumADC(a[wordoffset+k], c, a[wordoffset+k], ~add_word, c);
-
-                add_word = product >> (BIGNUM_INT_BITS - bitoffset);
-            }
-        }
-
-	if (quot) {
-#ifdef DIVISION_DEBUG
-            printf("adding quotient word %#0*llx << %d\n",
-                   BIGNUM_INT_BITS/4, (unsigned long long)q, full_bitoffset);
-#endif
-	    internal_add_shifted(quot, q, full_bitoffset);
-#ifdef DIVISION_DEBUG
-            {
-                int d;
-                printf("now quot=0x");
-                for (d = quot[0]; d > 0; d--)
-                    printf("%0*llx", BIGNUM_INT_BITS/4,
-                           (unsigned long long)quot[d]);
-                printf("\n");
-            }
-#endif
-        }
-    }
-
-#ifdef DIVISION_DEBUG
-    {
-        int d;
-        printf("end main loop, a=0x");
-        for (d = 0; d < alen; d++)
-            printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]);
-        if (quot) {
-            printf(", quot=0x");
-            for (d = quot[0]; d > 0; d--)
-                printf("%0*llx", BIGNUM_INT_BITS/4,
-                       (unsigned long long)quot[d]);
-        }
-        printf("\n");
-    }
-#endif
-
-    /*
-     * The above loop should terminate with the remaining value in a
-     * being strictly less than 2*m (if a >= 2*m then we should always
-     * have managed to get a nonzero q word), but we can't guarantee
-     * that it will be strictly less than m: consider a case where the
-     * remainder is 1, and another where the remainder is m-1. By the
-     * time a contains a value that's _about m_, you clearly can't
-     * distinguish those cases by looking at only the top word of a -
-     * you have to go all the way down to the bottom before you find
-     * out whether it's just less or just more than m.
-     *
-     * Hence, we now do a final fixup in which we subtract one last
-     * copy of m, or don't, accordingly. We should never have to
-     * subtract more than one copy of m here.
-     */
-    for (i = 0; i < alen; i++) {
-        /* Compare a with m, word by word, from the MSW down. As soon
-         * as we encounter a difference, we know whether we need the
-         * fixup. */
-        int mindex = mlen-alen+i;
-        BignumInt mword = mindex < 0 ? 0 : m[mindex];
-        if (a[i] < mword) {
-#ifdef DIVISION_DEBUG
-            printf("final fixup not needed, a < m\n");
-#endif
-            return;
-        } else if (a[i] > mword) {
-#ifdef DIVISION_DEBUG
-            printf("final fixup is needed, a > m\n");
-#endif
-            break;
-        }
-        /* If neither of those cases happened, the words are the same,
-         * so keep going and look at the next one. */
-    }
-#ifdef DIVISION_DEBUG
-    if (i == mlen) /* if we printed neither of the above diagnostics */
-        printf("final fixup is needed, a == m\n");
-#endif
-
-    /*
-     * If we got here without returning, then a >= m, so we must
-     * subtract m, and increment the quotient.
-     */
-    {
-        BignumCarry c = 1;
-        for (i = alen - 1; i >= 0; i--) {
-            int mindex = mlen-alen+i;
-            BignumInt mword = mindex < 0 ? 0 : m[mindex];
-            BignumADC(a[i], c, a[i], ~mword, c);
-        }
-    }
-    if (quot)
-        internal_add_shifted(quot, 1, 0);
-
-#ifdef DIVISION_DEBUG
-    {
-        int d;
-        printf("after final fixup, a=0x");
-        for (d = 0; d < alen; d++)
-            printf("%0*llx", BIGNUM_INT_BITS/4, (unsigned long long)a[d]);
-        if (quot) {
-            printf(", quot=0x");
-            for (d = quot[0]; d > 0; d--)
-                printf("%0*llx", BIGNUM_INT_BITS/4,
-                       (unsigned long long)quot[d]);
-        }
-        printf("\n");
-    }
-#endif
-}
-
-/*
- * Compute (base ^ exp) % mod, the pedestrian way.
- */
-Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod)
-{
-    BignumInt *a, *b, *n, *m, *scratch;
-    BignumInt recip;
-    int rshift;
-    int mlen, scratchlen, i, j;
-    Bignum base, result;
-
-    /*
-     * The most significant word of mod needs to be non-zero. It
-     * should already be, but let's make sure.
-     */
-    assert(mod[mod[0]] != 0);
-
-    /*
-     * Make sure the base is smaller than the modulus, by reducing
-     * it modulo the modulus if not.
-     */
-    base = bigmod(base_in, mod);
-
-    /* Allocate m of size mlen, copy mod to m */
-    /* We use big endian internally */
-    mlen = mod[0];
-    m = snewn(mlen, BignumInt);
-    for (j = 0; j < mlen; j++)
-	m[j] = mod[mod[0] - j];
-
-    /* Allocate n of size mlen, copy base to n */
-    n = snewn(mlen, BignumInt);
-    i = mlen - base[0];
-    for (j = 0; j < i; j++)
-	n[j] = 0;
-    for (j = 0; j < (int)base[0]; j++)
-	n[i + j] = base[base[0] - j];
-
-    /* Allocate a and b of size 2*mlen. Set a = 1 */
-    a = snewn(2 * mlen, BignumInt);
-    b = snewn(2 * mlen, BignumInt);
-    for (i = 0; i < 2 * mlen; i++)
-	a[i] = 0;
-    a[2 * mlen - 1] = 1;
-
-    /* Scratch space for multiplies */
-    scratchlen = mul_compute_scratch(mlen);
-    scratch = snewn(scratchlen, BignumInt);
-
-    /* Skip leading zero bits of exp. */
-    i = 0;
-    j = BIGNUM_INT_BITS-1;
-    while (i < (int)exp[0] && (exp[exp[0] - i] & ((BignumInt)1 << j)) == 0) {
-	j--;
-	if (j < 0) {
-	    i++;
-	    j = BIGNUM_INT_BITS-1;
-	}
-    }
-
-    /* Compute reciprocal of the top full word of the modulus */
-    {
-        BignumInt m0 = m[0];
-        rshift = bn_clz(m0);
-        if (rshift) {
-            m0 <<= rshift;
-            if (mlen > 1)
-                m0 |= m[1] >> (BIGNUM_INT_BITS - rshift);
-        }
-        recip = reciprocal_word(m0);
-    }
-
-    /* Main computation */
-    while (i < (int)exp[0]) {
-	while (j >= 0) {
-	    internal_mul(a + mlen, a + mlen, b, mlen, scratch);
-	    internal_mod(b, mlen * 2, m, mlen, NULL, recip, rshift);
-	    if ((exp[exp[0] - i] & ((BignumInt)1 << j)) != 0) {
-		internal_mul(b + mlen, n, a, mlen, scratch);
-		internal_mod(a, mlen * 2, m, mlen, NULL, recip, rshift);
-	    } else {
-		BignumInt *t;
-		t = a;
-		a = b;
-		b = t;
-	    }
-	    j--;
-	}
-	i++;
-	j = BIGNUM_INT_BITS-1;
-    }
-
-    /* Copy result to buffer */
-    result = newbn(mod[0]);
-    for (i = 0; i < mlen; i++)
-	result[result[0] - i] = a[i + mlen];
-    while (result[0] > 1 && result[result[0]] == 0)
-	result[0]--;
-
-    /* Free temporary arrays */
-    smemclr(a, 2 * mlen * sizeof(*a));
-    sfree(a);
-    smemclr(scratch, scratchlen * sizeof(*scratch));
-    sfree(scratch);
-    smemclr(b, 2 * mlen * sizeof(*b));
-    sfree(b);
-    smemclr(m, mlen * sizeof(*m));
-    sfree(m);
-    smemclr(n, mlen * sizeof(*n));
-    sfree(n);
-
-    freebn(base);
-
-    return result;
-}
-
-/*
- * Compute (base ^ exp) % mod. Uses the Montgomery multiplication
- * technique where possible, falling back to modpow_simple otherwise.
- */
-Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
-{
-    BignumInt *a, *b, *x, *n, *mninv, *scratch;
-    int len, scratchlen, i, j;
-    Bignum base, base2, r, rn, inv, result;
-
-    /*
-     * The most significant word of mod needs to be non-zero. It
-     * should already be, but let's make sure.
-     */
-    assert(mod[mod[0]] != 0);
-
-    /*
-     * mod had better be odd, or we can't do Montgomery multiplication
-     * using a power of two at all.
-     */
-    if (!(mod[1] & 1))
-        return modpow_simple(base_in, exp, mod);
-
-    /*
-     * Make sure the base is smaller than the modulus, by reducing
-     * it modulo the modulus if not.
-     */
-    base = bigmod(base_in, mod);
-
-    /*
-     * Compute the inverse of n mod r, for monty_reduce. (In fact we
-     * want the inverse of _minus_ n mod r, but we'll sort that out
-     * below.)
-     */
-    len = mod[0];
-    r = bn_power_2(BIGNUM_INT_BITS * len);
-    inv = modinv(mod, r);
-    assert(inv); /* cannot fail, since mod is odd and r is a power of 2 */
-
-    /*
-     * Multiply the base by r mod n, to get it into Montgomery
-     * representation.
-     */
-    base2 = modmul(base, r, mod);
-    freebn(base);
-    base = base2;
-
-    rn = bigmod(r, mod);               /* r mod n, i.e. Montgomerified 1 */
-
-    freebn(r);                         /* won't need this any more */
-
-    /*
-     * Set up internal arrays of the right lengths, in big-endian
-     * format, containing the base, the modulus, and the modulus's
-     * inverse.
-     */
-    n = snewn(len, BignumInt);
-    for (j = 0; j < len; j++)
-	n[len - 1 - j] = mod[j + 1];
-
-    mninv = snewn(len, BignumInt);
-    for (j = 0; j < len; j++)
-	mninv[len - 1 - j] = (j < (int)inv[0] ? inv[j + 1] : 0);
-    freebn(inv);         /* we don't need this copy of it any more */
-    /* Now negate mninv mod r, so it's the inverse of -n rather than +n. */
-    x = snewn(len, BignumInt);
-    for (j = 0; j < len; j++)
-        x[j] = 0;
-    internal_sub(x, mninv, mninv, len);
-
-    /* x = snewn(len, BignumInt); */ /* already done above */
-    for (j = 0; j < len; j++)
-	x[len - 1 - j] = (j < (int)base[0] ? base[j + 1] : 0);
-    freebn(base);        /* we don't need this copy of it any more */
-
-    a = snewn(2*len, BignumInt);
-    b = snewn(2*len, BignumInt);
-    for (j = 0; j < len; j++)
-	a[2*len - 1 - j] = (j < (int)rn[0] ? rn[j + 1] : 0);
-    freebn(rn);
-
-    /* Scratch space for multiplies */
-    scratchlen = 3*len + mul_compute_scratch(len);
-    scratch = snewn(scratchlen, BignumInt);
-
-    /* Skip leading zero bits of exp. */
-    i = 0;
-    j = BIGNUM_INT_BITS-1;
-    while (i < (int)exp[0] && (exp[exp[0] - i] & ((BignumInt)1 << j)) == 0) {
-	j--;
-	if (j < 0) {
-	    i++;
-	    j = BIGNUM_INT_BITS-1;
-	}
-    }
-
-    /* Main computation */
-    while (i < (int)exp[0]) {
-	while (j >= 0) {
-	    internal_mul(a + len, a + len, b, len, scratch);
-            monty_reduce(b, n, mninv, scratch, len);
-	    if ((exp[exp[0] - i] & ((BignumInt)1 << j)) != 0) {
-                internal_mul(b + len, x, a, len,  scratch);
-                monty_reduce(a, n, mninv, scratch, len);
-	    } else {
-		BignumInt *t;
-		t = a;
-		a = b;
-		b = t;
-	    }
-	    j--;
-	}
-	i++;
-	j = BIGNUM_INT_BITS-1;
-    }
-
-    /*
-     * Final monty_reduce to get back from the adjusted Montgomery
-     * representation.
-     */
-    monty_reduce(a, n, mninv, scratch, len);
-
-    /* Copy result to buffer */
-    result = newbn(mod[0]);
-    for (i = 0; i < len; i++)
-	result[result[0] - i] = a[i + len];
-    while (result[0] > 1 && result[result[0]] == 0)
-	result[0]--;
-
-    /* Free temporary arrays */
-    smemclr(scratch, scratchlen * sizeof(*scratch));
-    sfree(scratch);
-    smemclr(a, 2 * len * sizeof(*a));
-    sfree(a);
-    smemclr(b, 2 * len * sizeof(*b));
-    sfree(b);
-    smemclr(mninv, len * sizeof(*mninv));
-    sfree(mninv);
-    smemclr(n, len * sizeof(*n));
-    sfree(n);
-    smemclr(x, len * sizeof(*x));
-    sfree(x);
-
-    return result;
-}
-
-/*
- * Compute (p * q) % mod.
- * The most significant word of mod MUST be non-zero.
- * We assume that the result array is the same size as the mod array.
- */
-Bignum modmul(Bignum p, Bignum q, Bignum mod)
-{
-    BignumInt *a, *n, *m, *o, *scratch;
-    BignumInt recip;
-    int rshift, scratchlen;
-    int pqlen, mlen, rlen, i, j;
-    Bignum result;
-
-    /*
-     * The most significant word of mod needs to be non-zero. It
-     * should already be, but let's make sure.
-     */
-    assert(mod[mod[0]] != 0);
-
-    /* Allocate m of size mlen, copy mod to m */
-    /* We use big endian internally */
-    mlen = mod[0];
-    m = snewn(mlen, BignumInt);
-    for (j = 0; j < mlen; j++)
-	m[j] = mod[mod[0] - j];
-
-    pqlen = (p[0] > q[0] ? p[0] : q[0]);
-
-    /*
-     * Make sure that we're allowing enough space. The shifting below
-     * will underflow the vectors we allocate if pqlen is too small.
-     */
-    if (2*pqlen <= mlen)
-        pqlen = mlen/2 + 1;
-
-    /* Allocate n of size pqlen, copy p to n */
-    n = snewn(pqlen, BignumInt);
-    i = pqlen - p[0];
-    for (j = 0; j < i; j++)
-	n[j] = 0;
-    for (j = 0; j < (int)p[0]; j++)
-	n[i + j] = p[p[0] - j];
-
-    /* Allocate o of size pqlen, copy q to o */
-    o = snewn(pqlen, BignumInt);
-    i = pqlen - q[0];
-    for (j = 0; j < i; j++)
-	o[j] = 0;
-    for (j = 0; j < (int)q[0]; j++)
-	o[i + j] = q[q[0] - j];
-
-    /* Allocate a of size 2*pqlen for result */
-    a = snewn(2 * pqlen, BignumInt);
-
-    /* Scratch space for multiplies */
-    scratchlen = mul_compute_scratch(pqlen);
-    scratch = snewn(scratchlen, BignumInt);
-
-    /* Compute reciprocal of the top full word of the modulus */
-    {
-        BignumInt m0 = m[0];
-        rshift = bn_clz(m0);
-        if (rshift) {
-            m0 <<= rshift;
-            if (mlen > 1)
-                m0 |= m[1] >> (BIGNUM_INT_BITS - rshift);
-        }
-        recip = reciprocal_word(m0);
-    }
-
-    /* Main computation */
-    internal_mul(n, o, a, pqlen, scratch);
-    internal_mod(a, pqlen * 2, m, mlen, NULL, recip, rshift);
-
-    /* Copy result to buffer */
-    rlen = (mlen < pqlen * 2 ? mlen : pqlen * 2);
-    result = newbn(rlen);
-    for (i = 0; i < rlen; i++)
-	result[result[0] - i] = a[i + 2 * pqlen - rlen];
-    while (result[0] > 1 && result[result[0]] == 0)
-	result[0]--;
-
-    /* Free temporary arrays */
-    smemclr(scratch, scratchlen * sizeof(*scratch));
-    sfree(scratch);
-    smemclr(a, 2 * pqlen * sizeof(*a));
-    sfree(a);
-    smemclr(m, mlen * sizeof(*m));
-    sfree(m);
-    smemclr(n, pqlen * sizeof(*n));
-    sfree(n);
-    smemclr(o, pqlen * sizeof(*o));
-    sfree(o);
-
-    return result;
-}
-
-Bignum modsub(const Bignum a, const Bignum b, const Bignum n)
-{
-    Bignum a1, b1, ret;
-
-    if (bignum_cmp(a, n) >= 0) a1 = bigmod(a, n);
-    else a1 = a;
-    if (bignum_cmp(b, n) >= 0) b1 = bigmod(b, n);
-    else b1 = b;
-
-    if (bignum_cmp(a1, b1) >= 0) /* a >= b */
-    {
-        ret = bigsub(a1, b1);
-    }
-    else
-    {
-        /* Handle going round the corner of the modulus without having
-         * negative support in Bignum */
-        Bignum tmp = bigsub(n, b1);
-        assert(tmp);
-        ret = bigadd(tmp, a1);
-        freebn(tmp);
-    }
-
-    if (a != a1) freebn(a1);
-    if (b != b1) freebn(b1);
-
-    return ret;
-}
-
-/*
- * Compute p % mod.
- * The most significant word of mod MUST be non-zero.
- * We assume that the result array is the same size as the mod array.
- * We optionally write out a quotient if `quotient' is non-NULL.
- * We can avoid writing out the result if `result' is NULL.
- */
-static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient)
-{
-    BignumInt *n, *m;
-    BignumInt recip;
-    int rshift;
-    int plen, mlen, i, j;
-
-    /*
-     * The most significant word of mod needs to be non-zero. It
-     * should already be, but let's make sure.
-     */
-    assert(mod[mod[0]] != 0);
-
-    /* Allocate m of size mlen, copy mod to m */
-    /* We use big endian internally */
-    mlen = mod[0];
-    m = snewn(mlen, BignumInt);
-    for (j = 0; j < mlen; j++)
-	m[j] = mod[mod[0] - j];
-
-    plen = p[0];
-    /* Ensure plen > mlen */
-    if (plen <= mlen)
-	plen = mlen + 1;
-
-    /* Allocate n of size plen, copy p to n */
-    n = snewn(plen, BignumInt);
-    for (j = 0; j < plen; j++)
-	n[j] = 0;
-    for (j = 1; j <= (int)p[0]; j++)
-	n[plen - j] = p[j];
-
-    /* Compute reciprocal of the top full word of the modulus */
-    {
-        BignumInt m0 = m[0];
-        rshift = bn_clz(m0);
-        if (rshift) {
-            m0 <<= rshift;
-            if (mlen > 1)
-                m0 |= m[1] >> (BIGNUM_INT_BITS - rshift);
-        }
-        recip = reciprocal_word(m0);
-    }
-
-    /* Main computation */
-    internal_mod(n, plen, m, mlen, quotient, recip, rshift);
-
-    /* Copy result to buffer */
-    if (result) {
-	for (i = 1; i <= (int)result[0]; i++) {
-	    int j = plen - i;
-	    result[i] = j >= 0 ? n[j] : 0;
-	}
-    }
-
-    /* Free temporary arrays */
-    smemclr(m, mlen * sizeof(*m));
-    sfree(m);
-    smemclr(n, plen * sizeof(*n));
-    sfree(n);
-}
-
-/*
- * Decrement a number.
- */
-void decbn(Bignum bn)
-{
-    int i = 1;
-    while (i < (int)bn[0] && bn[i] == 0)
-	bn[i++] = BIGNUM_INT_MASK;
-    bn[i]--;
-}
-
-Bignum bignum_from_bytes(const void *vdata, int nbytes)
-{
-    const unsigned char *data = (const unsigned char *)vdata;
-    Bignum result;
-    int w, i;
-
-    assert(nbytes >= 0 && nbytes < INT_MAX/8);
-
-    w = (nbytes + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; /* bytes->words */
-
-    result = newbn(w);
-    for (i = 1; i <= w; i++)
-	result[i] = 0;
-    for (i = nbytes; i--;) {
-	unsigned char byte = *data++;
-	result[1 + i / BIGNUM_INT_BYTES] |=
-            (BignumInt)byte << (8*i % BIGNUM_INT_BITS);
-    }
-
-    bn_restore_invariant(result);
-    return result;
-}
-
-Bignum bignum_from_bytes_le(const void *vdata, int nbytes)
-{
-    const unsigned char *data = (const unsigned char *)vdata;
-    Bignum result;
-    int w, i;
-
-    assert(nbytes >= 0 && nbytes < INT_MAX/8);
-
-    w = (nbytes + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; /* bytes->words */
-
-    result = newbn(w);
-    for (i = 1; i <= w; i++)
-        result[i] = 0;
-    for (i = 0; i < nbytes; ++i) {
-        unsigned char byte = *data++;
-        result[1 + i / BIGNUM_INT_BYTES] |=
-            (BignumInt)byte << (8*i % BIGNUM_INT_BITS);
-    }
-
-    bn_restore_invariant(result);
-    return result;
-}
-
-Bignum bignum_from_decimal(const char *decimal)
-{
-    Bignum result = copybn(Zero);
-
-    while (*decimal) {
-        Bignum tmp, tmp2;
-
-        if (!isdigit((unsigned char)*decimal)) {
-            freebn(result);
-            return 0;
-        }
-
-        tmp = bigmul(result, Ten);
-        tmp2 = bignum_from_long(*decimal - '0');
-        freebn(result);
-        result = bigadd(tmp, tmp2);
-        freebn(tmp);
-        freebn(tmp2);
-
-        decimal++;
-    }
-
-    return result;
-}
-
-Bignum bignum_random_in_range(const Bignum lower, const Bignum upper)
-{
-    Bignum ret = NULL;
-    unsigned char *bytes;
-    int upper_len = bignum_bitcount(upper);
-    int upper_bytes = upper_len / 8;
-    int upper_bits = upper_len % 8;
-    if (upper_bits) ++upper_bytes;
-
-    bytes = snewn(upper_bytes, unsigned char);
-    do {
-        int i;
-
-        if (ret) freebn(ret);
-
-        for (i = 0; i < upper_bytes; ++i)
-        {
-            bytes[i] = (unsigned char)random_byte();
-        }
-        /* Mask the top to reduce failure rate to 50/50 */
-        if (upper_bits)
-        {
-            bytes[i - 1] &= 0xFF >> (8 - upper_bits);
-        }
-
-        ret = bignum_from_bytes(bytes, upper_bytes);
-    } while (bignum_cmp(ret, lower) < 0 || bignum_cmp(ret, upper) > 0);
-    smemclr(bytes, upper_bytes);
-    sfree(bytes);
-
-    return ret;
-}
-
-/*
- * Return the bit count of a bignum.
- */
-int bignum_bitcount(Bignum bn)
-{
-    int bitcount = bn[0] * BIGNUM_INT_BITS - 1;
-    while (bitcount >= 0
-	   && (bn[bitcount / BIGNUM_INT_BITS + 1] >> (bitcount % BIGNUM_INT_BITS)) == 0) bitcount--;
-    return bitcount + 1;
-}
-
-/*
- * Return a byte from a bignum; 0 is least significant, etc.
- */
-int bignum_byte(Bignum bn, int i)
-{
-    if (i < 0 || i >= (int)(BIGNUM_INT_BYTES * bn[0]))
-	return 0;		       /* beyond the end */
-    else
-	return (bn[i / BIGNUM_INT_BYTES + 1] >>
-		((i % BIGNUM_INT_BYTES)*8)) & 0xFF;
-}
-
-/*
- * Return a bit from a bignum; 0 is least significant, etc.
- */
-int bignum_bit(Bignum bn, int i)
-{
-    if (i < 0 || i >= (int)(BIGNUM_INT_BITS * bn[0]))
-	return 0;		       /* beyond the end */
-    else
-	return (bn[i / BIGNUM_INT_BITS + 1] >> (i % BIGNUM_INT_BITS)) & 1;
-}
-
-/*
- * Set a bit in a bignum; 0 is least significant, etc.
- */
-void bignum_set_bit(Bignum bn, int bitnum, int value)
-{
-    if (bitnum < 0 || bitnum >= (int)(BIGNUM_INT_BITS * bn[0])) {
-        if (value) abort();		       /* beyond the end */
-    } else {
-	int v = bitnum / BIGNUM_INT_BITS + 1;
-	BignumInt mask = (BignumInt)1 << (bitnum % BIGNUM_INT_BITS);
-	if (value)
-	    bn[v] |= mask;
-	else
-	    bn[v] &= ~mask;
-    }
-}
-
-void BinarySink_put_mp_ssh1(BinarySink *bs, Bignum bn)
-{
-    int bits = bignum_bitcount(bn);
-    int bytes = (bits + 7) / 8;
-    int i;
-
-    put_uint16(bs, bits);
-    for (i = bytes; i--;)
-        put_byte(bs, bignum_byte(bn, i));
-}
-
-void BinarySink_put_mp_ssh2(BinarySink *bs, Bignum bn)
-{
-    int bytes = (bignum_bitcount(bn) + 8) / 8;
-    int i;
-
-    put_uint32(bs, bytes);
-    for (i = bytes; i--;)
-        put_byte(bs, bignum_byte(bn, i));
-}
-
-Bignum BinarySource_get_mp_ssh1(BinarySource *src)
-{
-    unsigned bitc = get_uint16(src);
-    ptrlen bytes = get_data(src, (bitc + 7) / 8);
-    if (get_err(src)) {
-        return bignum_from_long(0);
-    } else {
-        Bignum toret = bignum_from_bytes(bytes.ptr, bytes.len);
-        /* SSH-1.5 spec says that it's OK for the prefix uint16 to be
-         * _greater_ than the actual number of bits */
-        if (bignum_bitcount(toret) > bitc) {
-            src->err = BSE_INVALID;
-            freebn(toret);
-            toret = bignum_from_long(0);
-        }
-        return toret;
-    }
-}
-
-Bignum BinarySource_get_mp_ssh2(BinarySource *src)
-{
-    ptrlen bytes = get_string(src);
-    if (get_err(src)) {
-        return bignum_from_long(0);
-    } else {
-        const unsigned char *p = bytes.ptr;
-        if ((bytes.len > 0 &&
-             ((p[0] & 0x80) ||
-              (p[0] == 0 && (bytes.len <= 1 || !(p[1] & 0x80)))))) {
-            src->err = BSE_INVALID;
-            return bignum_from_long(0);
-        }
-        return bignum_from_bytes(bytes.ptr, bytes.len);
-    }
-}
-
-/*
- * Compare two bignums. Returns like strcmp.
- */
-int bignum_cmp(Bignum a, Bignum b)
-{
-    int amax = a[0], bmax = b[0];
-    int i;
-
-    /* Annoyingly we have two representations of zero */
-    if (amax == 1 && a[amax] == 0)
-        amax = 0;
-    if (bmax == 1 && b[bmax] == 0)
-        bmax = 0;
-
-    assert(amax == 0 || a[amax] != 0);
-    assert(bmax == 0 || b[bmax] != 0);
-
-    i = (amax > bmax ? amax : bmax);
-    while (i) {
-	BignumInt aval = (i > amax ? 0 : a[i]);
-	BignumInt bval = (i > bmax ? 0 : b[i]);
-	if (aval < bval)
-	    return -1;
-	if (aval > bval)
-	    return +1;
-	i--;
-    }
-    return 0;
-}
-
-/*
- * Right-shift one bignum to form another.
- */
-Bignum bignum_rshift(Bignum a, int shift)
-{
-    Bignum ret;
-    int i, shiftw, shiftb, shiftbb, bits;
-    BignumInt ai, ai1;
-
-    assert(shift >= 0);
-
-    bits = bignum_bitcount(a) - shift;
-    ret = newbn((bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
-
-    if (ret) {
-	shiftw = shift / BIGNUM_INT_BITS;
-	shiftb = shift % BIGNUM_INT_BITS;
-	shiftbb = BIGNUM_INT_BITS - shiftb;
-
-	ai1 = a[shiftw + 1];
-	for (i = 1; i <= (int)ret[0]; i++) {
-	    ai = ai1;
-	    ai1 = (i + shiftw + 1 <= (int)a[0] ? a[i + shiftw + 1] : 0);
-	    ret[i] = ((ai >> shiftb) | (ai1 << shiftbb)) & BIGNUM_INT_MASK;
-	}
-    }
-
-    return ret;
-}
-
-/*
- * Left-shift one bignum to form another.
- */
-Bignum bignum_lshift(Bignum a, int shift)
-{
-    Bignum ret;
-    int bits, shiftWords, shiftBits;
-
-    assert(shift >= 0);
-
-    bits = bignum_bitcount(a) + shift;
-    ret = newbn((bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
-
-    shiftWords = shift / BIGNUM_INT_BITS;
-    shiftBits = shift % BIGNUM_INT_BITS;
-
-    if (shiftBits == 0)
-    {
-        memcpy(&ret[1 + shiftWords], &a[1], sizeof(BignumInt) * a[0]);
-    }
-    else
-    {
-        int i;
-        BignumInt carry = 0;
-
-        /* Remember that Bignum[0] is length, so add 1 */
-        for (i = shiftWords + 1; i < ((int)a[0]) + shiftWords + 1; ++i)
-        {
-            BignumInt from = a[i - shiftWords];
-            ret[i] = (from << shiftBits) | carry;
-            carry = from >> (BIGNUM_INT_BITS - shiftBits);
-        }
-        if (carry) ret[i] = carry;
-    }
-
-    return ret;
-}
-
-/*
- * Non-modular multiplication and addition.
- */
-Bignum bigmuladd(Bignum a, Bignum b, Bignum addend)
-{
-    int alen = a[0], blen = b[0];
-    int mlen = (alen > blen ? alen : blen);
-    int rlen, i, maxspot;
-    int wslen;
-    BignumInt *workspace;
-    Bignum ret;
-
-    /* mlen space for a, mlen space for b, 2*mlen for result,
-     * plus scratch space for multiplication */
-    wslen = mlen * 4 + mul_compute_scratch(mlen);
-    workspace = snewn(wslen, BignumInt);
-    for (i = 0; i < mlen; i++) {
-	workspace[0 * mlen + i] = (mlen - i <= (int)a[0] ? a[mlen - i] : 0);
-	workspace[1 * mlen + i] = (mlen - i <= (int)b[0] ? b[mlen - i] : 0);
-    }
-
-    internal_mul(workspace + 0 * mlen, workspace + 1 * mlen,
-		 workspace + 2 * mlen, mlen, workspace + 4 * mlen);
-
-    /* now just copy the result back */
-    rlen = alen + blen + 1;
-    if (addend && rlen <= (int)addend[0])
-	rlen = addend[0] + 1;
-    ret = newbn(rlen);
-    maxspot = 0;
-    for (i = 1; i <= (int)ret[0]; i++) {
-	ret[i] = (i <= 2 * mlen ? workspace[4 * mlen - i] : 0);
-	if (ret[i] != 0)
-	    maxspot = i;
-    }
-    ret[0] = maxspot;
-
-    /* now add in the addend, if any */
-    if (addend) {
-	BignumCarry carry = 0;
-	for (i = 1; i <= rlen; i++) {
-            BignumInt retword = (i <= (int)ret[0] ? ret[i] : 0);
-            BignumInt addword = (i <= (int)addend[0] ? addend[i] : 0);
-            BignumADC(ret[i], carry, retword, addword, carry);
-	    if (ret[i] != 0 && i > maxspot)
-		maxspot = i;
-	}
-    }
-    ret[0] = maxspot;
-
-    smemclr(workspace, wslen * sizeof(*workspace));
-    sfree(workspace);
-    return ret;
-}
-
-/*
- * Non-modular multiplication.
- */
-Bignum bigmul(Bignum a, Bignum b)
-{
-    return bigmuladd(a, b, NULL);
-}
-
-/*
- * Simple addition.
- */
-Bignum bigadd(Bignum a, Bignum b)
-{
-    int alen = a[0], blen = b[0];
-    int rlen = (alen > blen ? alen : blen) + 1;
-    int i, maxspot;
-    Bignum ret;
-    BignumCarry carry;
-
-    ret = newbn(rlen);
-
-    carry = 0;
-    maxspot = 0;
-    for (i = 1; i <= rlen; i++) {
-        BignumInt aword = (i <= (int)a[0] ? a[i] : 0);
-        BignumInt bword = (i <= (int)b[0] ? b[i] : 0);
-        BignumADC(ret[i], carry, aword, bword, carry);
-        if (ret[i] != 0 && i > maxspot)
-            maxspot = i;
-    }
-    ret[0] = maxspot;
-
-    return ret;
-}
-
-/*
- * Subtraction. Returns a-b, or NULL if the result would come out
- * negative (recall that this entire bignum module only handles
- * positive numbers).
- */
-Bignum bigsub(Bignum a, Bignum b)
-{
-    int alen = a[0], blen = b[0];
-    int rlen = (alen > blen ? alen : blen);
-    int i, maxspot;
-    Bignum ret;
-    BignumCarry carry;
-
-    ret = newbn(rlen);
-
-    carry = 1;
-    maxspot = 0;
-    for (i = 1; i <= rlen; i++) {
-        BignumInt aword = (i <= (int)a[0] ? a[i] : 0);
-        BignumInt bword = (i <= (int)b[0] ? b[i] : 0);
-        BignumADC(ret[i], carry, aword, ~bword, carry);
-        if (ret[i] != 0 && i > maxspot)
-            maxspot = i;
-    }
-    ret[0] = maxspot;
-
-    if (!carry) {
-        freebn(ret);
-        return NULL;
-    }
-
-    return ret;
-}
-
-/*
- * Create a bignum which is the bitmask covering another one. That
- * is, the smallest integer which is >= N and is also one less than
- * a power of two.
- */
-Bignum bignum_bitmask(Bignum n)
-{
-    Bignum ret = copybn(n);
-    int i;
-    BignumInt j;
-
-    i = ret[0];
-    while (n[i] == 0 && i > 0)
-	i--;
-    if (i <= 0)
-	return ret;		       /* input was zero */
-    j = 1;
-    while (j < n[i])
-	j = 2 * j + 1;
-    ret[i] = j;
-    while (--i > 0)
-	ret[i] = BIGNUM_INT_MASK;
-    return ret;
-}
-
-/*
- * Convert an unsigned long into a bignum.
- */
-Bignum bignum_from_long(unsigned long n)
-{
-    const int maxwords =
-        (sizeof(unsigned long) + sizeof(BignumInt) - 1) / sizeof(BignumInt);
-    Bignum ret;
-    int i;
-
-    ret = newbn(maxwords);
-    ret[0] = 0;
-    for (i = 0; i < maxwords; i++) {
-        ret[i+1] = n >> (i * BIGNUM_INT_BITS);
-        if (ret[i+1] != 0)
-            ret[0] = i+1;
-    }
-
-    return ret;
-}
-
-/*
- * Add a long to a bignum.
- */
-Bignum bignum_add_long(Bignum number, unsigned long n)
-{
-    const int maxwords =
-        (sizeof(unsigned long) + sizeof(BignumInt) - 1) / sizeof(BignumInt);
-    Bignum ret;
-    int words, i;
-    BignumCarry carry;
-
-    words = number[0];
-    if (words < maxwords)
-        words = maxwords;
-    words++;
-    ret = newbn(words);
-
-    carry = 0;
-    ret[0] = 0;
-    for (i = 0; i < words; i++) {
-        BignumInt nword = (i < maxwords ? n >> (i * BIGNUM_INT_BITS) : 0);
-        BignumInt numword = (i < number[0] ? number[i+1] : 0);
-        BignumADC(ret[i+1], carry, numword, nword, carry);
-	if (ret[i+1] != 0)
-            ret[0] = i+1;
-    }
-    return ret;
-}
-
-/*
- * Compute the residue of a bignum, modulo a (max 16-bit) short.
- */
-unsigned short bignum_mod_short(Bignum number, unsigned short modulus)
-{
-    unsigned long mod = modulus, r = 0;
-    /* Precompute (BIGNUM_INT_MASK+1) % mod */
-    unsigned long base_r = (BIGNUM_INT_MASK - modulus + 1) % mod;
-    int i;
-
-    for (i = number[0]; i > 0; i--) {
-        /*
-         * Conceptually, ((r << BIGNUM_INT_BITS) + number[i]) % mod
-         */
-        r = ((r * base_r) + (number[i] % mod)) % mod;
-    }
-    return (unsigned short) r;
-}
-
-#ifdef DEBUG
-void diagbn(char *prefix, Bignum md)
-{
-    int i, nibbles, morenibbles;
-    static const char hex[] = "0123456789ABCDEF";
-
-    debug("%s0x", prefix ? prefix : "");
-
-    nibbles = (3 + bignum_bitcount(md)) / 4;
-    if (nibbles < 1)
-	nibbles = 1;
-    morenibbles = 4 * md[0] - nibbles;
-    for (i = 0; i < morenibbles; i++)
-	debug("-");
-    for (i = nibbles; i--;)
-        debug("%c", hex[(bignum_byte(md, i / 2) >> (4 * (i % 2))) & 0xF]);
-
-    if (prefix)
-	debug("\n");
-}
-#endif
-
-/*
- * Simple division.
- */
-Bignum bigdiv(Bignum a, Bignum b)
-{
-    Bignum q = newbn(a[0]);
-    bigdivmod(a, b, NULL, q);
-    while (q[0] > 1 && q[q[0]] == 0)
-        q[0]--;
-    return q;
-}
-
-/*
- * Simple remainder.
- */
-Bignum bigmod(Bignum a, Bignum b)
-{
-    Bignum r = newbn(b[0]);
-    bigdivmod(a, b, r, NULL);
-    while (r[0] > 1 && r[r[0]] == 0)
-        r[0]--;
-    return r;
-}
-
-/*
- * Greatest common divisor.
- */
-Bignum biggcd(Bignum av, Bignum bv)
-{
-    Bignum a = copybn(av);
-    Bignum b = copybn(bv);
-
-    while (bignum_cmp(b, Zero) != 0) {
-	Bignum t = newbn(b[0]);
-	bigdivmod(a, b, t, NULL);
-	while (t[0] > 1 && t[t[0]] == 0)
-	    t[0]--;
-	freebn(a);
-	a = b;
-	b = t;
-    }
-
-    freebn(b);
-    return a;
-}
-
-/*
- * Modular inverse, using Euclid's extended algorithm.
- */
-Bignum modinv(Bignum number, Bignum modulus)
-{
-    Bignum a = copybn(modulus);
-    Bignum b = copybn(number);
-    Bignum xp = copybn(Zero);
-    Bignum x = copybn(One);
-    int sign = +1;
-
-    assert(number[number[0]] != 0);
-    assert(modulus[modulus[0]] != 0);
-
-    while (bignum_cmp(b, One) != 0) {
-	Bignum t, q;
-
-        if (bignum_cmp(b, Zero) == 0) {
-            /*
-             * Found a common factor between the inputs, so we cannot
-             * return a modular inverse at all.
-             */
-            freebn(b);
-            freebn(a);
-            freebn(xp);
-            freebn(x);
-            return NULL;
-        }
-
-        t = newbn(b[0]);
-	q = newbn(a[0]);
-	bigdivmod(a, b, t, q);
-	while (t[0] > 1 && t[t[0]] == 0)
-	    t[0]--;
-	while (q[0] > 1 && q[q[0]] == 0)
-	    q[0]--;
-	freebn(a);
-	a = b;
-	b = t;
-	t = xp;
-	xp = x;
-	x = bigmuladd(q, xp, t);
-	sign = -sign;
-	freebn(t);
-	freebn(q);
-    }
-
-    freebn(b);
-    freebn(a);
-    freebn(xp);
-
-    /* now we know that sign * x == 1, and that x < modulus */
-    if (sign < 0) {
-	/* set a new x to be modulus - x */
-	Bignum newx = newbn(modulus[0]);
-	BignumInt carry = 0;
-	int maxspot = 1;
-	int i;
-
-	for (i = 1; i <= (int)newx[0]; i++) {
-	    BignumInt aword = (i <= (int)modulus[0] ? modulus[i] : 0);
-	    BignumInt bword = (i <= (int)x[0] ? x[i] : 0);
-	    newx[i] = aword - bword - carry;
-	    bword = ~bword;
-	    carry = carry ? (newx[i] >= bword) : (newx[i] > bword);
-	    if (newx[i] != 0)
-		maxspot = i;
-	}
-	newx[0] = maxspot;
-	freebn(x);
-	x = newx;
-    }
-
-    /* and return. */
-    return x;
-}
-
-/*
- * Render a bignum into decimal. Return a malloced string holding
- * the decimal representation.
- */
-char *bignum_decimal(Bignum x)
-{
-    int ndigits, ndigit;
-    int i;
-    bool iszero;
-    BignumInt carry;
-    char *ret;
-    BignumInt *workspace;
-
-    /*
-     * First, estimate the number of digits. Since log(10)/log(2)
-     * is just greater than 93/28 (the joys of continued fraction
-     * approximations...) we know that for every 93 bits, we need
-     * at most 28 digits. This will tell us how much to malloc.
-     *
-     * Formally: if x has i bits, that means x is strictly less
-     * than 2^i. Since 2 is less than 10^(28/93), this is less than
-     * 10^(28i/93). We need an integer power of ten, so we must
-     * round up (rounding down might make it less than x again).
-     * Therefore if we multiply the bit count by 28/93, rounding
-     * up, we will have enough digits.
-     *
-     * i=0 (i.e., x=0) is an irritating special case.
-     */
-    i = bignum_bitcount(x);
-    if (!i)
-	ndigits = 1;		       /* x = 0 */
-    else
-	ndigits = (28 * i + 92) / 93;  /* multiply by 28/93 and round up */
-    ndigits++;			       /* allow for trailing \0 */
-    ret = snewn(ndigits, char);
-
-    /*
-     * Now allocate some workspace to hold the binary form as we
-     * repeatedly divide it by ten. Initialise this to the
-     * big-endian form of the number.
-     */
-    workspace = snewn(x[0], BignumInt);
-    for (i = 0; i < (int)x[0]; i++)
-	workspace[i] = x[x[0] - i];
-
-    /*
-     * Next, write the decimal number starting with the last digit.
-     * We use ordinary short division, dividing 10 into the
-     * workspace.
-     */
-    ndigit = ndigits - 1;
-    ret[ndigit] = '\0';
-    do {
-	iszero = true;
-	carry = 0;
-	for (i = 0; i < (int)x[0]; i++) {
-            /*
-             * Conceptually, we want to compute
-             *
-             *   (carry << BIGNUM_INT_BITS) + workspace[i]
-             *   -----------------------------------------
-             *                      10
-             *
-             * but we don't have an integer type longer than BignumInt
-             * to work with. So we have to do it in pieces.
-             */
-
-            BignumInt q, r;
-            q = workspace[i] / 10;
-            r = workspace[i] % 10;
-
-            /* I want (BIGNUM_INT_MASK+1)/10 but can't say so directly! */
-            q += carry * ((BIGNUM_INT_MASK-9) / 10 + 1);
-            r += carry * ((BIGNUM_INT_MASK-9) % 10);
-
-            q += r / 10;
-            r %= 10;
-
-	    workspace[i] = q;
-	    carry = r;
-
-	    if (workspace[i])
-		iszero = false;
-	}
-	ret[--ndigit] = (char) (carry + '0');
-    } while (!iszero);
-
-    /*
-     * There's a chance we've fallen short of the start of the
-     * string. Correct if so.
-     */
-    if (ndigit > 0)
-	memmove(ret, ret + ndigit, ndigits - ndigit);
-
-    /*
-     * Done.
-     */
-    smemclr(workspace, x[0] * sizeof(*workspace));
-    sfree(workspace);
-    return ret;
-}

+ 1 - 1
source/putty/sshccp.c

@@ -30,7 +30,7 @@
  */
 
 #include "ssh.h"
-#include "sshbn.h"
+#include "mpint_i.h"
 
 #ifndef INLINE
 #define INLINE

+ 5 - 5
source/putty/sshcommon.c

@@ -7,6 +7,7 @@
 #include <stdlib.h>
 
 #include "putty.h"
+#include "mpint.h"
 #include "ssh.h"
 #include "sshbpp.h"
 #include "sshppl.h"
@@ -1008,13 +1009,12 @@ void ssh1_compute_session_id(
     struct RSAKey *hostkey, struct RSAKey *servkey)
 {
     struct MD5Context md5c;
-    int i;
 
     MD5Init(&md5c);
-    for (i = (bignum_bitcount(hostkey->modulus) + 7) / 8; i-- ;)
-        put_byte(&md5c, bignum_byte(hostkey->modulus, i));
-    for (i = (bignum_bitcount(servkey->modulus) + 7) / 8; i-- ;)
-        put_byte(&md5c, bignum_byte(servkey->modulus, i));
+    for (size_t i = (mp_get_nbits(hostkey->modulus) + 7) / 8; i-- ;)
+        put_byte(&md5c, mp_get_byte(hostkey->modulus, i));
+    for (size_t i = (mp_get_nbits(servkey->modulus) + 7) / 8; i-- ;)
+        put_byte(&md5c, mp_get_byte(servkey->modulus, i));
     put_data(&md5c, cookie, 8);
     MD5Final(session_id, &md5c);
 }

+ 71 - 122
source/putty/sshdh.c

@@ -2,61 +2,35 @@
  * Diffie-Hellman implementation for PuTTY.
  */
 
+#include <assert.h>
+
 #include "ssh.h"
+#include "misc.h"
+#include "mpint.h"
 
-/*
- * The primes used in the group1 and group14 key exchange.
- */
-static const unsigned char P1[] = {
-    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xC9, 0x0F, 0xDA, 0xA2,
-    0x21, 0x68, 0xC2, 0x34, 0xC4, 0xC6, 0x62, 0x8B, 0x80, 0xDC, 0x1C, 0xD1,
-    0x29, 0x02, 0x4E, 0x08, 0x8A, 0x67, 0xCC, 0x74, 0x02, 0x0B, 0xBE, 0xA6,
-    0x3B, 0x13, 0x9B, 0x22, 0x51, 0x4A, 0x08, 0x79, 0x8E, 0x34, 0x04, 0xDD,
-    0xEF, 0x95, 0x19, 0xB3, 0xCD, 0x3A, 0x43, 0x1B, 0x30, 0x2B, 0x0A, 0x6D,
-    0xF2, 0x5F, 0x14, 0x37, 0x4F, 0xE1, 0x35, 0x6D, 0x6D, 0x51, 0xC2, 0x45,
-    0xE4, 0x85, 0xB5, 0x76, 0x62, 0x5E, 0x7E, 0xC6, 0xF4, 0x4C, 0x42, 0xE9,
-    0xA6, 0x37, 0xED, 0x6B, 0x0B, 0xFF, 0x5C, 0xB6, 0xF4, 0x06, 0xB7, 0xED,
-    0xEE, 0x38, 0x6B, 0xFB, 0x5A, 0x89, 0x9F, 0xA5, 0xAE, 0x9F, 0x24, 0x11,
-    0x7C, 0x4B, 0x1F, 0xE6, 0x49, 0x28, 0x66, 0x51, 0xEC, 0xE6, 0x53, 0x81,
-    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
-};
-static const unsigned char P14[] = {
-    0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xC9, 0x0F, 0xDA, 0xA2,
-    0x21, 0x68, 0xC2, 0x34, 0xC4, 0xC6, 0x62, 0x8B, 0x80, 0xDC, 0x1C, 0xD1,
-    0x29, 0x02, 0x4E, 0x08, 0x8A, 0x67, 0xCC, 0x74, 0x02, 0x0B, 0xBE, 0xA6,
-    0x3B, 0x13, 0x9B, 0x22, 0x51, 0x4A, 0x08, 0x79, 0x8E, 0x34, 0x04, 0xDD,
-    0xEF, 0x95, 0x19, 0xB3, 0xCD, 0x3A, 0x43, 0x1B, 0x30, 0x2B, 0x0A, 0x6D,
-    0xF2, 0x5F, 0x14, 0x37, 0x4F, 0xE1, 0x35, 0x6D, 0x6D, 0x51, 0xC2, 0x45,
-    0xE4, 0x85, 0xB5, 0x76, 0x62, 0x5E, 0x7E, 0xC6, 0xF4, 0x4C, 0x42, 0xE9,
-    0xA6, 0x37, 0xED, 0x6B, 0x0B, 0xFF, 0x5C, 0xB6, 0xF4, 0x06, 0xB7, 0xED,
-    0xEE, 0x38, 0x6B, 0xFB, 0x5A, 0x89, 0x9F, 0xA5, 0xAE, 0x9F, 0x24, 0x11,
-    0x7C, 0x4B, 0x1F, 0xE6, 0x49, 0x28, 0x66, 0x51, 0xEC, 0xE4, 0x5B, 0x3D,
-    0xC2, 0x00, 0x7C, 0xB8, 0xA1, 0x63, 0xBF, 0x05, 0x98, 0xDA, 0x48, 0x36,
-    0x1C, 0x55, 0xD3, 0x9A, 0x69, 0x16, 0x3F, 0xA8, 0xFD, 0x24, 0xCF, 0x5F,
-    0x83, 0x65, 0x5D, 0x23, 0xDC, 0xA3, 0xAD, 0x96, 0x1C, 0x62, 0xF3, 0x56,
-    0x20, 0x85, 0x52, 0xBB, 0x9E, 0xD5, 0x29, 0x07, 0x70, 0x96, 0x96, 0x6D,
-    0x67, 0x0C, 0x35, 0x4E, 0x4A, 0xBC, 0x98, 0x04, 0xF1, 0x74, 0x6C, 0x08,
-    0xCA, 0x18, 0x21, 0x7C, 0x32, 0x90, 0x5E, 0x46, 0x2E, 0x36, 0xCE, 0x3B,
-    0xE3, 0x9E, 0x77, 0x2C, 0x18, 0x0E, 0x86, 0x03, 0x9B, 0x27, 0x83, 0xA2,
-    0xEC, 0x07, 0xA2, 0x8F, 0xB5, 0xC5, 0x5D, 0xF0, 0x6F, 0x4C, 0x52, 0xC9,
-    0xDE, 0x2B, 0xCB, 0xF6, 0x95, 0x58, 0x17, 0x18, 0x39, 0x95, 0x49, 0x7C,
-    0xEA, 0x95, 0x6A, 0xE5, 0x15, 0xD2, 0x26, 0x18, 0x98, 0xFA, 0x05, 0x10,
-    0x15, 0x72, 0x8E, 0x5A, 0x8A, 0xAC, 0xAA, 0x68, 0xFF, 0xFF, 0xFF, 0xFF,
-    0xFF, 0xFF, 0xFF, 0xFF
+struct dh_ctx {
+    mp_int *x, *e, *p, *q, *g;
 };
 
-/*
- * The generator g = 2 (used for both group1 and group14).
- */
-static const unsigned char G[] = { 2 };
-
 struct dh_extra {
-    const unsigned char *pdata, *gdata; /* NULL means group exchange */
-    int plen, glen;
+    bool gex;
+    void (*construct)(struct dh_ctx *ctx);
 };
 
+static void dh_group1_construct(struct dh_ctx *ctx)
+{
+    ctx->p = MP_LITERAL(0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF);
+    ctx->g = mp_from_integer(2);
+}
+
+static void dh_group14_construct(struct dh_ctx *ctx)
+{
+    ctx->p = MP_LITERAL(0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF);
+    ctx->g = mp_from_integer(2);
+}
+
 static const struct dh_extra extra_group1 = {
-    P1, G, lenof(P1), lenof(G),
+    false, dh_group1_construct,
 };
 
 static const struct ssh_kex ssh_diffiehellman_group1_sha1 = {
@@ -74,7 +48,7 @@ const struct ssh_kexes ssh_diffiehellman_group1 = {
 };
 
 static const struct dh_extra extra_group14 = {
-    P14, G, lenof(P14), lenof(G),
+    false, dh_group14_construct,
 };
 
 static const struct ssh_kex ssh_diffiehellman_group14_sha256 = {
@@ -97,9 +71,7 @@ const struct ssh_kexes ssh_diffiehellman_group14 = {
     group14_list
 };
 
-static const struct dh_extra extra_gex = {
-    NULL, NULL, 0, 0,
-};
+static const struct dh_extra extra_gex = { true };
 
 static const struct ssh_kex ssh_diffiehellman_gex_sha256 = {
     "diffie-hellman-group-exchange-sha256", NULL,
@@ -161,27 +133,19 @@ const struct ssh_kexes ssh_gssk5_sha1_kex = {
     gssk5_sha1_kex_list
 };
 
-/*
- * Variables.
- */
-struct dh_ctx {
-    Bignum x, e, p, q, qmask, g;
-};
-
 /*
  * Common DH initialisation.
  */
 static void dh_init(struct dh_ctx *ctx)
 {
-    ctx->q = bignum_rshift(ctx->p, 1);
-    ctx->qmask = bignum_bitmask(ctx->q);
+    ctx->q = mp_rshift_fixed(ctx->p, 1);
     ctx->x = ctx->e = NULL;
 }
 
 bool dh_is_gex(const struct ssh_kex *kex)
 {
     const struct dh_extra *extra = (const struct dh_extra *)kex->extra;
-    return extra->pdata == NULL;
+    return extra->gex;
 }
 
 /*
@@ -190,9 +154,9 @@ bool dh_is_gex(const struct ssh_kex *kex)
 struct dh_ctx *dh_setup_group(const struct ssh_kex *kex)
 {
     const struct dh_extra *extra = (const struct dh_extra *)kex->extra;
+    assert(!extra->gex);
     struct dh_ctx *ctx = snew(struct dh_ctx);
-    ctx->p = bignum_from_bytes(extra->pdata, extra->plen);
-    ctx->g = bignum_from_bytes(extra->gdata, extra->glen);
+    extra->construct(ctx);
     dh_init(ctx);
     return ctx;
 }
@@ -200,11 +164,11 @@ struct dh_ctx *dh_setup_group(const struct ssh_kex *kex)
 /*
  * Initialise DH for a server-supplied group.
  */
-struct dh_ctx *dh_setup_gex(Bignum pval, Bignum gval)
+struct dh_ctx *dh_setup_gex(mp_int *pval, mp_int *gval)
 {
     struct dh_ctx *ctx = snew(struct dh_ctx);
-    ctx->p = copybn(pval);
-    ctx->g = copybn(gval);
+    ctx->p = mp_copy(pval);
+    ctx->g = mp_copy(gval);
     dh_init(ctx);
     return ctx;
 }
@@ -214,7 +178,7 @@ struct dh_ctx *dh_setup_gex(Bignum pval, Bignum gval)
  */
 int dh_modulus_bit_size(const struct dh_ctx *ctx)
 {
-    return bignum_bitcount(ctx->p);
+    return mp_get_nbits(ctx->p);
 }
 
 /*
@@ -222,12 +186,11 @@ int dh_modulus_bit_size(const struct dh_ctx *ctx)
  */
 void dh_cleanup(struct dh_ctx *ctx)
 {
-    freebn(ctx->x);
-    freebn(ctx->e);
-    freebn(ctx->p);
-    freebn(ctx->g);
-    freebn(ctx->q);
-    freebn(ctx->qmask);
+    mp_free(ctx->x);
+    mp_free(ctx->e);
+    mp_free(ctx->p);
+    mp_free(ctx->g);
+    mp_free(ctx->q);
     sfree(ctx);
 }
 
@@ -246,49 +209,36 @@ void dh_cleanup(struct dh_ctx *ctx)
  * Advances in Cryptology: Proceedings of Eurocrypt '96
  * Springer-Verlag, May 1996.
  */
-Bignum dh_create_e(struct dh_ctx *ctx, int nbits)
+mp_int *dh_create_e(struct dh_ctx *ctx, int nbits)
 {
-    int i;
-
-    int nbytes;
-    unsigned char *buf;
-
-    nbytes = (bignum_bitcount(ctx->qmask) + 7) / 8;
-    buf = snewn(nbytes, unsigned char);
-
-    do {
-	/*
-	 * Create a potential x, by ANDing a string of random bytes
-	 * with qmask.
-	 */
-	if (ctx->x)
-	    freebn(ctx->x);
-	if (nbits == 0 || nbits > bignum_bitcount(ctx->qmask)) {
-	    for (i = 0; i < nbytes; i++)
-		buf[i] = bignum_byte(ctx->qmask, i) & random_byte();
-	    ctx->x = bignum_from_bytes(buf, nbytes);
-	} else {
-	    int b, nb;
-	    ctx->x = bn_power_2(nbits);
-	    b = nb = 0;
-	    for (i = 0; i < nbits; i++) {
-		if (nb == 0) {
-		    nb = 8;
-		    b = random_byte();
-		}
-		bignum_set_bit(ctx->x, i, b & 1);
-		b >>= 1;
-		nb--;
-	    }
-	}
-    } while (bignum_cmp(ctx->x, One) <= 0 || bignum_cmp(ctx->x, ctx->q) >= 0);
-
-    sfree(buf);
+    /*
+     * Lower limit is just 2.
+     */
+    mp_int *lo = mp_from_integer(2);
+
+    /*
+     * Upper limit.
+     */
+    mp_int *hi = mp_copy(ctx->q);
+    mp_sub_integer_into(hi, hi, 1);
+    if (nbits) {
+        mp_int *pow2 = mp_power_2(nbits+1);
+        mp_min_into(pow2, pow2, hi);
+        mp_free(hi);
+        hi = pow2;
+    }
+
+    /*
+     * Make a random number in that range.
+     */
+    ctx->x = mp_random_in_range(lo, hi);
+    mp_free(lo);
+    mp_free(hi);
 
     /*
-     * Done. Now compute e = g^x mod p.
+     * Now compute e = g^x mod p.
      */
-    ctx->e = modpow(ctx->g, ctx->x, ctx->p);
+    ctx->e = mp_modpow(ctx->g, ctx->x, ctx->p);
 
     return ctx->e;
 }
@@ -301,15 +251,16 @@ Bignum dh_create_e(struct dh_ctx *ctx, int nbits)
  * they lead to obviously weak keys that even a passive eavesdropper
  * can figure out.)
  */
-const char *dh_validate_f(struct dh_ctx *ctx, Bignum f)
+const char *dh_validate_f(struct dh_ctx *ctx, mp_int *f)
 {
-    if (bignum_cmp(f, One) <= 0) {
+    if (!mp_hs_integer(f, 2)) {
         return "f value received is too small";
     } else {
-        Bignum pm1 = bigsub(ctx->p, One);
-        int cmp = bignum_cmp(f, pm1);
-        freebn(pm1);
-        if (cmp >= 0)
+        mp_int *pm1 = mp_copy(ctx->p);
+        mp_sub_integer_into(pm1, pm1, 1);
+        unsigned cmp = mp_cmp_hs(f, pm1);
+        mp_free(pm1);
+        if (cmp)
             return "f value received is too large";
     }
     return NULL;
@@ -318,9 +269,7 @@ const char *dh_validate_f(struct dh_ctx *ctx, Bignum f)
 /*
  * DH stage 2: given a number f, compute K = f^x mod p.
  */
-Bignum dh_find_K(struct dh_ctx *ctx, Bignum f)
+mp_int *dh_find_K(struct dh_ctx *ctx, mp_int *f)
 {
-    Bignum ret;
-    ret = modpow(f, ctx->x, ctx->p);
-    return ret;
+    return mp_modpow(f, ctx->x, ctx->p);
 }

+ 88 - 98
source/putty/sshdss.c

@@ -7,6 +7,7 @@
 #include <assert.h>
 
 #include "ssh.h"
+#include "mpint.h"
 #include "misc.h"
 
 static void dss_freekey(ssh_key *key);    /* forward reference */
@@ -29,7 +30,7 @@ static ssh_key *dss_new_pub(const ssh_keyalg *self, ptrlen data)
     dss->x = NULL;
 
     if (get_err(src) ||
-        !bignum_cmp(dss->q, Zero) || !bignum_cmp(dss->p, Zero)) {
+        mp_eq_integer(dss->p, 0) || mp_eq_integer(dss->q, 0)) {
         /* Invalid key. */
         dss_freekey(&dss->sshk);
         return NULL;
@@ -42,29 +43,28 @@ static void dss_freekey(ssh_key *key)
 {
     struct dss_key *dss = container_of(key, struct dss_key, sshk);
     if (dss->p)
-        freebn(dss->p);
+        mp_free(dss->p);
     if (dss->q)
-        freebn(dss->q);
+        mp_free(dss->q);
     if (dss->g)
-        freebn(dss->g);
+        mp_free(dss->g);
     if (dss->y)
-        freebn(dss->y);
+        mp_free(dss->y);
     if (dss->x)
-        freebn(dss->x);
+        mp_free(dss->x);
     sfree(dss);
 }
 
-static void append_hex_to_strbuf(strbuf *sb, Bignum *x)
+static void append_hex_to_strbuf(strbuf *sb, mp_int *x)
 {
     if (sb->len > 0)
         put_byte(sb, ',');
     put_data(sb, "0x", 2);
-    int nibbles = (3 + bignum_bitcount(x)) / 4;
-    if (nibbles < 1)
-	nibbles = 1;
-    static const char hex[] = "0123456789abcdef";
-    for (int i = nibbles; i--;)
-	put_byte(sb, hex[(bignum_byte(x, i / 2) >> (4 * (i % 2))) & 0xF]);
+    char *hex = mp_get_hex(x);
+    size_t hexlen = strlen(hex);
+    put_data(sb, hex, hexlen);
+    smemclr(hex, hexlen);
+    sfree(hex);
 }
 
 static char *dss_cache_str(ssh_key *key)
@@ -88,7 +88,6 @@ static bool dss_verify(ssh_key *key, ptrlen sig, ptrlen data)
     struct dss_key *dss = container_of(key, struct dss_key, sshk);
     BinarySource src[1];
     unsigned char hash[20];
-    Bignum r, s, w, gu1p, yu2p, gu1yu2p, u1, u2, sha, v;
     bool toret;
 
     if (!dss->p)
@@ -117,29 +116,29 @@ static bool dss_verify(ssh_key *key, ptrlen sig, ptrlen data)
     }
 
     /* Now we're sitting on a 40-byte string for sure. */
-    r = bignum_from_bytes(sig.ptr, 20);
-    s = bignum_from_bytes((const char *)sig.ptr + 20, 20);
+    mp_int *r = mp_from_bytes_be(make_ptrlen(sig.ptr, 20));
+    mp_int *s = mp_from_bytes_be(make_ptrlen((const char *)sig.ptr + 20, 20));
     if (!r || !s) {
         if (r)
-            freebn(r);
+            mp_free(r);
         if (s)
-            freebn(s);
+            mp_free(s);
 	return false;
     }
 
-    if (!bignum_cmp(s, Zero)) {
-        freebn(r);
-        freebn(s);
+    if (mp_eq_integer(s, 0)) {
+        mp_free(r);
+        mp_free(s);
         return false;
     }
 
     /*
      * Step 1. w <- s^-1 mod q.
      */
-    w = modinv(s, dss->q);
+    mp_int *w = mp_invert(s, dss->q);
     if (!w) {
-        freebn(r);
-        freebn(s);
+        mp_free(r);
+        mp_free(s);
         return false;
     }
 
@@ -147,38 +146,38 @@ static bool dss_verify(ssh_key *key, ptrlen sig, ptrlen data)
      * Step 2. u1 <- SHA(message) * w mod q.
      */
     SHA_Simple(data.ptr, data.len, hash);
-    sha = bignum_from_bytes(hash, 20);
-    u1 = modmul(sha, w, dss->q);
+    mp_int *sha = mp_from_bytes_be(make_ptrlen(hash, 20));
+    mp_int *u1 = mp_modmul(sha, w, dss->q);
 
     /*
      * Step 3. u2 <- r * w mod q.
      */
-    u2 = modmul(r, w, dss->q);
+    mp_int *u2 = mp_modmul(r, w, dss->q);
 
     /*
      * Step 4. v <- (g^u1 * y^u2 mod p) mod q.
      */
-    gu1p = modpow(dss->g, u1, dss->p);
-    yu2p = modpow(dss->y, u2, dss->p);
-    gu1yu2p = modmul(gu1p, yu2p, dss->p);
-    v = modmul(gu1yu2p, One, dss->q);
+    mp_int *gu1p = mp_modpow(dss->g, u1, dss->p);
+    mp_int *yu2p = mp_modpow(dss->y, u2, dss->p);
+    mp_int *gu1yu2p = mp_modmul(gu1p, yu2p, dss->p);
+    mp_int *v = mp_mod(gu1yu2p, dss->q);
 
     /*
      * Step 5. v should now be equal to r.
      */
 
-    toret = !bignum_cmp(v, r);
+    toret = mp_cmp_eq(v, r);
 
-    freebn(w);
-    freebn(sha);
-    freebn(u1);
-    freebn(u2);
-    freebn(gu1p);
-    freebn(yu2p);
-    freebn(gu1yu2p);
-    freebn(v);
-    freebn(r);
-    freebn(s);
+    mp_free(w);
+    mp_free(sha);
+    mp_free(u1);
+    mp_free(u2);
+    mp_free(gu1p);
+    mp_free(yu2p);
+    mp_free(gu1yu2p);
+    mp_free(v);
+    mp_free(r);
+    mp_free(s);
 
     return toret;
 }
@@ -209,7 +208,7 @@ static ssh_key *dss_new_priv(const ssh_keyalg *self, ptrlen pub, ptrlen priv)
     ptrlen hash;
     SHA_State s;
     unsigned char digest[20];
-    Bignum ytest;
+    mp_int *ytest;
 
     sshk = dss_new_pub(self, pub);
     if (!sshk)
@@ -233,7 +232,7 @@ static ssh_key *dss_new_priv(const ssh_keyalg *self, ptrlen pub, ptrlen priv)
 	put_mp_ssh2(&s, dss->q);
 	put_mp_ssh2(&s, dss->g);
 	SHA_Final(&s, digest);
-	if (0 != memcmp(hash.ptr, digest, 20)) {
+	if (!smemeq(hash.ptr, digest, 20)) {
 	    dss_freekey(&dss->sshk);
 	    return NULL;
 	}
@@ -242,13 +241,13 @@ static ssh_key *dss_new_priv(const ssh_keyalg *self, ptrlen pub, ptrlen priv)
     /*
      * Now ensure g^x mod p really is y.
      */
-    ytest = modpow(dss->g, dss->x, dss->p);
-    if (0 != bignum_cmp(ytest, dss->y)) {
+    ytest = mp_modpow(dss->g, dss->x, dss->p);
+    if (!mp_cmp_eq(ytest, dss->y)) {
+        mp_free(ytest);
 	dss_freekey(&dss->sshk);
-        freebn(ytest);
 	return NULL;
     }
-    freebn(ytest);
+    mp_free(ytest);
 
     return &dss->sshk;
 }
@@ -268,7 +267,7 @@ static ssh_key *dss_new_priv_openssh(const ssh_keyalg *self,
     dss->x = get_mp_ssh2(src);
 
     if (get_err(src) ||
-        !bignum_cmp(dss->q, Zero) || !bignum_cmp(dss->p, Zero)) {
+        mp_eq_integer(dss->q, 0) || mp_eq_integer(dss->p, 0)) {
         /* Invalid key. */
         dss_freekey(&dss->sshk);
         return NULL;
@@ -299,14 +298,15 @@ static int dss_pubkey_bits(const ssh_keyalg *self, ptrlen pub)
         return -1;
 
     dss = container_of(sshk, struct dss_key, sshk);
-    ret = bignum_bitcount(dss->p);
+    ret = mp_get_nbits(dss->p);
     dss_freekey(&dss->sshk);
 
     return ret;
 }
 
-Bignum *dss_gen_k(const char *id_string, Bignum modulus, Bignum private_key,
-                  unsigned char *digest, int digest_len)
+mp_int *dss_gen_k(const char *id_string, mp_int *modulus,
+                     mp_int *private_key,
+                     unsigned char *digest, int digest_len)
 {
     /*
      * The basic DSS signing algorithm is:
@@ -381,7 +381,6 @@ Bignum *dss_gen_k(const char *id_string, Bignum modulus, Bignum private_key,
      */
     SHA512_State ss;
     unsigned char digest512[64];
-    Bignum proto_k, k;
 
     /*
      * Hash some identifying text plus x.
@@ -397,72 +396,63 @@ Bignum *dss_gen_k(const char *id_string, Bignum modulus, Bignum private_key,
     SHA512_Init(&ss);
     put_data(&ss, digest512, sizeof(digest512));
     put_data(&ss, digest, digest_len);
+    SHA512_Final(&ss, digest512);
 
-    while (1) {
-        SHA512_State ss2 = ss;         /* structure copy */
-        SHA512_Final(&ss2, digest512);
-
-        smemclr(&ss2, sizeof(ss2));
-
-        /*
-         * Now convert the result into a bignum, and reduce it mod q.
-         */
-        proto_k = bignum_from_bytes(digest512, 64);
-        k = bigmod(proto_k, modulus);
-        freebn(proto_k);
-
-        if (bignum_cmp(k, One) != 0 && bignum_cmp(k, Zero) != 0) {
-            smemclr(&ss, sizeof(ss));
-            smemclr(digest512, sizeof(digest512));
-            return k;
-        }
-
-        /* Very unlikely we get here, but if so, k was unsuitable. */
-        freebn(k);
-        /* Perturb the hash to think of a different k. */
-        put_byte(&ss, 'x');
-        /* Go round and try again. */
-    }
+    /*
+     * Now convert the result into a bignum, and coerce it to the
+     * range [2,q), which we do by reducing it mod q-2 and adding 2.
+     */
+    mp_int *modminus2 = mp_copy(modulus);
+    mp_sub_integer_into(modminus2, modminus2, 2);
+    mp_int *proto_k = mp_from_bytes_be(make_ptrlen(digest512, 64));
+    mp_int *k = mp_mod(proto_k, modminus2);
+    mp_free(proto_k);
+    mp_free(modminus2);
+    mp_add_integer_into(k, k, 2);
+
+    smemclr(&ss, sizeof(ss));
+    smemclr(digest512, sizeof(digest512));
+
+    return k;
 }
 
 static void dss_sign(ssh_key *key, const void *data, int datalen,
                      unsigned flags, BinarySink *bs)
 {
     struct dss_key *dss = container_of(key, struct dss_key, sshk);
-    Bignum k, gkp, hash, kinv, hxr, r, s;
     unsigned char digest[20];
     int i;
 
     SHA_Simple(data, datalen, digest);
 
-    k = dss_gen_k("DSA deterministic k generator", dss->q, dss->x,
-                  digest, sizeof(digest));
-    kinv = modinv(k, dss->q);	       /* k^-1 mod q */
-    assert(kinv);
+    mp_int *k = dss_gen_k("DSA deterministic k generator", dss->q, dss->x,
+                          digest, sizeof(digest));
+    mp_int *kinv = mp_invert(k, dss->q);       /* k^-1 mod q */
 
     /*
      * Now we have k, so just go ahead and compute the signature.
      */
-    gkp = modpow(dss->g, k, dss->p);   /* g^k mod p */
-    r = bigmod(gkp, dss->q);	       /* r = (g^k mod p) mod q */
-    freebn(gkp);
-
-    hash = bignum_from_bytes(digest, 20);
-    hxr = bigmuladd(dss->x, r, hash);  /* hash + x*r */
-    s = modmul(kinv, hxr, dss->q);     /* s = k^-1 * (hash + x*r) mod q */
-    freebn(hxr);
-    freebn(kinv);
-    freebn(k);
-    freebn(hash);
+    mp_int *gkp = mp_modpow(dss->g, k, dss->p); /* g^k mod p */
+    mp_int *r = mp_mod(gkp, dss->q);        /* r = (g^k mod p) mod q */
+    mp_free(gkp);
+
+    mp_int *hash = mp_from_bytes_be(make_ptrlen(digest, 20));
+    mp_int *hxr = mp_mul(dss->x, r);
+    mp_add_into(hxr, hxr, hash);         /* hash + x*r */
+    mp_int *s = mp_modmul(kinv, hxr, dss->q); /* s = k^-1 * (hash+x*r) mod q */
+    mp_free(hxr);
+    mp_free(kinv);
+    mp_free(k);
+    mp_free(hash);
 
     put_stringz(bs, "ssh-dss");
     put_uint32(bs, 40);
     for (i = 0; i < 20; i++)
-	put_byte(bs, bignum_byte(r, 19 - i));
+	put_byte(bs, mp_get_byte(r, 19 - i));
     for (i = 0; i < 20; i++)
-        put_byte(bs, bignum_byte(s, 19 - i));
-    freebn(r);
-    freebn(s);
+        put_byte(bs, mp_get_byte(s, 19 - i));
+    mp_free(r);
+    mp_free(s);
 }
 
 const ssh_keyalg ssh_dss = {

File diff suppressed because it is too large
+ 788 - 2062
source/putty/sshecc.c


+ 9 - 9
source/putty/sshpubk.c

@@ -10,6 +10,7 @@
 #include <assert.h>
 
 #include "putty.h"
+#include "mpint.h"
 #include "ssh.h"
 #include "misc.h"
 
@@ -276,11 +277,11 @@ int rsa_ssh1_loadpub(const Filename *filename, BinarySink *bs,
         }
 
 	memset(&key, 0, sizeof(key));
-        key.exponent = bignum_from_decimal(expp);
-        key.modulus = bignum_from_decimal(modp);
-        if (atoi(bitsp) != bignum_bitcount(key.modulus)) {
-            freebn(key.exponent);
-            freebn(key.modulus);
+        key.exponent = mp_from_decimal(expp);
+        key.modulus = mp_from_decimal(modp);
+        if (atoi(bitsp) != mp_get_nbits(key.modulus)) {
+            mp_free(key.exponent);
+            mp_free(key.modulus);
             sfree(line);
             error = "key bit count does not match in SSH-1 public key file";
             goto end;
@@ -1360,10 +1361,9 @@ char *ssh1_pubkey_str(struct RSAKey *key)
     char *buffer;
     char *dec1, *dec2;
 
-    dec1 = bignum_decimal(key->exponent);
-    dec2 = bignum_decimal(key->modulus);
-    buffer = dupprintf("%d %s %s%s%s", bignum_bitcount(key->modulus),
-		       dec1, dec2,
+    dec1 = mp_get_decimal(key->exponent);
+    dec2 = mp_get_decimal(key->modulus);
+    buffer = dupprintf("%zd %s %s%s%s", mp_get_nbits(key->modulus), dec1, dec2,
                        key->comment ? " " : "",
                        key->comment ? key->comment : "");
     sfree(dec1);

+ 137 - 274
source/putty/sshrsa.c

@@ -8,13 +8,14 @@
 #include <assert.h>
 
 #include "ssh.h"
+#include "mpint.h"
 #include "misc.h"
 
 void BinarySource_get_rsa_ssh1_pub(
     BinarySource *src, struct RSAKey *rsa, RsaSsh1Order order)
 {
     unsigned bits;
-    Bignum e, m;
+    mp_int *e, *m;
 
     bits = get_uint32(src);
     if (order == RSA_SSH1_EXPONENT_FIRST) {
@@ -29,10 +30,10 @@ void BinarySource_get_rsa_ssh1_pub(
         rsa->bits = bits;
         rsa->exponent = e;
         rsa->modulus = m;
-        rsa->bytes = (bignum_bitcount(m) + 7) / 8;
+        rsa->bytes = (mp_get_nbits(m) + 7) / 8;
     } else {
-        freebn(e);
-        freebn(m);
+        mp_free(e);
+        mp_free(m);
     }
 }
 
@@ -44,7 +45,7 @@ void BinarySource_get_rsa_ssh1_priv(
 
 bool rsa_ssh1_encrypt(unsigned char *data, int length, struct RSAKey *key)
 {
-    Bignum b1, b2;
+    mp_int *b1, *b2;
     int i;
     unsigned char *p;
 
@@ -62,17 +63,17 @@ bool rsa_ssh1_encrypt(unsigned char *data, int length, struct RSAKey *key)
     }
     data[key->bytes - length - 1] = 0;
 
-    b1 = bignum_from_bytes(data, key->bytes);
+    b1 = mp_from_bytes_be(make_ptrlen(data, key->bytes));
 
-    b2 = modpow(b1, key->exponent, key->modulus);
+    b2 = mp_modpow(b1, key->exponent, key->modulus);
 
     p = data;
     for (i = key->bytes; i--;) {
-	*p++ = bignum_byte(b2, i);
+	*p++ = mp_get_byte(b2, i);
     }
 
-    freebn(b1);
-    freebn(b2);
+    mp_free(b1);
+    mp_free(b2);
 
     return true;
 }
@@ -83,28 +84,33 @@ bool rsa_ssh1_encrypt(unsigned char *data, int length, struct RSAKey *key)
  * Uses Chinese Remainder Theorem to speed computation up over the
  * obvious implementation of a single big modpow.
  */
-Bignum crt_modpow(Bignum base, Bignum exp, Bignum mod,
-                  Bignum p, Bignum q, Bignum iqmp)
+mp_int *crt_modpow(mp_int *base, mp_int *exp, mp_int *mod,
+                      mp_int *p, mp_int *q, mp_int *iqmp)
 {
-    Bignum pm1, qm1, pexp, qexp, presult, qresult, diff, multiplier, ret0, ret;
+    mp_int *pm1, *qm1, *pexp, *qexp, *presult, *qresult;
+    mp_int *diff, *multiplier, *ret0, *ret;
 
     /*
      * Reduce the exponent mod phi(p) and phi(q), to save time when
      * exponentiating mod p and mod q respectively. Of course, since p
      * and q are prime, phi(p) == p-1 and similarly for q.
      */
-    pm1 = copybn(p);
-    decbn(pm1);
-    qm1 = copybn(q);
-    decbn(qm1);
-    pexp = bigmod(exp, pm1);
-    qexp = bigmod(exp, qm1);
+    pm1 = mp_copy(p);
+    mp_sub_integer_into(pm1, pm1, 1);
+    qm1 = mp_copy(q);
+    mp_sub_integer_into(qm1, qm1, 1);
+    pexp = mp_mod(exp, pm1);
+    qexp = mp_mod(exp, qm1);
 
     /*
      * Do the two modpows.
      */
-    presult = modpow(base, pexp, p);
-    qresult = modpow(base, qexp, q);
+    mp_int *base_mod_p = mp_mod(base, p);
+    presult = mp_modpow(base_mod_p, pexp, p);
+    mp_free(base_mod_p);
+    mp_int *base_mod_q = mp_mod(base, q);
+    qresult = mp_modpow(base_mod_q, qexp, q);
+    mp_free(base_mod_q);
 
     /*
      * Recombine the results. We want a value which is congruent to
@@ -115,189 +121,66 @@ Bignum crt_modpow(Bignum base, Bignum exp, Bignum mod,
      * (which is congruent to qresult mod both primes), and add on
      * (presult-qresult) * (iqmp * q) which adjusts it to be congruent
      * to presult mod p without affecting its value mod q.
+     *
+     * (If presult-qresult < 0, we add p to it to keep it positive.)
      */
-    if (bignum_cmp(presult, qresult) < 0) {
-        /*
-         * Can't subtract presult from qresult without first adding on
-         * p.
-         */
-        Bignum tmp = presult;
-        presult = bigadd(presult, p);
-        freebn(tmp);
-    }
-    diff = bigsub(presult, qresult);
-    multiplier = bigmul(iqmp, q);
-    ret0 = bigmuladd(multiplier, diff, qresult);
+    unsigned presult_too_small = mp_cmp_hs(qresult, presult);
+    mp_cond_add_into(presult, presult, p, presult_too_small);
+
+    diff = mp_sub(presult, qresult);
+    multiplier = mp_mul(iqmp, q);
+    ret0 = mp_mul(multiplier, diff);
+    mp_add_into(ret0, ret0, qresult);
 
     /*
      * Finally, reduce the result mod n.
      */
-    ret = bigmod(ret0, mod);
+    ret = mp_mod(ret0, mod);
 
     /*
      * Free all the intermediate results before returning.
      */
-    freebn(pm1);
-    freebn(qm1);
-    freebn(pexp);
-    freebn(qexp);
-    freebn(presult);
-    freebn(qresult);
-    freebn(diff);
-    freebn(multiplier);
-    freebn(ret0);
+    mp_free(pm1);
+    mp_free(qm1);
+    mp_free(pexp);
+    mp_free(qexp);
+    mp_free(presult);
+    mp_free(qresult);
+    mp_free(diff);
+    mp_free(multiplier);
+    mp_free(ret0);
 
     return ret;
 }
 
 /*
- * This function is a wrapper on modpow(). It has the same effect as
- * modpow(), but employs RSA blinding to protect against timing
- * attacks and also uses the Chinese Remainder Theorem (implemented
- * above, in crt_modpow()) to speed up the main operation.
+ * Wrapper on crt_modpow that looks up all the right values from an
+ * RSAKey.
  */
-static Bignum rsa_privkey_op(Bignum input, struct RSAKey *key)
+static mp_int *rsa_privkey_op(mp_int *input, struct RSAKey *key)
 {
-    Bignum random, random_encrypted, random_inverse;
-    Bignum input_blinded, ret_blinded;
-    Bignum ret;
-
-    SHA512_State ss;
-    unsigned char digest512[64];
-    int digestused = lenof(digest512);
-    int hashseq = 0;
-
-    /*
-     * Start by inventing a random number chosen uniformly from the
-     * range 2..modulus-1. (We do this by preparing a random number
-     * of the right length and retrying if it's greater than the
-     * modulus, to prevent any potential Bleichenbacher-like
-     * attacks making use of the uneven distribution within the
-     * range that would arise from just reducing our number mod n.
-     * There are timing implications to the potential retries, of
-     * course, but all they tell you is the modulus, which you
-     * already knew.)
-     * 
-     * To preserve determinism and avoid Pageant needing to share
-     * the random number pool, we actually generate this `random'
-     * number by hashing stuff with the private key.
-     */
-    while (1) {
-	int bits, byte, bitsleft, v;
-	random = copybn(key->modulus);
-	/*
-	 * Find the topmost set bit. (This function will return its
-	 * index plus one.) Then we'll set all bits from that one
-	 * downwards randomly.
-	 */
-	bits = bignum_bitcount(random);
-	byte = 0;
-	bitsleft = 0;
-	while (bits--) {
-	    if (bitsleft <= 0) {
-		bitsleft = 8;
-		/*
-		 * Conceptually the following few lines are equivalent to
-		 *    byte = random_byte();
-		 */
-		if (digestused >= lenof(digest512)) {
-		    SHA512_Init(&ss);
-		    put_data(&ss, "RSA deterministic blinding", 26);
-		    put_uint32(&ss, hashseq);
-		    put_mp_ssh2(&ss, key->private_exponent);
-		    SHA512_Final(&ss, digest512);
-		    hashseq++;
-
-		    /*
-		     * Now hash that digest plus the signature
-		     * input.
-		     */
-		    SHA512_Init(&ss);
-		    put_data(&ss, digest512, sizeof(digest512));
-		    put_mp_ssh2(&ss, input);
-		    SHA512_Final(&ss, digest512);
-
-		    digestused = 0;
-		}
-		byte = digest512[digestused++];
-	    }
-	    v = byte & 1;
-	    byte >>= 1;
-	    bitsleft--;
-	    bignum_set_bit(random, bits, v);
-	}
-        bn_restore_invariant(random);
-
-	/*
-	 * Now check that this number is strictly greater than
-	 * zero, and strictly less than modulus.
-	 */
-	if (bignum_cmp(random, Zero) <= 0 ||
-	    bignum_cmp(random, key->modulus) >= 0) {
-	    freebn(random);
-	    continue;
-	}
-
-        /*
-         * Also, make sure it has an inverse mod modulus.
-         */
-        random_inverse = modinv(random, key->modulus);
-        if (!random_inverse) {
-	    freebn(random);
-	    continue;
-        }
-
-        break;
-    }
-
-    /*
-     * RSA blinding relies on the fact that (xy)^d mod n is equal
-     * to (x^d mod n) * (y^d mod n) mod n. We invent a random pair
-     * y and y^d; then we multiply x by y, raise to the power d mod
-     * n as usual, and divide by y^d to recover x^d. Thus an
-     * attacker can't correlate the timing of the modpow with the
-     * input, because they don't know anything about the number
-     * that was input to the actual modpow.
-     * 
-     * The clever bit is that we don't have to do a huge modpow to
-     * get y and y^d; we will use the number we just invented as
-     * _y^d_, and use the _public_ exponent to compute (y^d)^e = y
-     * from it, which is much faster to do.
-     */
-    random_encrypted = crt_modpow(random, key->exponent,
-                                  key->modulus, key->p, key->q, key->iqmp);
-    input_blinded = modmul(input, random_encrypted, key->modulus);
-    ret_blinded = crt_modpow(input_blinded, key->private_exponent,
-                             key->modulus, key->p, key->q, key->iqmp);
-    ret = modmul(ret_blinded, random_inverse, key->modulus);
-
-    freebn(ret_blinded);
-    freebn(input_blinded);
-    freebn(random_inverse);
-    freebn(random_encrypted);
-    freebn(random);
-
-    return ret;
+    return crt_modpow(input, key->private_exponent,
+                      key->modulus, key->p, key->q, key->iqmp);
 }
 
-Bignum rsa_ssh1_decrypt(Bignum input, struct RSAKey *key)
+mp_int *rsa_ssh1_decrypt(mp_int *input, struct RSAKey *key)
 {
     return rsa_privkey_op(input, key);
 }
 
-bool rsa_ssh1_decrypt_pkcs1(Bignum input, struct RSAKey *key, strbuf *outbuf)
+bool rsa_ssh1_decrypt_pkcs1(mp_int *input, struct RSAKey *key,
+                            strbuf *outbuf)
 {
     strbuf *data = strbuf_new();
     bool success = false;
     BinarySource src[1];
 
     {
-        Bignum *b = rsa_ssh1_decrypt(input, key);
-        int i;
-        for (i = (bignum_bitcount(key->modulus) + 7) / 8; i-- > 0 ;) {
-            put_byte(data, bignum_byte(b, i));
+        mp_int *b = rsa_ssh1_decrypt(input, key);
+        for (size_t i = (mp_get_nbits(key->modulus) + 7) / 8; i-- > 0 ;) {
+            put_byte(data, mp_get_byte(b, i));
         }
-        freebn(b);
+        mp_free(b);
     }
 
     BinarySource_BARE_INIT(src, data->u, data->len);
@@ -321,17 +204,16 @@ bool rsa_ssh1_decrypt_pkcs1(Bignum input, struct RSAKey *key, strbuf *outbuf)
     return success;
 }
 
-static void append_hex_to_strbuf(strbuf *sb, Bignum *x)
+static void append_hex_to_strbuf(strbuf *sb, mp_int *x)
 {
     if (sb->len > 0)
         put_byte(sb, ',');
     put_data(sb, "0x", 2);
-    int nibbles = (3 + bignum_bitcount(x)) / 4;
-    if (nibbles < 1)
-	nibbles = 1;
-    static const char hex[] = "0123456789abcdef";
-    for (int i = nibbles; i--;)
-	put_byte(sb, hex[(bignum_byte(x, i / 2) >> (4 * (i % 2))) & 0xF]);
+    char *hex = mp_get_hex(x);
+    size_t hexlen = strlen(hex);
+    put_data(sb, hex, hexlen);
+    smemclr(hex, hexlen);
+    sfree(hex);
 }
 
 char *rsastr_fmt(struct RSAKey *key)
@@ -361,7 +243,7 @@ char *rsa_ssh1_fingerprint(struct RSAKey *key)
     MD5Final(digest, &md5c);
 
     out = strbuf_new();
-    strbuf_catf(out, "%d ", bignum_bitcount(key->modulus));
+    strbuf_catf(out, "%d ", mp_get_nbits(key->modulus));
     for (i = 0; i < 16; i++)
 	strbuf_catf(out, "%s%02x", i ? ":" : "", digest[i]);
     if (key->comment)
@@ -376,34 +258,32 @@ char *rsa_ssh1_fingerprint(struct RSAKey *key)
  */
 bool rsa_verify(struct RSAKey *key)
 {
-    Bignum n, ed, pm1, qm1;
-    int cmp;
+    mp_int *n, *ed, *pm1, *qm1;
+    unsigned ok = 1;
+
+    /* Preliminary checks: p,q must actually be nonzero. */
+    if (mp_eq_integer(key->p, 0) | mp_eq_integer(key->q, 0))
+        return false;
 
     /* n must equal pq. */
-    n = bigmul(key->p, key->q);
-    cmp = bignum_cmp(n, key->modulus);
-    freebn(n);
-    if (cmp != 0)
-	return false;
+    n = mp_mul(key->p, key->q);
+    ok &= mp_cmp_eq(n, key->modulus);
+    mp_free(n);
 
     /* e * d must be congruent to 1, modulo (p-1) and modulo (q-1). */
-    pm1 = copybn(key->p);
-    decbn(pm1);
-    ed = modmul(key->exponent, key->private_exponent, pm1);
-    freebn(pm1);
-    cmp = bignum_cmp(ed, One);
-    freebn(ed);
-    if (cmp != 0)
-	return false;
-
-    qm1 = copybn(key->q);
-    decbn(qm1);
-    ed = modmul(key->exponent, key->private_exponent, qm1);
-    freebn(qm1);
-    cmp = bignum_cmp(ed, One);
-    freebn(ed);
-    if (cmp != 0)
-	return false;
+    pm1 = mp_copy(key->p);
+    mp_sub_integer_into(pm1, pm1, 1);
+    ed = mp_modmul(key->exponent, key->private_exponent, pm1);
+    mp_free(pm1);
+    ok &= mp_eq_integer(ed, 1);
+    mp_free(ed);
+
+    qm1 = mp_copy(key->q);
+    mp_sub_integer_into(qm1, qm1, 1);
+    ed = mp_modmul(key->exponent, key->private_exponent, qm1);
+    mp_free(qm1);
+    ok &= mp_eq_integer(ed, 1);
+    mp_free(ed);
 
     /*
      * Ensure p > q.
@@ -413,33 +293,18 @@ bool rsa_verify(struct RSAKey *key)
      * should instead flip them round into the canonical order of
      * p > q. This also involves regenerating iqmp.
      */
-    if (bignum_cmp(key->p, key->q) <= 0) {
-	Bignum tmp = key->p;
-	key->p = key->q;
-	key->q = tmp;
-
-	freebn(key->iqmp);
-	key->iqmp = modinv(key->q, key->p);
-        if (!key->iqmp)
-            return false;
-    }
+    unsigned swap_pq = mp_cmp_hs(key->q, key->p);
+    mp_cond_swap(key->p, key->q, swap_pq);
+    mp_free(key->iqmp);
+    key->iqmp = mp_invert(key->q, key->p);
 
-    /*
-     * Ensure iqmp * q is congruent to 1, modulo p.
-     */
-    n = modmul(key->iqmp, key->q, key->p);
-    cmp = bignum_cmp(n, One);
-    freebn(n);
-    if (cmp != 0)
-	return false;
-
-    return true;
+    return ok;
 }
 
 void rsa_ssh1_public_blob(BinarySink *bs, struct RSAKey *key,
                           RsaSsh1Order order)
 {
-    put_uint32(bs, bignum_bitcount(key->modulus));
+    put_uint32(bs, mp_get_nbits(key->modulus));
     if (order == RSA_SSH1_EXPONENT_FIRST) {
         put_mp_ssh1(bs, key->exponent);
         put_mp_ssh1(bs, key->modulus);
@@ -459,8 +324,8 @@ int rsa_ssh1_public_blob_len(void *data, int maxlen)
     /* Expect a length word, then exponent and modulus. (It doesn't
      * even matter which order.) */
     get_uint32(src);
-    freebn(get_mp_ssh1(src));
-    freebn(get_mp_ssh1(src));
+    mp_free(get_mp_ssh1(src));
+    mp_free(get_mp_ssh1(src));
 
     if (get_err(src))
 	return -1;
@@ -472,19 +337,19 @@ int rsa_ssh1_public_blob_len(void *data, int maxlen)
 void freersapriv(struct RSAKey *key)
 {
     if (key->private_exponent) {
-	freebn(key->private_exponent);
+	mp_free(key->private_exponent);
         key->private_exponent = NULL;
     }
     if (key->p) {
-	freebn(key->p);
+	mp_free(key->p);
         key->p = NULL;
     }
     if (key->q) {
-	freebn(key->q);
+	mp_free(key->q);
         key->q = NULL;
     }
     if (key->iqmp) {
-	freebn(key->iqmp);
+	mp_free(key->iqmp);
         key->iqmp = NULL;
     }
 }
@@ -493,11 +358,11 @@ void freersakey(struct RSAKey *key)
 {
     freersapriv(key);
     if (key->modulus) {
-	freebn(key->modulus);
+	mp_free(key->modulus);
         key->modulus = NULL;
     }
     if (key->exponent) {
-	freebn(key->exponent);
+	mp_free(key->exponent);
         key->exponent = NULL;
     }
     if (key->comment) {
@@ -642,7 +507,7 @@ static int rsa2_pubkey_bits(const ssh_keyalg *self, ptrlen pub)
         return -1;
 
     rsa = container_of(sshk, struct RSAKey, sshk);
-    ret = bignum_bitcount(rsa->modulus);
+    ret = mp_get_nbits(rsa->modulus);
     rsa2_freekey(&rsa->sshk);
 
     return ret;
@@ -738,8 +603,7 @@ static bool rsa2_verify(ssh_key *key, ptrlen sig, ptrlen data)
     struct RSAKey *rsa = container_of(key, struct RSAKey, sshk);
     BinarySource src[1];
     ptrlen type, in_pl;
-    Bignum in, out;
-    bool toret;
+    mp_int *in, *out;
 
     BinarySource_BARE_INIT(src, sig.ptr, sig.len);
     type = get_string(src);
@@ -751,28 +615,27 @@ static bool rsa2_verify(ssh_key *key, ptrlen sig, ptrlen data)
      * BUG_SSH2_RSA_PADDING at the other end, we tolerate it if it's
      * there.) So we can't use get_mp_ssh2, which enforces that
      * leading-byte scheme; instead we use get_string and
-     * bignum_from_bytes, which will tolerate anything.
+     * mp_from_bytes_be, which will tolerate anything.
      */
     in_pl = get_string(src);
     if (get_err(src) || !ptrlen_eq_string(type, "ssh-rsa"))
 	return false;
 
-    in = bignum_from_bytes(in_pl.ptr, in_pl.len);
-    out = modpow(in, rsa->exponent, rsa->modulus);
-    freebn(in);
+    in = mp_from_bytes_be(in_pl);
+    out = mp_modpow(in, rsa->exponent, rsa->modulus);
+    mp_free(in);
 
-    toret = true;
+    unsigned diff = 0;
 
-    size_t nbytes = (bignum_bitcount(rsa->modulus) + 7) / 8;
+    size_t nbytes = (mp_get_nbits(rsa->modulus) + 7) / 8;
     unsigned char *bytes = rsa_pkcs1_signature_string(nbytes, &ssh_sha1, data);
     for (size_t i = 0; i < nbytes; i++)
-	if (bytes[nbytes-1 - i] != bignum_byte(out, i))
-	    toret = false;
+        diff |= bytes[nbytes-1 - i] ^ mp_get_byte(out, i);
     smemclr(bytes, nbytes);
     sfree(bytes);
-    freebn(out);
+    mp_free(out);
 
-    return toret;
+    return diff == 0;
 }
 
 static void rsa2_sign(ssh_key *key, const void *data, int datalen,
@@ -780,8 +643,8 @@ static void rsa2_sign(ssh_key *key, const void *data, int datalen,
 {
     struct RSAKey *rsa = container_of(key, struct RSAKey, sshk);
     unsigned char *bytes;
-    int nbytes;
-    Bignum in, out;
+    size_t nbytes;
+    mp_int *in, *out;
     const struct ssh_hashalg *halg;
     const char *sign_alg_name;
 
@@ -796,24 +659,24 @@ static void rsa2_sign(ssh_key *key, const void *data, int datalen,
         sign_alg_name = "ssh-rsa";
     }
 
-    nbytes = (bignum_bitcount(rsa->modulus) + 7) / 8;
+    nbytes = (mp_get_nbits(rsa->modulus) + 7) / 8;
 
     bytes = rsa_pkcs1_signature_string(
         nbytes, halg, make_ptrlen(data, datalen));
-    in = bignum_from_bytes(bytes, nbytes);
+    in = mp_from_bytes_be(make_ptrlen(bytes, nbytes));
     smemclr(bytes, nbytes);
     sfree(bytes);
 
     out = rsa_privkey_op(in, rsa);
-    freebn(in);
+    mp_free(in);
 
     put_stringz(bs, sign_alg_name);
-    nbytes = (bignum_bitcount(out) + 7) / 8;
+    nbytes = (mp_get_nbits(out) + 7) / 8;
     put_uint32(bs, nbytes);
     for (size_t i = 0; i < nbytes; i++)
-	put_byte(bs, bignum_byte(out, nbytes - 1 - i));
+	put_byte(bs, mp_get_byte(out, nbytes - 1 - i));
 
-    freebn(out);
+    mp_free(out);
 }
 
 const ssh_keyalg ssh_rsa = {
@@ -852,7 +715,7 @@ void ssh_rsakex_freekey(struct RSAKey *key)
 
 int ssh_rsakex_klen(struct RSAKey *rsa)
 {
-    return bignum_bitcount(rsa->modulus);
+    return mp_get_nbits(rsa->modulus);
 }
 
 static void oaep_mask(const struct ssh_hashalg *h, void *seed, int seedlen,
@@ -885,7 +748,7 @@ void ssh_rsakex_encrypt(const struct ssh_hashalg *h,
                         unsigned char *in, int inlen,
                         unsigned char *out, int outlen, struct RSAKey *rsa)
 {
-    Bignum b1, b2;
+    mp_int *b1, *b2;
     int k, i;
     char *p;
     const int HLEN = h->hlen;
@@ -918,7 +781,7 @@ void ssh_rsakex_encrypt(const struct ssh_hashalg *h,
      */
 
     /* k denotes the length in octets of the RSA modulus. */
-    k = (7 + bignum_bitcount(rsa->modulus)) / 8;
+    k = (7 + mp_get_nbits(rsa->modulus)) / 8;
 
     /* The length of the input data must be at most k - 2hLen - 2. */
     assert(inlen > 0 && inlen <= k - 2*HLEN - 2);
@@ -961,24 +824,24 @@ void ssh_rsakex_encrypt(const struct ssh_hashalg *h,
      * Now `out' contains precisely the data we want to
      * RSA-encrypt.
      */
-    b1 = bignum_from_bytes(out, outlen);
-    b2 = modpow(b1, rsa->exponent, rsa->modulus);
+    b1 = mp_from_bytes_be(make_ptrlen(out, outlen));
+    b2 = mp_modpow(b1, rsa->exponent, rsa->modulus);
     p = (char *)out;
     for (i = outlen; i--;) {
-	*p++ = bignum_byte(b2, i);
+	*p++ = mp_get_byte(b2, i);
     }
-    freebn(b1);
-    freebn(b2);
+    mp_free(b1);
+    mp_free(b2);
 
     /*
      * And we're done.
      */
 }
 
-Bignum ssh_rsakex_decrypt(const struct ssh_hashalg *h, ptrlen ciphertext,
-                          struct RSAKey *rsa)
+mp_int *ssh_rsakex_decrypt(const struct ssh_hashalg *h, ptrlen ciphertext,
+                              struct RSAKey *rsa)
 {
-    Bignum b1, b2;
+    mp_int *b1, *b2;
     int outlen, i;
     unsigned char *out;
     unsigned char labelhash[64];
@@ -992,18 +855,18 @@ Bignum ssh_rsakex_decrypt(const struct ssh_hashalg *h, ptrlen ciphertext,
 
     /* The length of the encrypted data should be exactly the length
      * in octets of the RSA modulus.. */
-    outlen = (7 + bignum_bitcount(rsa->modulus)) / 8;
+    outlen = (7 + mp_get_nbits(rsa->modulus)) / 8;
     if (ciphertext.len != outlen)
         return NULL;
 
     /* Do the RSA decryption, and extract the result into a byte array. */
-    b1 = bignum_from_bytes(ciphertext.ptr, ciphertext.len);
+    b1 = mp_from_bytes_be(ciphertext);
     b2 = rsa_privkey_op(b1, rsa);
     out = snewn(outlen, unsigned char);
     for (i = 0; i < outlen; i++)
-        out[i] = bignum_byte(b2, outlen-1-i);
-    freebn(b1);
-    freebn(b2);
+        out[i] = mp_get_byte(b2, outlen-1-i);
+    mp_free(b1);
+    mp_free(b2);
 
     /* Do the OAEP masking operations, in the reverse order from encryption */
     oaep_mask(h, out+HLEN+1, outlen-HLEN-1, out+1, HLEN);
@@ -1038,7 +901,7 @@ Bignum ssh_rsakex_decrypt(const struct ssh_hashalg *h, ptrlen ciphertext,
     b1 = get_mp_ssh2(src);
     sfree(out);
     if (get_err(src) || get_avail(src) != 0) {
-        freebn(b1);
+        mp_free(b1);
         return NULL;
     }
 

Some files were not shown because too many files changed in this diff