sshbn.c 52 KB


  1. /*
  2. * Bignum routines for RSA and DH and stuff.
  3. */
  4. #include <stdio.h>
  5. #include <assert.h>
  6. #include <stdlib.h>
  7. #include <string.h>
  8. #include <limits.h>
  9. #include "misc.h"
  10. #include "sshbn.h"
  11. #define BIGNUM_INTERNAL
  12. typedef BignumInt *Bignum;
  13. #include "ssh.h"
  14. BignumInt bnZero[1] = { 0 };
  15. BignumInt bnOne[2] = { 1, 1 };
  16. /*
  17. * The Bignum format is an array of `BignumInt'. The first
  18. * element of the array counts the remaining elements. The
  19. * remaining elements express the actual number, base 2^BIGNUM_INT_BITS, _least_
  20. * significant digit first. (So it's trivial to extract the bit
  21. * with value 2^n for any n.)
  22. *
  23. * All Bignums in this module are positive. Negative numbers must
  24. * be dealt with outside it.
  25. *
  26. * INVARIANT: the most significant word of any Bignum must be
  27. * nonzero.
  28. */
  29. Bignum Zero = bnZero, One = bnOne;
  30. static Bignum newbn(int length)
  31. {
  32. Bignum b;
  33. assert(length >= 0 && length < INT_MAX / BIGNUM_INT_BITS);
  34. b = snewn(length + 1, BignumInt);
  35. if (!b)
  36. abort(); /* FIXME */
  37. memset(b, 0, (length + 1) * sizeof(*b));
  38. b[0] = length;
  39. return b;
  40. }
  41. void bn_restore_invariant(Bignum b)
  42. {
  43. while (b[0] > 1 && b[b[0]] == 0)
  44. b[0]--;
  45. }
  46. Bignum copybn(Bignum orig)
  47. {
  48. Bignum b = snewn(orig[0] + 1, BignumInt);
  49. if (!b)
  50. abort(); /* FIXME */
  51. memcpy(b, orig, (orig[0] + 1) * sizeof(*b));
  52. return b;
  53. }
  54. void freebn(Bignum b)
  55. {
  56. /*
  57. * Burn the evidence, just in case.
  58. */
  59. smemclr(b, sizeof(b[0]) * (b[0] + 1));
  60. sfree(b);
  61. }
  62. Bignum bn_power_2(int n)
  63. {
  64. Bignum ret;
  65. assert(n >= 0);
  66. ret = newbn(n / BIGNUM_INT_BITS + 1);
  67. bignum_set_bit(ret, n, 1);
  68. return ret;
  69. }
  70. /*
  71. * Internal addition. Sets c = a - b, where 'a', 'b' and 'c' are all
  72. * big-endian arrays of 'len' BignumInts. Returns a BignumInt carried
  73. * off the top.
  74. */
  75. static BignumInt internal_add(const BignumInt *a, const BignumInt *b,
  76. BignumInt *c, int len)
  77. {
  78. int i;
  79. BignumDblInt carry = 0;
  80. for (i = len-1; i >= 0; i--) {
  81. carry += (BignumDblInt)a[i] + b[i];
  82. c[i] = (BignumInt)carry;
  83. carry >>= BIGNUM_INT_BITS;
  84. }
  85. return (BignumInt)carry;
  86. }
  87. /*
  88. * Internal subtraction. Sets c = a - b, where 'a', 'b' and 'c' are
  89. * all big-endian arrays of 'len' BignumInts. Any borrow from the top
  90. * is ignored.
  91. */
  92. static void internal_sub(const BignumInt *a, const BignumInt *b,
  93. BignumInt *c, int len)
  94. {
  95. int i;
  96. BignumDblInt carry = 1;
  97. for (i = len-1; i >= 0; i--) {
  98. carry += (BignumDblInt)a[i] + (b[i] ^ BIGNUM_INT_MASK);
  99. c[i] = (BignumInt)carry;
  100. carry >>= BIGNUM_INT_BITS;
  101. }
  102. }
  103. /*
  104. * Compute c = a * b.
  105. * Input is in the first len words of a and b.
  106. * Result is returned in the first 2*len words of c.
  107. *
  108. * 'scratch' must point to an array of BignumInt of size at least
  109. * mul_compute_scratch(len). (This covers the needs of internal_mul
  110. * and all its recursive calls to itself.)
  111. */
  112. #define KARATSUBA_THRESHOLD 50
  113. static int mul_compute_scratch(int len)
  114. {
  115. int ret = 0;
  116. while (len > KARATSUBA_THRESHOLD) {
  117. int toplen = len/2, botlen = len - toplen; /* botlen is the bigger */
  118. int midlen = botlen + 1;
  119. ret += 4*midlen;
  120. len = midlen;
  121. }
  122. return ret;
  123. }
  124. static void internal_mul(const BignumInt *a, const BignumInt *b,
  125. BignumInt *c, int len, BignumInt *scratch)
  126. {
  127. if (len > KARATSUBA_THRESHOLD) {
  128. int i;
  129. /*
  130. * Karatsuba divide-and-conquer algorithm. Cut each input in
  131. * half, so that it's expressed as two big 'digits' in a giant
  132. * base D:
  133. *
  134. * a = a_1 D + a_0
  135. * b = b_1 D + b_0
  136. *
  137. * Then the product is of course
  138. *
  139. * ab = a_1 b_1 D^2 + (a_1 b_0 + a_0 b_1) D + a_0 b_0
  140. *
  141. * and we compute the three coefficients by recursively
  142. * calling ourself to do half-length multiplications.
  143. *
  144. * The clever bit that makes this worth doing is that we only
  145. * need _one_ half-length multiplication for the central
  146. * coefficient rather than the two that it obviouly looks
  147. * like, because we can use a single multiplication to compute
  148. *
  149. * (a_1 + a_0) (b_1 + b_0) = a_1 b_1 + a_1 b_0 + a_0 b_1 + a_0 b_0
  150. *
  151. * and then we subtract the other two coefficients (a_1 b_1
  152. * and a_0 b_0) which we were computing anyway.
  153. *
  154. * Hence we get to multiply two numbers of length N in about
  155. * three times as much work as it takes to multiply numbers of
  156. * length N/2, which is obviously better than the four times
  157. * as much work it would take if we just did a long
  158. * conventional multiply.
  159. */
  160. int toplen = len/2, botlen = len - toplen; /* botlen is the bigger */
  161. int midlen = botlen + 1;
  162. BignumDblInt carry;
  163. #ifdef KARA_DEBUG
  164. int i;
  165. #endif
  166. /*
  167. * The coefficients a_1 b_1 and a_0 b_0 just avoid overlapping
  168. * in the output array, so we can compute them immediately in
  169. * place.
  170. */
  171. #ifdef KARA_DEBUG
  172. printf("a1,a0 = 0x");
  173. for (i = 0; i < len; i++) {
  174. if (i == toplen) printf(", 0x");
  175. printf("%0*x", BIGNUM_INT_BITS/4, a[i]);
  176. }
  177. printf("\n");
  178. printf("b1,b0 = 0x");
  179. for (i = 0; i < len; i++) {
  180. if (i == toplen) printf(", 0x");
  181. printf("%0*x", BIGNUM_INT_BITS/4, b[i]);
  182. }
  183. printf("\n");
  184. #endif
  185. /* a_1 b_1 */
  186. internal_mul(a, b, c, toplen, scratch);
  187. #ifdef KARA_DEBUG
  188. printf("a1b1 = 0x");
  189. for (i = 0; i < 2*toplen; i++) {
  190. printf("%0*x", BIGNUM_INT_BITS/4, c[i]);
  191. }
  192. printf("\n");
  193. #endif
  194. /* a_0 b_0 */
  195. internal_mul(a + toplen, b + toplen, c + 2*toplen, botlen, scratch);
  196. #ifdef KARA_DEBUG
  197. printf("a0b0 = 0x");
  198. for (i = 0; i < 2*botlen; i++) {
  199. printf("%0*x", BIGNUM_INT_BITS/4, c[2*toplen+i]);
  200. }
  201. printf("\n");
  202. #endif
  203. /* Zero padding. midlen exceeds toplen by at most 2, so just
  204. * zero the first two words of each input and the rest will be
  205. * copied over. */
  206. scratch[0] = scratch[1] = scratch[midlen] = scratch[midlen+1] = 0;
  207. for (i = 0; i < toplen; i++) {
  208. scratch[midlen - toplen + i] = a[i]; /* a_1 */
  209. scratch[2*midlen - toplen + i] = b[i]; /* b_1 */
  210. }
  211. /* compute a_1 + a_0 */
  212. scratch[0] = internal_add(scratch+1, a+toplen, scratch+1, botlen);
  213. #ifdef KARA_DEBUG
  214. printf("a1plusa0 = 0x");
  215. for (i = 0; i < midlen; i++) {
  216. printf("%0*x", BIGNUM_INT_BITS/4, scratch[i]);
  217. }
  218. printf("\n");
  219. #endif
  220. /* compute b_1 + b_0 */
  221. scratch[midlen] = internal_add(scratch+midlen+1, b+toplen,
  222. scratch+midlen+1, botlen);
  223. #ifdef KARA_DEBUG
  224. printf("b1plusb0 = 0x");
  225. for (i = 0; i < midlen; i++) {
  226. printf("%0*x", BIGNUM_INT_BITS/4, scratch[midlen+i]);
  227. }
  228. printf("\n");
  229. #endif
  230. /*
  231. * Now we can do the third multiplication.
  232. */
  233. internal_mul(scratch, scratch + midlen, scratch + 2*midlen, midlen,
  234. scratch + 4*midlen);
  235. #ifdef KARA_DEBUG
  236. printf("a1plusa0timesb1plusb0 = 0x");
  237. for (i = 0; i < 2*midlen; i++) {
  238. printf("%0*x", BIGNUM_INT_BITS/4, scratch[2*midlen+i]);
  239. }
  240. printf("\n");
  241. #endif
  242. /*
  243. * Now we can reuse the first half of 'scratch' to compute the
  244. * sum of the outer two coefficients, to subtract from that
  245. * product to obtain the middle one.
  246. */
  247. scratch[0] = scratch[1] = scratch[2] = scratch[3] = 0;
  248. for (i = 0; i < 2*toplen; i++)
  249. scratch[2*midlen - 2*toplen + i] = c[i];
  250. scratch[1] = internal_add(scratch+2, c + 2*toplen,
  251. scratch+2, 2*botlen);
  252. #ifdef KARA_DEBUG
  253. printf("a1b1plusa0b0 = 0x");
  254. for (i = 0; i < 2*midlen; i++) {
  255. printf("%0*x", BIGNUM_INT_BITS/4, scratch[i]);
  256. }
  257. printf("\n");
  258. #endif
  259. internal_sub(scratch + 2*midlen, scratch,
  260. scratch + 2*midlen, 2*midlen);
  261. #ifdef KARA_DEBUG
  262. printf("a1b0plusa0b1 = 0x");
  263. for (i = 0; i < 2*midlen; i++) {
  264. printf("%0*x", BIGNUM_INT_BITS/4, scratch[2*midlen+i]);
  265. }
  266. printf("\n");
  267. #endif
  268. /*
  269. * And now all we need to do is to add that middle coefficient
  270. * back into the output. We may have to propagate a carry
  271. * further up the output, but we can be sure it won't
  272. * propagate right the way off the top.
  273. */
  274. carry = internal_add(c + 2*len - botlen - 2*midlen,
  275. scratch + 2*midlen,
  276. c + 2*len - botlen - 2*midlen, 2*midlen);
  277. i = 2*len - botlen - 2*midlen - 1;
  278. while (carry) {
  279. assert(i >= 0);
  280. carry += c[i];
  281. c[i] = (BignumInt)carry;
  282. carry >>= BIGNUM_INT_BITS;
  283. i--;
  284. }
  285. #ifdef KARA_DEBUG
  286. printf("ab = 0x");
  287. for (i = 0; i < 2*len; i++) {
  288. printf("%0*x", BIGNUM_INT_BITS/4, c[i]);
  289. }
  290. printf("\n");
  291. #endif
  292. } else {
  293. int i;
  294. BignumInt carry;
  295. BignumDblInt t;
  296. const BignumInt *ap, *bp;
  297. BignumInt *cp, *cps;
  298. /*
  299. * Multiply in the ordinary O(N^2) way.
  300. */
  301. for (i = 0; i < 2 * len; i++)
  302. c[i] = 0;
  303. for (cps = c + 2*len, ap = a + len; ap-- > a; cps--) {
  304. carry = 0;
  305. for (cp = cps, bp = b + len; cp--, bp-- > b ;) {
  306. t = (MUL_WORD(*ap, *bp) + carry) + *cp;
  307. *cp = (BignumInt) t;
  308. carry = (BignumInt)(t >> BIGNUM_INT_BITS);
  309. }
  310. *cp = carry;
  311. }
  312. }
  313. }
  314. /*
  315. * Variant form of internal_mul used for the initial step of
  316. * Montgomery reduction. Only bothers outputting 'len' words
  317. * (everything above that is thrown away).
  318. */
  319. static void internal_mul_low(const BignumInt *a, const BignumInt *b,
  320. BignumInt *c, int len, BignumInt *scratch)
  321. {
  322. if (len > KARATSUBA_THRESHOLD) {
  323. int i;
  324. /*
  325. * Karatsuba-aware version of internal_mul_low. As before, we
  326. * express each input value as a shifted combination of two
  327. * halves:
  328. *
  329. * a = a_1 D + a_0
  330. * b = b_1 D + b_0
  331. *
  332. * Then the full product is, as before,
  333. *
  334. * ab = a_1 b_1 D^2 + (a_1 b_0 + a_0 b_1) D + a_0 b_0
  335. *
  336. * Provided we choose D on the large side (so that a_0 and b_0
  337. * are _at least_ as long as a_1 and b_1), we don't need the
  338. * topmost term at all, and we only need half of the middle
  339. * term. So there's no point in doing the proper Karatsuba
  340. * optimisation which computes the middle term using the top
  341. * one, because we'd take as long computing the top one as
  342. * just computing the middle one directly.
  343. *
  344. * So instead, we do a much more obvious thing: we call the
  345. * fully optimised internal_mul to compute a_0 b_0, and we
  346. * recursively call ourself to compute the _bottom halves_ of
  347. * a_1 b_0 and a_0 b_1, each of which we add into the result
  348. * in the obvious way.
  349. *
  350. * In other words, there's no actual Karatsuba _optimisation_
  351. * in this function; the only benefit in doing it this way is
  352. * that we call internal_mul proper for a large part of the
  353. * work, and _that_ can optimise its operation.
  354. */
  355. int toplen = len/2, botlen = len - toplen; /* botlen is the bigger */
  356. /*
  357. * Scratch space for the various bits and pieces we're going
  358. * to be adding together: we need botlen*2 words for a_0 b_0
  359. * (though we may end up throwing away its topmost word), and
  360. * toplen words for each of a_1 b_0 and a_0 b_1. That adds up
  361. * to exactly 2*len.
  362. */
  363. /* a_0 b_0 */
  364. internal_mul(a + toplen, b + toplen, scratch + 2*toplen, botlen,
  365. scratch + 2*len);
  366. /* a_1 b_0 */
  367. internal_mul_low(a, b + len - toplen, scratch + toplen, toplen,
  368. scratch + 2*len);
  369. /* a_0 b_1 */
  370. internal_mul_low(a + len - toplen, b, scratch, toplen,
  371. scratch + 2*len);
  372. /* Copy the bottom half of the big coefficient into place */
  373. for (i = 0; i < botlen; i++)
  374. c[toplen + i] = scratch[2*toplen + botlen + i];
  375. /* Add the two small coefficients, throwing away the returned carry */
  376. internal_add(scratch, scratch + toplen, scratch, toplen);
  377. /* And add that to the large coefficient, leaving the result in c. */
  378. internal_add(scratch, scratch + 2*toplen + botlen - toplen,
  379. c, toplen);
  380. } else {
  381. int i;
  382. BignumInt carry;
  383. BignumDblInt t;
  384. const BignumInt *ap, *bp;
  385. BignumInt *cp, *cps;
  386. /*
  387. * Multiply in the ordinary O(N^2) way.
  388. */
  389. for (i = 0; i < len; i++)
  390. c[i] = 0;
  391. for (cps = c + len, ap = a + len; ap-- > a; cps--) {
  392. carry = 0;
  393. for (cp = cps, bp = b + len; bp--, cp-- > c ;) {
  394. t = (MUL_WORD(*ap, *bp) + carry) + *cp;
  395. *cp = (BignumInt) t;
  396. carry = (BignumInt)(t >> BIGNUM_INT_BITS);
  397. }
  398. }
  399. }
  400. }
  401. /*
  402. * Montgomery reduction. Expects x to be a big-endian array of 2*len
  403. * BignumInts whose value satisfies 0 <= x < rn (where r = 2^(len *
  404. * BIGNUM_INT_BITS) is the Montgomery base). Returns in the same array
  405. * a value x' which is congruent to xr^{-1} mod n, and satisfies 0 <=
  406. * x' < n.
  407. *
  408. * 'n' and 'mninv' should be big-endian arrays of 'len' BignumInts
  409. * each, containing respectively n and the multiplicative inverse of
  410. * -n mod r.
  411. *
  412. * 'tmp' is an array of BignumInt used as scratch space, of length at
  413. * least 3*len + mul_compute_scratch(len).
  414. */
  415. static void monty_reduce(BignumInt *x, const BignumInt *n,
  416. const BignumInt *mninv, BignumInt *tmp, int len)
  417. {
  418. int i;
  419. BignumInt carry;
  420. /*
  421. * Multiply x by (-n)^{-1} mod r. This gives us a value m such
  422. * that mn is congruent to -x mod r. Hence, mn+x is an exact
  423. * multiple of r, and is also (obviously) congruent to x mod n.
  424. */
  425. internal_mul_low(x + len, mninv, tmp, len, tmp + 3*len);
  426. /*
  427. * Compute t = (mn+x)/r in ordinary, non-modular, integer
  428. * arithmetic. By construction this is exact, and is congruent mod
  429. * n to x * r^{-1}, i.e. the answer we want.
  430. *
  431. * The following multiply leaves that answer in the _most_
  432. * significant half of the 'x' array, so then we must shift it
  433. * down.
  434. */
  435. internal_mul(tmp, n, tmp+len, len, tmp + 3*len);
  436. carry = internal_add(x, tmp+len, x, 2*len);
  437. for (i = 0; i < len; i++)
  438. x[len + i] = x[i], x[i] = 0;
  439. /*
  440. * Reduce t mod n. This doesn't require a full-on division by n,
  441. * but merely a test and single optional subtraction, since we can
  442. * show that 0 <= t < 2n.
  443. *
  444. * Proof:
  445. * + we computed m mod r, so 0 <= m < r.
  446. * + so 0 <= mn < rn, obviously
  447. * + hence we only need 0 <= x < rn to guarantee that 0 <= mn+x < 2rn
  448. * + yielding 0 <= (mn+x)/r < 2n as required.
  449. */
  450. if (!carry) {
  451. for (i = 0; i < len; i++)
  452. if (x[len + i] != n[i])
  453. break;
  454. }
  455. if (carry || i >= len || x[len + i] > n[i])
  456. internal_sub(x+len, n, x+len, len);
  457. }
  458. static void internal_add_shifted(BignumInt *number,
  459. BignumInt n, int shift)
  460. {
  461. int word = 1 + (shift / BIGNUM_INT_BITS);
  462. int bshift = shift % BIGNUM_INT_BITS;
  463. BignumDblInt addend;
  464. addend = (BignumDblInt)n << bshift;
  465. while (addend) {
  466. assert(word <= number[0]);
  467. addend += number[word];
  468. number[word] = (BignumInt) addend & BIGNUM_INT_MASK;
  469. addend >>= BIGNUM_INT_BITS;
  470. word++;
  471. }
  472. }
  473. /*
  474. * Compute a = a % m.
  475. * Input in first alen words of a and first mlen words of m.
  476. * Output in first alen words of a
  477. * (of which first alen-mlen words will be zero).
  478. * The MSW of m MUST have its high bit set.
  479. * Quotient is accumulated in the `quotient' array, which is a Bignum
  480. * rather than the internal bigendian format. Quotient parts are shifted
  481. * left by `qshift' before adding into quot.
  482. */
  483. static void internal_mod(BignumInt *a, int alen,
  484. BignumInt *m, int mlen,
  485. BignumInt *quot, int qshift)
  486. {
  487. BignumInt m0, m1, h;
  488. int i, k;
  489. m0 = m[0];
  490. assert(m0 >> (BIGNUM_INT_BITS-1) == 1);
  491. if (mlen > 1)
  492. m1 = m[1];
  493. else
  494. m1 = 0;
  495. for (i = 0; i <= alen - mlen; i++) {
  496. BignumDblInt t;
  497. BignumInt q, r, c, ai1;
  498. if (i == 0) {
  499. h = 0;
  500. } else {
  501. h = a[i - 1];
  502. a[i - 1] = 0;
  503. }
  504. if (i == alen - 1)
  505. ai1 = 0;
  506. else
  507. ai1 = a[i + 1];
  508. /* Find q = h:a[i] / m0 */
  509. if (h >= m0) {
  510. /*
  511. * Special case.
  512. *
  513. * To illustrate it, suppose a BignumInt is 8 bits, and
  514. * we are dividing (say) A1:23:45:67 by A1:B2:C3. Then
  515. * our initial division will be 0xA123 / 0xA1, which
  516. * will give a quotient of 0x100 and a divide overflow.
  517. * However, the invariants in this division algorithm
  518. * are not violated, since the full number A1:23:... is
  519. * _less_ than the quotient prefix A1:B2:... and so the
  520. * following correction loop would have sorted it out.
  521. *
  522. * In this situation we set q to be the largest
  523. * quotient we _can_ stomach (0xFF, of course).
  524. */
  525. q = BIGNUM_INT_MASK;
  526. } else {
  527. /* Macro doesn't want an array subscript expression passed
  528. * into it (see definition), so use a temporary. */
  529. BignumInt tmplo = a[i];
  530. DIVMOD_WORD(q, r, h, tmplo, m0);
  531. /* Refine our estimate of q by looking at
  532. h:a[i]:a[i+1] / m0:m1 */
  533. t = MUL_WORD(m1, q);
  534. if (t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) {
  535. q--;
  536. t -= m1;
  537. r = (r + m0) & BIGNUM_INT_MASK; /* overflow? */
  538. if (r >= (BignumDblInt) m0 &&
  539. t > ((BignumDblInt) r << BIGNUM_INT_BITS) + ai1) q--;
  540. }
  541. }
  542. /* Subtract q * m from a[i...] */
  543. c = 0;
  544. for (k = mlen - 1; k >= 0; k--) {
  545. t = MUL_WORD(q, m[k]);
  546. t += c;
  547. c = (BignumInt)(t >> BIGNUM_INT_BITS);
  548. if ((BignumInt) t > a[i + k])
  549. c++;
  550. a[i + k] -= (BignumInt) t;
  551. }
  552. /* Add back m in case of borrow */
  553. if (c != h) {
  554. t = 0;
  555. for (k = mlen - 1; k >= 0; k--) {
  556. t += m[k];
  557. t += a[i + k];
  558. a[i + k] = (BignumInt) t;
  559. t = t >> BIGNUM_INT_BITS;
  560. }
  561. q--;
  562. }
  563. if (quot)
  564. internal_add_shifted(quot, q, qshift + BIGNUM_INT_BITS * (alen - mlen - i));
  565. }
  566. }
  567. /*
  568. * Compute (base ^ exp) % mod, the pedestrian way.
  569. */
  570. Bignum modpow_simple(Bignum base_in, Bignum exp, Bignum mod)
  571. {
  572. BignumInt *a, *b, *n, *m, *scratch;
  573. int mshift;
  574. int mlen, scratchlen, i, j;
  575. Bignum base, result;
  576. /*
  577. * The most significant word of mod needs to be non-zero. It
  578. * should already be, but let's make sure.
  579. */
  580. assert(mod[mod[0]] != 0);
  581. /*
  582. * Make sure the base is smaller than the modulus, by reducing
  583. * it modulo the modulus if not.
  584. */
  585. base = bigmod(base_in, mod);
  586. /* Allocate m of size mlen, copy mod to m */
  587. /* We use big endian internally */
  588. mlen = mod[0];
  589. m = snewn(mlen, BignumInt);
  590. for (j = 0; j < mlen; j++)
  591. m[j] = mod[mod[0] - j];
  592. /* Shift m left to make msb bit set */
  593. for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++)
  594. if ((m[0] << mshift) & BIGNUM_TOP_BIT)
  595. break;
  596. if (mshift) {
  597. for (i = 0; i < mlen - 1; i++)
  598. m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift));
  599. m[mlen - 1] = m[mlen - 1] << mshift;
  600. }
  601. /* Allocate n of size mlen, copy base to n */
  602. n = snewn(mlen, BignumInt);
  603. i = mlen - base[0];
  604. for (j = 0; j < i; j++)
  605. n[j] = 0;
  606. for (j = 0; j < (int)base[0]; j++)
  607. n[i + j] = base[base[0] - j];
  608. /* Allocate a and b of size 2*mlen. Set a = 1 */
  609. a = snewn(2 * mlen, BignumInt);
  610. b = snewn(2 * mlen, BignumInt);
  611. for (i = 0; i < 2 * mlen; i++)
  612. a[i] = 0;
  613. a[2 * mlen - 1] = 1;
  614. /* Scratch space for multiplies */
  615. scratchlen = mul_compute_scratch(mlen);
  616. scratch = snewn(scratchlen, BignumInt);
  617. /* Skip leading zero bits of exp. */
  618. i = 0;
  619. j = BIGNUM_INT_BITS-1;
  620. while (i < (int)exp[0] && (exp[exp[0] - i] & ((BignumInt)1 << j)) == 0) {
  621. j--;
  622. if (j < 0) {
  623. i++;
  624. j = BIGNUM_INT_BITS-1;
  625. }
  626. }
  627. /* Main computation */
  628. while (i < (int)exp[0]) {
  629. while (j >= 0) {
  630. internal_mul(a + mlen, a + mlen, b, mlen, scratch);
  631. internal_mod(b, mlen * 2, m, mlen, NULL, 0);
  632. if ((exp[exp[0] - i] & ((BignumInt)1 << j)) != 0) {
  633. internal_mul(b + mlen, n, a, mlen, scratch);
  634. internal_mod(a, mlen * 2, m, mlen, NULL, 0);
  635. } else {
  636. BignumInt *t;
  637. t = a;
  638. a = b;
  639. b = t;
  640. }
  641. j--;
  642. }
  643. i++;
  644. j = BIGNUM_INT_BITS-1;
  645. }
  646. /* Fixup result in case the modulus was shifted */
  647. if (mshift) {
  648. for (i = mlen - 1; i < 2 * mlen - 1; i++)
  649. a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift));
  650. a[2 * mlen - 1] = a[2 * mlen - 1] << mshift;
  651. internal_mod(a, mlen * 2, m, mlen, NULL, 0);
  652. for (i = 2 * mlen - 1; i >= mlen; i--)
  653. a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift));
  654. }
  655. /* Copy result to buffer */
  656. result = newbn(mod[0]);
  657. for (i = 0; i < mlen; i++)
  658. result[result[0] - i] = a[i + mlen];
  659. while (result[0] > 1 && result[result[0]] == 0)
  660. result[0]--;
  661. /* Free temporary arrays */
  662. smemclr(a, 2 * mlen * sizeof(*a));
  663. sfree(a);
  664. smemclr(scratch, scratchlen * sizeof(*scratch));
  665. sfree(scratch);
  666. smemclr(b, 2 * mlen * sizeof(*b));
  667. sfree(b);
  668. smemclr(m, mlen * sizeof(*m));
  669. sfree(m);
  670. smemclr(n, mlen * sizeof(*n));
  671. sfree(n);
  672. freebn(base);
  673. return result;
  674. }
  675. /*
  676. * Compute (base ^ exp) % mod. Uses the Montgomery multiplication
  677. * technique where possible, falling back to modpow_simple otherwise.
  678. */
  679. Bignum modpow(Bignum base_in, Bignum exp, Bignum mod)
  680. {
  681. BignumInt *a, *b, *x, *n, *mninv, *scratch;
  682. int len, scratchlen, i, j;
  683. Bignum base, base2, r, rn, inv, result;
  684. /*
  685. * The most significant word of mod needs to be non-zero. It
  686. * should already be, but let's make sure.
  687. */
  688. assert(mod[mod[0]] != 0);
  689. /*
  690. * mod had better be odd, or we can't do Montgomery multiplication
  691. * using a power of two at all.
  692. */
  693. if (!(mod[1] & 1))
  694. return modpow_simple(base_in, exp, mod);
  695. /*
  696. * Make sure the base is smaller than the modulus, by reducing
  697. * it modulo the modulus if not.
  698. */
  699. base = bigmod(base_in, mod);
  700. /*
  701. * Compute the inverse of n mod r, for monty_reduce. (In fact we
  702. * want the inverse of _minus_ n mod r, but we'll sort that out
  703. * below.)
  704. */
  705. len = mod[0];
  706. r = bn_power_2(BIGNUM_INT_BITS * len);
  707. inv = modinv(mod, r);
  708. assert(inv); /* cannot fail, since mod is odd and r is a power of 2 */
  709. /*
  710. * Multiply the base by r mod n, to get it into Montgomery
  711. * representation.
  712. */
  713. base2 = modmul(base, r, mod);
  714. freebn(base);
  715. base = base2;
  716. rn = bigmod(r, mod); /* r mod n, i.e. Montgomerified 1 */
  717. freebn(r); /* won't need this any more */
  718. /*
  719. * Set up internal arrays of the right lengths, in big-endian
  720. * format, containing the base, the modulus, and the modulus's
  721. * inverse.
  722. */
  723. n = snewn(len, BignumInt);
  724. for (j = 0; j < len; j++)
  725. n[len - 1 - j] = mod[j + 1];
  726. mninv = snewn(len, BignumInt);
  727. for (j = 0; j < len; j++)
  728. mninv[len - 1 - j] = (j < (int)inv[0] ? inv[j + 1] : 0);
  729. freebn(inv); /* we don't need this copy of it any more */
  730. /* Now negate mninv mod r, so it's the inverse of -n rather than +n. */
  731. x = snewn(len, BignumInt);
  732. for (j = 0; j < len; j++)
  733. x[j] = 0;
  734. internal_sub(x, mninv, mninv, len);
  735. /* x = snewn(len, BignumInt); */ /* already done above */
  736. for (j = 0; j < len; j++)
  737. x[len - 1 - j] = (j < (int)base[0] ? base[j + 1] : 0);
  738. freebn(base); /* we don't need this copy of it any more */
  739. a = snewn(2*len, BignumInt);
  740. b = snewn(2*len, BignumInt);
  741. for (j = 0; j < len; j++)
  742. a[2*len - 1 - j] = (j < (int)rn[0] ? rn[j + 1] : 0);
  743. freebn(rn);
  744. /* Scratch space for multiplies */
  745. scratchlen = 3*len + mul_compute_scratch(len);
  746. scratch = snewn(scratchlen, BignumInt);
  747. /* Skip leading zero bits of exp. */
  748. i = 0;
  749. j = BIGNUM_INT_BITS-1;
  750. while (i < (int)exp[0] && (exp[exp[0] - i] & ((BignumInt)1 << j)) == 0) {
  751. j--;
  752. if (j < 0) {
  753. i++;
  754. j = BIGNUM_INT_BITS-1;
  755. }
  756. }
  757. /* Main computation */
  758. while (i < (int)exp[0]) {
  759. while (j >= 0) {
  760. internal_mul(a + len, a + len, b, len, scratch);
  761. monty_reduce(b, n, mninv, scratch, len);
  762. if ((exp[exp[0] - i] & ((BignumInt)1 << j)) != 0) {
  763. internal_mul(b + len, x, a, len, scratch);
  764. monty_reduce(a, n, mninv, scratch, len);
  765. } else {
  766. BignumInt *t;
  767. t = a;
  768. a = b;
  769. b = t;
  770. }
  771. j--;
  772. }
  773. i++;
  774. j = BIGNUM_INT_BITS-1;
  775. }
  776. /*
  777. * Final monty_reduce to get back from the adjusted Montgomery
  778. * representation.
  779. */
  780. monty_reduce(a, n, mninv, scratch, len);
  781. /* Copy result to buffer */
  782. result = newbn(mod[0]);
  783. for (i = 0; i < len; i++)
  784. result[result[0] - i] = a[i + len];
  785. while (result[0] > 1 && result[result[0]] == 0)
  786. result[0]--;
  787. /* Free temporary arrays */
  788. smemclr(scratch, scratchlen * sizeof(*scratch));
  789. sfree(scratch);
  790. smemclr(a, 2 * len * sizeof(*a));
  791. sfree(a);
  792. smemclr(b, 2 * len * sizeof(*b));
  793. sfree(b);
  794. smemclr(mninv, len * sizeof(*mninv));
  795. sfree(mninv);
  796. smemclr(n, len * sizeof(*n));
  797. sfree(n);
  798. smemclr(x, len * sizeof(*x));
  799. sfree(x);
  800. return result;
  801. }
  802. /*
  803. * Compute (p * q) % mod.
  804. * The most significant word of mod MUST be non-zero.
  805. * We assume that the result array is the same size as the mod array.
  806. */
  807. Bignum modmul(Bignum p, Bignum q, Bignum mod)
  808. {
  809. BignumInt *a, *n, *m, *o, *scratch;
  810. int mshift, scratchlen;
  811. int pqlen, mlen, rlen, i, j;
  812. Bignum result;
  813. /*
  814. * The most significant word of mod needs to be non-zero. It
  815. * should already be, but let's make sure.
  816. */
  817. assert(mod[mod[0]] != 0);
  818. /* Allocate m of size mlen, copy mod to m */
  819. /* We use big endian internally */
  820. mlen = mod[0];
  821. m = snewn(mlen, BignumInt);
  822. for (j = 0; j < mlen; j++)
  823. m[j] = mod[mod[0] - j];
  824. /* Shift m left to make msb bit set */
  825. for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++)
  826. if ((m[0] << mshift) & BIGNUM_TOP_BIT)
  827. break;
  828. if (mshift) {
  829. for (i = 0; i < mlen - 1; i++)
  830. m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift));
  831. m[mlen - 1] = m[mlen - 1] << mshift;
  832. }
  833. pqlen = (p[0] > q[0] ? p[0] : q[0]);
  834. /*
  835. * Make sure that we're allowing enough space. The shifting below
  836. * will underflow the vectors we allocate if pqlen is too small.
  837. */
  838. if (2*pqlen <= mlen)
  839. pqlen = mlen/2 + 1;
  840. /* Allocate n of size pqlen, copy p to n */
  841. n = snewn(pqlen, BignumInt);
  842. i = pqlen - p[0];
  843. for (j = 0; j < i; j++)
  844. n[j] = 0;
  845. for (j = 0; j < (int)p[0]; j++)
  846. n[i + j] = p[p[0] - j];
  847. /* Allocate o of size pqlen, copy q to o */
  848. o = snewn(pqlen, BignumInt);
  849. i = pqlen - q[0];
  850. for (j = 0; j < i; j++)
  851. o[j] = 0;
  852. for (j = 0; j < (int)q[0]; j++)
  853. o[i + j] = q[q[0] - j];
  854. /* Allocate a of size 2*pqlen for result */
  855. a = snewn(2 * pqlen, BignumInt);
  856. /* Scratch space for multiplies */
  857. scratchlen = mul_compute_scratch(pqlen);
  858. scratch = snewn(scratchlen, BignumInt);
  859. /* Main computation */
  860. internal_mul(n, o, a, pqlen, scratch);
  861. internal_mod(a, pqlen * 2, m, mlen, NULL, 0);
  862. /* Fixup result in case the modulus was shifted */
  863. if (mshift) {
  864. for (i = 2 * pqlen - mlen - 1; i < 2 * pqlen - 1; i++)
  865. a[i] = (a[i] << mshift) | (a[i + 1] >> (BIGNUM_INT_BITS - mshift));
  866. a[2 * pqlen - 1] = a[2 * pqlen - 1] << mshift;
  867. internal_mod(a, pqlen * 2, m, mlen, NULL, 0);
  868. for (i = 2 * pqlen - 1; i >= 2 * pqlen - mlen; i--)
  869. a[i] = (a[i] >> mshift) | (a[i - 1] << (BIGNUM_INT_BITS - mshift));
  870. }
  871. /* Copy result to buffer */
  872. rlen = (mlen < pqlen * 2 ? mlen : pqlen * 2);
  873. result = newbn(rlen);
  874. for (i = 0; i < rlen; i++)
  875. result[result[0] - i] = a[i + 2 * pqlen - rlen];
  876. while (result[0] > 1 && result[result[0]] == 0)
  877. result[0]--;
  878. /* Free temporary arrays */
  879. smemclr(scratch, scratchlen * sizeof(*scratch));
  880. sfree(scratch);
  881. smemclr(a, 2 * pqlen * sizeof(*a));
  882. sfree(a);
  883. smemclr(m, mlen * sizeof(*m));
  884. sfree(m);
  885. smemclr(n, pqlen * sizeof(*n));
  886. sfree(n);
  887. smemclr(o, pqlen * sizeof(*o));
  888. sfree(o);
  889. return result;
  890. }
  891. /*
  892. * Compute p % mod.
  893. * The most significant word of mod MUST be non-zero.
  894. * We assume that the result array is the same size as the mod array.
  895. * We optionally write out a quotient if `quotient' is non-NULL.
  896. * We can avoid writing out the result if `result' is NULL.
  897. */
  898. static void bigdivmod(Bignum p, Bignum mod, Bignum result, Bignum quotient)
  899. {
  900. BignumInt *n, *m;
  901. int mshift;
  902. int plen, mlen, i, j;
  903. /*
  904. * The most significant word of mod needs to be non-zero. It
  905. * should already be, but let's make sure.
  906. */
  907. assert(mod[mod[0]] != 0);
  908. /* Allocate m of size mlen, copy mod to m */
  909. /* We use big endian internally */
  910. mlen = mod[0];
  911. m = snewn(mlen, BignumInt);
  912. for (j = 0; j < mlen; j++)
  913. m[j] = mod[mod[0] - j];
  914. /* Shift m left to make msb bit set */
  915. for (mshift = 0; mshift < BIGNUM_INT_BITS-1; mshift++)
  916. if ((m[0] << mshift) & BIGNUM_TOP_BIT)
  917. break;
  918. if (mshift) {
  919. for (i = 0; i < mlen - 1; i++)
  920. m[i] = (m[i] << mshift) | (m[i + 1] >> (BIGNUM_INT_BITS - mshift));
  921. m[mlen - 1] = m[mlen - 1] << mshift;
  922. }
  923. plen = p[0];
  924. /* Ensure plen > mlen */
  925. if (plen <= mlen)
  926. plen = mlen + 1;
  927. /* Allocate n of size plen, copy p to n */
  928. n = snewn(plen, BignumInt);
  929. for (j = 0; j < plen; j++)
  930. n[j] = 0;
  931. for (j = 1; j <= (int)p[0]; j++)
  932. n[plen - j] = p[j];
  933. /* Main computation */
  934. internal_mod(n, plen, m, mlen, quotient, mshift);
  935. /* Fixup result in case the modulus was shifted */
  936. if (mshift) {
  937. for (i = plen - mlen - 1; i < plen - 1; i++)
  938. n[i] = (n[i] << mshift) | (n[i + 1] >> (BIGNUM_INT_BITS - mshift));
  939. n[plen - 1] = n[plen - 1] << mshift;
  940. internal_mod(n, plen, m, mlen, quotient, 0);
  941. for (i = plen - 1; i >= plen - mlen; i--)
  942. n[i] = (n[i] >> mshift) | (n[i - 1] << (BIGNUM_INT_BITS - mshift));
  943. }
  944. /* Copy result to buffer */
  945. if (result) {
  946. for (i = 1; i <= (int)result[0]; i++) {
  947. int j = plen - i;
  948. result[i] = j >= 0 ? n[j] : 0;
  949. }
  950. }
  951. /* Free temporary arrays */
  952. smemclr(m, mlen * sizeof(*m));
  953. sfree(m);
  954. smemclr(n, plen * sizeof(*n));
  955. sfree(n);
  956. }
  957. /*
  958. * Decrement a number.
  959. */
  960. void decbn(Bignum bn)
  961. {
  962. int i = 1;
  963. while (i < (int)bn[0] && bn[i] == 0)
  964. bn[i++] = BIGNUM_INT_MASK;
  965. bn[i]--;
  966. }
  967. Bignum bignum_from_bytes(const unsigned char *data, int nbytes)
  968. {
  969. Bignum result;
  970. int w, i;
  971. assert(nbytes >= 0 && nbytes < INT_MAX/8);
  972. w = (nbytes + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; /* bytes->words */
  973. result = newbn(w);
  974. for (i = 1; i <= w; i++)
  975. result[i] = 0;
  976. for (i = nbytes; i--;) {
  977. unsigned char byte = *data++;
  978. result[1 + i / BIGNUM_INT_BYTES] |=
  979. (BignumInt)byte << (8*i % BIGNUM_INT_BITS);
  980. }
  981. while (result[0] > 1 && result[result[0]] == 0)
  982. result[0]--;
  983. return result;
  984. }
  985. /*
  986. * Read an SSH-1-format bignum from a data buffer. Return the number
  987. * of bytes consumed, or -1 if there wasn't enough data.
  988. */
  989. int ssh1_read_bignum(const unsigned char *data, int len, Bignum * result)
  990. {
  991. const unsigned char *p = data;
  992. int i;
  993. int w, b;
  994. if (len < 2)
  995. return -1;
  996. w = 0;
  997. for (i = 0; i < 2; i++)
  998. w = (w << 8) + *p++;
  999. b = (w + 7) / 8; /* bits -> bytes */
  1000. if (len < b+2)
  1001. return -1;
  1002. if (!result) /* just return length */
  1003. return b + 2;
  1004. *result = bignum_from_bytes(p, b);
  1005. return p + b - data;
  1006. }
  1007. /*
  1008. * Return the bit count of a bignum, for SSH-1 encoding.
  1009. */
  1010. int bignum_bitcount(Bignum bn)
  1011. {
  1012. int bitcount = bn[0] * BIGNUM_INT_BITS - 1;
  1013. while (bitcount >= 0
  1014. && (bn[bitcount / BIGNUM_INT_BITS + 1] >> (bitcount % BIGNUM_INT_BITS)) == 0) bitcount--;
  1015. return bitcount + 1;
  1016. }
  1017. /*
  1018. * Return the byte length of a bignum when SSH-1 encoded.
  1019. */
  1020. int ssh1_bignum_length(Bignum bn)
  1021. {
  1022. return 2 + (bignum_bitcount(bn) + 7) / 8;
  1023. }
  1024. /*
  1025. * Return the byte length of a bignum when SSH-2 encoded.
  1026. */
  1027. int ssh2_bignum_length(Bignum bn)
  1028. {
  1029. return 4 + (bignum_bitcount(bn) + 8) / 8;
  1030. }
  1031. /*
  1032. * Return a byte from a bignum; 0 is least significant, etc.
  1033. */
  1034. int bignum_byte(Bignum bn, int i)
  1035. {
  1036. if (i < 0 || i >= (int)(BIGNUM_INT_BYTES * bn[0]))
  1037. return 0; /* beyond the end */
  1038. else
  1039. return (bn[i / BIGNUM_INT_BYTES + 1] >>
  1040. ((i % BIGNUM_INT_BYTES)*8)) & 0xFF;
  1041. }
  1042. /*
  1043. * Return a bit from a bignum; 0 is least significant, etc.
  1044. */
  1045. int bignum_bit(Bignum bn, int i)
  1046. {
  1047. if (i < 0 || i >= (int)(BIGNUM_INT_BITS * bn[0]))
  1048. return 0; /* beyond the end */
  1049. else
  1050. return (bn[i / BIGNUM_INT_BITS + 1] >> (i % BIGNUM_INT_BITS)) & 1;
  1051. }
  1052. /*
  1053. * Set a bit in a bignum; 0 is least significant, etc.
  1054. */
  1055. void bignum_set_bit(Bignum bn, int bitnum, int value)
  1056. {
  1057. if (bitnum < 0 || bitnum >= (int)(BIGNUM_INT_BITS * bn[0]))
  1058. abort(); /* beyond the end */
  1059. else {
  1060. int v = bitnum / BIGNUM_INT_BITS + 1;
  1061. BignumInt mask = (BignumInt)1 << (bitnum % BIGNUM_INT_BITS);
  1062. if (value)
  1063. bn[v] |= mask;
  1064. else
  1065. bn[v] &= ~mask;
  1066. }
  1067. }
  1068. /*
  1069. * Write a SSH-1-format bignum into a buffer. It is assumed the
  1070. * buffer is big enough. Returns the number of bytes used.
  1071. */
  1072. int ssh1_write_bignum(void *data, Bignum bn)
  1073. {
  1074. unsigned char *p = data;
  1075. int len = ssh1_bignum_length(bn);
  1076. int i;
  1077. int bitc = bignum_bitcount(bn);
  1078. *p++ = (bitc >> 8) & 0xFF;
  1079. *p++ = (bitc) & 0xFF;
  1080. for (i = len - 2; i--;)
  1081. *p++ = bignum_byte(bn, i);
  1082. return len;
  1083. }
  1084. /*
  1085. * Compare two bignums. Returns like strcmp.
  1086. */
  1087. int bignum_cmp(Bignum a, Bignum b)
  1088. {
  1089. int amax = a[0], bmax = b[0];
  1090. int i;
  1091. /* Annoyingly we have two representations of zero */
  1092. if (amax == 1 && a[amax] == 0)
  1093. amax = 0;
  1094. if (bmax == 1 && b[bmax] == 0)
  1095. bmax = 0;
  1096. assert(amax == 0 || a[amax] != 0);
  1097. assert(bmax == 0 || b[bmax] != 0);
  1098. i = (amax > bmax ? amax : bmax);
  1099. while (i) {
  1100. BignumInt aval = (i > amax ? 0 : a[i]);
  1101. BignumInt bval = (i > bmax ? 0 : b[i]);
  1102. if (aval < bval)
  1103. return -1;
  1104. if (aval > bval)
  1105. return +1;
  1106. i--;
  1107. }
  1108. return 0;
  1109. }
  1110. /*
  1111. * Right-shift one bignum to form another.
  1112. */
  1113. Bignum bignum_rshift(Bignum a, int shift)
  1114. {
  1115. Bignum ret;
  1116. int i, shiftw, shiftb, shiftbb, bits;
  1117. BignumInt ai, ai1;
  1118. assert(shift >= 0);
  1119. bits = bignum_bitcount(a) - shift;
  1120. ret = newbn((bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS);
  1121. if (ret) {
  1122. shiftw = shift / BIGNUM_INT_BITS;
  1123. shiftb = shift % BIGNUM_INT_BITS;
  1124. shiftbb = BIGNUM_INT_BITS - shiftb;
  1125. ai1 = a[shiftw + 1];
  1126. for (i = 1; i <= (int)ret[0]; i++) {
  1127. ai = ai1;
  1128. ai1 = (i + shiftw + 1 <= (int)a[0] ? a[i + shiftw + 1] : 0);
  1129. ret[i] = ((ai >> shiftb) | (ai1 << shiftbb)) & BIGNUM_INT_MASK;
  1130. }
  1131. }
  1132. return ret;
  1133. }
  1134. /*
  1135. * Non-modular multiplication and addition.
  1136. */
  1137. Bignum bigmuladd(Bignum a, Bignum b, Bignum addend)
  1138. {
  1139. int alen = a[0], blen = b[0];
  1140. int mlen = (alen > blen ? alen : blen);
  1141. int rlen, i, maxspot;
  1142. int wslen;
  1143. BignumInt *workspace;
  1144. Bignum ret;
  1145. /* mlen space for a, mlen space for b, 2*mlen for result,
  1146. * plus scratch space for multiplication */
  1147. wslen = mlen * 4 + mul_compute_scratch(mlen);
  1148. workspace = snewn(wslen, BignumInt);
  1149. for (i = 0; i < mlen; i++) {
  1150. workspace[0 * mlen + i] = (mlen - i <= (int)a[0] ? a[mlen - i] : 0);
  1151. workspace[1 * mlen + i] = (mlen - i <= (int)b[0] ? b[mlen - i] : 0);
  1152. }
  1153. internal_mul(workspace + 0 * mlen, workspace + 1 * mlen,
  1154. workspace + 2 * mlen, mlen, workspace + 4 * mlen);
  1155. /* now just copy the result back */
  1156. rlen = alen + blen + 1;
  1157. if (addend && rlen <= (int)addend[0])
  1158. rlen = addend[0] + 1;
  1159. ret = newbn(rlen);
  1160. maxspot = 0;
  1161. for (i = 1; i <= (int)ret[0]; i++) {
  1162. ret[i] = (i <= 2 * mlen ? workspace[4 * mlen - i] : 0);
  1163. if (ret[i] != 0)
  1164. maxspot = i;
  1165. }
  1166. ret[0] = maxspot;
  1167. /* now add in the addend, if any */
  1168. if (addend) {
  1169. BignumDblInt carry = 0;
  1170. for (i = 1; i <= rlen; i++) {
  1171. carry += (i <= (int)ret[0] ? ret[i] : 0);
  1172. carry += (i <= (int)addend[0] ? addend[i] : 0);
  1173. ret[i] = (BignumInt) carry & BIGNUM_INT_MASK;
  1174. carry >>= BIGNUM_INT_BITS;
  1175. if (ret[i] != 0 && i > maxspot)
  1176. maxspot = i;
  1177. }
  1178. }
  1179. ret[0] = maxspot;
  1180. smemclr(workspace, wslen * sizeof(*workspace));
  1181. sfree(workspace);
  1182. return ret;
  1183. }
  1184. /*
  1185. * Non-modular multiplication.
  1186. */
  1187. Bignum bigmul(Bignum a, Bignum b)
  1188. {
  1189. return bigmuladd(a, b, NULL);
  1190. }
  1191. /*
  1192. * Simple addition.
  1193. */
  1194. Bignum bigadd(Bignum a, Bignum b)
  1195. {
  1196. int alen = a[0], blen = b[0];
  1197. int rlen = (alen > blen ? alen : blen) + 1;
  1198. int i, maxspot;
  1199. Bignum ret;
  1200. BignumDblInt carry;
  1201. ret = newbn(rlen);
  1202. carry = 0;
  1203. maxspot = 0;
  1204. for (i = 1; i <= rlen; i++) {
  1205. carry += (i <= (int)a[0] ? a[i] : 0);
  1206. carry += (i <= (int)b[0] ? b[i] : 0);
  1207. ret[i] = (BignumInt) carry & BIGNUM_INT_MASK;
  1208. carry >>= BIGNUM_INT_BITS;
  1209. if (ret[i] != 0 && i > maxspot)
  1210. maxspot = i;
  1211. }
  1212. ret[0] = maxspot;
  1213. return ret;
  1214. }
  1215. /*
  1216. * Subtraction. Returns a-b, or NULL if the result would come out
  1217. * negative (recall that this entire bignum module only handles
  1218. * positive numbers).
  1219. */
  1220. Bignum bigsub(Bignum a, Bignum b)
  1221. {
  1222. int alen = a[0], blen = b[0];
  1223. int rlen = (alen > blen ? alen : blen);
  1224. int i, maxspot;
  1225. Bignum ret;
  1226. BignumDblInt carry;
  1227. ret = newbn(rlen);
  1228. carry = 1;
  1229. maxspot = 0;
  1230. for (i = 1; i <= rlen; i++) {
  1231. carry += (i <= (int)a[0] ? a[i] : 0);
  1232. carry += (i <= (int)b[0] ? b[i] ^ BIGNUM_INT_MASK : BIGNUM_INT_MASK);
  1233. ret[i] = (BignumInt) carry & BIGNUM_INT_MASK;
  1234. carry >>= BIGNUM_INT_BITS;
  1235. if (ret[i] != 0 && i > maxspot)
  1236. maxspot = i;
  1237. }
  1238. ret[0] = maxspot;
  1239. if (!carry) {
  1240. freebn(ret);
  1241. return NULL;
  1242. }
  1243. return ret;
  1244. }
  1245. /*
  1246. * Create a bignum which is the bitmask covering another one. That
  1247. * is, the smallest integer which is >= N and is also one less than
  1248. * a power of two.
  1249. */
  1250. Bignum bignum_bitmask(Bignum n)
  1251. {
  1252. Bignum ret = copybn(n);
  1253. int i;
  1254. BignumInt j;
  1255. i = ret[0];
  1256. while (n[i] == 0 && i > 0)
  1257. i--;
  1258. if (i <= 0)
  1259. return ret; /* input was zero */
  1260. j = 1;
  1261. while (j < n[i])
  1262. j = 2 * j + 1;
  1263. ret[i] = j;
  1264. while (--i > 0)
  1265. ret[i] = BIGNUM_INT_MASK;
  1266. return ret;
  1267. }
  1268. /*
  1269. * Convert a (max 32-bit) long into a bignum.
  1270. */
  1271. Bignum bignum_from_long(unsigned long nn)
  1272. {
  1273. Bignum ret;
  1274. BignumDblInt n = nn;
  1275. ret = newbn(3);
  1276. ret[1] = (BignumInt)(n & BIGNUM_INT_MASK);
  1277. ret[2] = (BignumInt)((n >> BIGNUM_INT_BITS) & BIGNUM_INT_MASK);
  1278. ret[3] = 0;
  1279. ret[0] = (ret[2] ? 2 : 1);
  1280. return ret;
  1281. }
  1282. /*
  1283. * Add a long to a bignum.
  1284. */
  1285. Bignum bignum_add_long(Bignum number, unsigned long addendx)
  1286. {
  1287. Bignum ret = newbn(number[0] + 1);
  1288. int i, maxspot = 0;
  1289. BignumDblInt carry = 0, addend = addendx;
  1290. for (i = 1; i <= (int)ret[0]; i++) {
  1291. carry += addend & BIGNUM_INT_MASK;
  1292. carry += (i <= (int)number[0] ? number[i] : 0);
  1293. addend >>= BIGNUM_INT_BITS;
  1294. ret[i] = (BignumInt) carry & BIGNUM_INT_MASK;
  1295. carry >>= BIGNUM_INT_BITS;
  1296. if (ret[i] != 0)
  1297. maxspot = i;
  1298. }
  1299. ret[0] = maxspot;
  1300. return ret;
  1301. }
  1302. /*
  1303. * Compute the residue of a bignum, modulo a (max 16-bit) short.
  1304. */
  1305. unsigned short bignum_mod_short(Bignum number, unsigned short modulus)
  1306. {
  1307. BignumDblInt mod, r;
  1308. int i;
  1309. r = 0;
  1310. mod = modulus;
  1311. for (i = number[0]; i > 0; i--)
  1312. r = (r * (BIGNUM_TOP_BIT % mod) * 2 + number[i] % mod) % mod;
  1313. return (unsigned short) r;
  1314. }
  1315. #ifdef DEBUG
  1316. void diagbn(char *prefix, Bignum md)
  1317. {
  1318. int i, nibbles, morenibbles;
  1319. static const char hex[] = "0123456789ABCDEF";
  1320. debug(("%s0x", prefix ? prefix : ""));
  1321. nibbles = (3 + bignum_bitcount(md)) / 4;
  1322. if (nibbles < 1)
  1323. nibbles = 1;
  1324. morenibbles = 4 * md[0] - nibbles;
  1325. for (i = 0; i < morenibbles; i++)
  1326. debug(("-"));
  1327. for (i = nibbles; i--;)
  1328. debug(("%c",
  1329. hex[(bignum_byte(md, i / 2) >> (4 * (i % 2))) & 0xF]));
  1330. if (prefix)
  1331. debug(("\n"));
  1332. }
  1333. #endif
  1334. /*
  1335. * Simple division.
  1336. */
  1337. Bignum bigdiv(Bignum a, Bignum b)
  1338. {
  1339. Bignum q = newbn(a[0]);
  1340. bigdivmod(a, b, NULL, q);
  1341. while (q[0] > 1 && q[q[0]] == 0)
  1342. q[0]--;
  1343. return q;
  1344. }
  1345. /*
  1346. * Simple remainder.
  1347. */
  1348. Bignum bigmod(Bignum a, Bignum b)
  1349. {
  1350. Bignum r = newbn(b[0]);
  1351. bigdivmod(a, b, r, NULL);
  1352. while (r[0] > 1 && r[r[0]] == 0)
  1353. r[0]--;
  1354. return r;
  1355. }
  1356. /*
  1357. * Greatest common divisor.
  1358. */
  1359. Bignum biggcd(Bignum av, Bignum bv)
  1360. {
  1361. Bignum a = copybn(av);
  1362. Bignum b = copybn(bv);
  1363. while (bignum_cmp(b, Zero) != 0) {
  1364. Bignum t = newbn(b[0]);
  1365. bigdivmod(a, b, t, NULL);
  1366. while (t[0] > 1 && t[t[0]] == 0)
  1367. t[0]--;
  1368. freebn(a);
  1369. a = b;
  1370. b = t;
  1371. }
  1372. freebn(b);
  1373. return a;
  1374. }
  1375. /*
  1376. * Modular inverse, using Euclid's extended algorithm.
  1377. */
  1378. Bignum modinv(Bignum number, Bignum modulus)
  1379. {
  1380. Bignum a = copybn(modulus);
  1381. Bignum b = copybn(number);
  1382. Bignum xp = copybn(Zero);
  1383. Bignum x = copybn(One);
  1384. int sign = +1;
  1385. assert(number[number[0]] != 0);
  1386. assert(modulus[modulus[0]] != 0);
  1387. while (bignum_cmp(b, One) != 0) {
  1388. Bignum t, q;
  1389. if (bignum_cmp(b, Zero) == 0) {
  1390. /*
  1391. * Found a common factor between the inputs, so we cannot
  1392. * return a modular inverse at all.
  1393. */
  1394. freebn(b);
  1395. freebn(a);
  1396. freebn(xp);
  1397. freebn(x);
  1398. return NULL;
  1399. }
  1400. t = newbn(b[0]);
  1401. q = newbn(a[0]);
  1402. bigdivmod(a, b, t, q);
  1403. while (t[0] > 1 && t[t[0]] == 0)
  1404. t[0]--;
  1405. while (q[0] > 1 && q[q[0]] == 0)
  1406. q[0]--;
  1407. freebn(a);
  1408. a = b;
  1409. b = t;
  1410. t = xp;
  1411. xp = x;
  1412. x = bigmuladd(q, xp, t);
  1413. sign = -sign;
  1414. freebn(t);
  1415. freebn(q);
  1416. }
  1417. freebn(b);
  1418. freebn(a);
  1419. freebn(xp);
  1420. /* now we know that sign * x == 1, and that x < modulus */
  1421. if (sign < 0) {
  1422. /* set a new x to be modulus - x */
  1423. Bignum newx = newbn(modulus[0]);
  1424. BignumInt carry = 0;
  1425. int maxspot = 1;
  1426. int i;
  1427. for (i = 1; i <= (int)newx[0]; i++) {
  1428. BignumInt aword = (i <= (int)modulus[0] ? modulus[i] : 0);
  1429. BignumInt bword = (i <= (int)x[0] ? x[i] : 0);
  1430. newx[i] = aword - bword - carry;
  1431. bword = ~bword;
  1432. carry = carry ? (newx[i] >= bword) : (newx[i] > bword);
  1433. if (newx[i] != 0)
  1434. maxspot = i;
  1435. }
  1436. newx[0] = maxspot;
  1437. freebn(x);
  1438. x = newx;
  1439. }
  1440. /* and return. */
  1441. return x;
  1442. }
  1443. /*
  1444. * Render a bignum into decimal. Return a malloced string holding
  1445. * the decimal representation.
  1446. */
  1447. char *bignum_decimal(Bignum x)
  1448. {
  1449. int ndigits, ndigit;
  1450. int i, iszero;
  1451. BignumDblInt carry;
  1452. char *ret;
  1453. BignumInt *workspace;
  1454. /*
  1455. * First, estimate the number of digits. Since log(10)/log(2)
  1456. * is just greater than 93/28 (the joys of continued fraction
  1457. * approximations...) we know that for every 93 bits, we need
  1458. * at most 28 digits. This will tell us how much to malloc.
  1459. *
  1460. * Formally: if x has i bits, that means x is strictly less
  1461. * than 2^i. Since 2 is less than 10^(28/93), this is less than
  1462. * 10^(28i/93). We need an integer power of ten, so we must
  1463. * round up (rounding down might make it less than x again).
  1464. * Therefore if we multiply the bit count by 28/93, rounding
  1465. * up, we will have enough digits.
  1466. *
  1467. * i=0 (i.e., x=0) is an irritating special case.
  1468. */
  1469. i = bignum_bitcount(x);
  1470. if (!i)
  1471. ndigits = 1; /* x = 0 */
  1472. else
  1473. ndigits = (28 * i + 92) / 93; /* multiply by 28/93 and round up */
  1474. ndigits++; /* allow for trailing \0 */
  1475. ret = snewn(ndigits, char);
  1476. /*
  1477. * Now allocate some workspace to hold the binary form as we
  1478. * repeatedly divide it by ten. Initialise this to the
  1479. * big-endian form of the number.
  1480. */
  1481. workspace = snewn(x[0], BignumInt);
  1482. for (i = 0; i < (int)x[0]; i++)
  1483. workspace[i] = x[x[0] - i];
  1484. /*
  1485. * Next, write the decimal number starting with the last digit.
  1486. * We use ordinary short division, dividing 10 into the
  1487. * workspace.
  1488. */
  1489. ndigit = ndigits - 1;
  1490. ret[ndigit] = '\0';
  1491. do {
  1492. iszero = 1;
  1493. carry = 0;
  1494. for (i = 0; i < (int)x[0]; i++) {
  1495. carry = (carry << BIGNUM_INT_BITS) + workspace[i];
  1496. workspace[i] = (BignumInt) (carry / 10);
  1497. if (workspace[i])
  1498. iszero = 0;
  1499. carry %= 10;
  1500. }
  1501. ret[--ndigit] = (char) (carry + '0');
  1502. } while (!iszero);
  1503. /*
  1504. * There's a chance we've fallen short of the start of the
  1505. * string. Correct if so.
  1506. */
  1507. if (ndigit > 0)
  1508. memmove(ret, ret + ndigit, ndigits - ndigit);
  1509. /*
  1510. * Done.
  1511. */
  1512. smemclr(workspace, x[0] * sizeof(*workspace));
  1513. sfree(workspace);
  1514. return ret;
  1515. }
  1516. #ifdef TESTBN
  1517. #include <stdio.h>
  1518. #include <stdlib.h>
  1519. #include <ctype.h>
  1520. /*
  1521. * gcc -Wall -g -O0 -DTESTBN -o testbn sshbn.c misc.c conf.c tree234.c unix/uxmisc.c -I. -I unix -I charset
  1522. *
  1523. * Then feed to this program's standard input the output of
  1524. * testdata/bignum.py .
  1525. */
  1526. void modalfatalbox(char *p, ...)
  1527. {
  1528. va_list ap;
  1529. fprintf(stderr, "FATAL ERROR: ");
  1530. va_start(ap, p);
  1531. vfprintf(stderr, p, ap);
  1532. va_end(ap);
  1533. fputc('\n', stderr);
  1534. exit(1);
  1535. }
  1536. int random_byte(void)
  1537. {
  1538. modalfatalbox("random_byte called in testbn");
  1539. return 0;
  1540. }
  1541. #define fromxdigit(c) ( (c)>'9' ? ((c)&0xDF) - 'A' + 10 : (c) - '0' )
  1542. int main(int argc, char **argv)
  1543. {
  1544. char *buf;
  1545. int line = 0;
  1546. int passes = 0, fails = 0;
  1547. while ((buf = fgetline(stdin)) != NULL) {
  1548. int maxlen = strlen(buf);
  1549. unsigned char *data = snewn(maxlen, unsigned char);
  1550. unsigned char *ptrs[5], *q;
  1551. int ptrnum;
  1552. char *bufp = buf;
  1553. line++;
  1554. q = data;
  1555. ptrnum = 0;
  1556. while (*bufp && !isspace((unsigned char)*bufp))
  1557. bufp++;
  1558. if (bufp)
  1559. *bufp++ = '\0';
  1560. while (*bufp) {
  1561. char *start, *end;
  1562. int i;
  1563. while (*bufp && !isxdigit((unsigned char)*bufp))
  1564. bufp++;
  1565. start = bufp;
  1566. if (!*bufp)
  1567. break;
  1568. while (*bufp && isxdigit((unsigned char)*bufp))
  1569. bufp++;
  1570. end = bufp;
  1571. if (ptrnum >= lenof(ptrs))
  1572. break;
  1573. ptrs[ptrnum++] = q;
  1574. for (i = -((end - start) & 1); i < end-start; i += 2) {
  1575. unsigned char val = (i < 0 ? 0 : fromxdigit(start[i]));
  1576. val = val * 16 + fromxdigit(start[i+1]);
  1577. *q++ = val;
  1578. }
  1579. ptrs[ptrnum] = q;
  1580. }
  1581. if (!strcmp(buf, "mul")) {
  1582. Bignum a, b, c, p;
  1583. if (ptrnum != 3) {
  1584. printf("%d: mul with %d parameters, expected 3\n", line, ptrnum);
  1585. exit(1);
  1586. }
  1587. a = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
  1588. b = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
  1589. c = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
  1590. p = bigmul(a, b);
  1591. if (bignum_cmp(c, p) == 0) {
  1592. passes++;
  1593. } else {
  1594. char *as = bignum_decimal(a);
  1595. char *bs = bignum_decimal(b);
  1596. char *cs = bignum_decimal(c);
  1597. char *ps = bignum_decimal(p);
  1598. printf("%d: fail: %s * %s gave %s expected %s\n",
  1599. line, as, bs, ps, cs);
  1600. fails++;
  1601. sfree(as);
  1602. sfree(bs);
  1603. sfree(cs);
  1604. sfree(ps);
  1605. }
  1606. freebn(a);
  1607. freebn(b);
  1608. freebn(c);
  1609. freebn(p);
  1610. } else if (!strcmp(buf, "modmul")) {
  1611. Bignum a, b, m, c, p;
  1612. if (ptrnum != 4) {
  1613. printf("%d: modmul with %d parameters, expected 4\n",
  1614. line, ptrnum);
  1615. exit(1);
  1616. }
  1617. a = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
  1618. b = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
  1619. m = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
  1620. c = bignum_from_bytes(ptrs[3], ptrs[4]-ptrs[3]);
  1621. p = modmul(a, b, m);
  1622. if (bignum_cmp(c, p) == 0) {
  1623. passes++;
  1624. } else {
  1625. char *as = bignum_decimal(a);
  1626. char *bs = bignum_decimal(b);
  1627. char *ms = bignum_decimal(m);
  1628. char *cs = bignum_decimal(c);
  1629. char *ps = bignum_decimal(p);
  1630. printf("%d: fail: %s * %s mod %s gave %s expected %s\n",
  1631. line, as, bs, ms, ps, cs);
  1632. fails++;
  1633. sfree(as);
  1634. sfree(bs);
  1635. sfree(ms);
  1636. sfree(cs);
  1637. sfree(ps);
  1638. }
  1639. freebn(a);
  1640. freebn(b);
  1641. freebn(m);
  1642. freebn(c);
  1643. freebn(p);
  1644. } else if (!strcmp(buf, "pow")) {
  1645. Bignum base, expt, modulus, expected, answer;
  1646. if (ptrnum != 4) {
  1647. printf("%d: mul with %d parameters, expected 4\n", line, ptrnum);
  1648. exit(1);
  1649. }
  1650. base = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]);
  1651. expt = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]);
  1652. modulus = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]);
  1653. expected = bignum_from_bytes(ptrs[3], ptrs[4]-ptrs[3]);
  1654. answer = modpow(base, expt, modulus);
  1655. if (bignum_cmp(expected, answer) == 0) {
  1656. passes++;
  1657. } else {
  1658. char *as = bignum_decimal(base);
  1659. char *bs = bignum_decimal(expt);
  1660. char *cs = bignum_decimal(modulus);
  1661. char *ds = bignum_decimal(answer);
  1662. char *ps = bignum_decimal(expected);
  1663. printf("%d: fail: %s ^ %s mod %s gave %s expected %s\n",
  1664. line, as, bs, cs, ds, ps);
  1665. fails++;
  1666. sfree(as);
  1667. sfree(bs);
  1668. sfree(cs);
  1669. sfree(ds);
  1670. sfree(ps);
  1671. }
  1672. freebn(base);
  1673. freebn(expt);
  1674. freebn(modulus);
  1675. freebn(expected);
  1676. freebn(answer);
  1677. } else {
  1678. printf("%d: unrecognised test keyword: '%s'\n", line, buf);
  1679. exit(1);
  1680. }
  1681. sfree(buf);
  1682. sfree(data);
  1683. }
  1684. printf("passed %d failed %d total %d\n", passes, fails, passes+fails);
  1685. return fails != 0;
  1686. }
  1687. #endif