| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391 |
- /*
- * Centralised machinery for hybridised post-quantum + classical key
- * exchange setups, using the same message structure as ECDH but the
- * strings sent each way are the concatenation of a key or ciphertext
- * of each type, and the output shared secret is obtained by hashing
- * together both of the sub-methods' outputs.
- */
- #include <stdio.h>
- #include <stdlib.h>
- #include <assert.h>
- #include "putty.h"
- #include "ssh.h"
- #include "mpint.h"
- /* ----------------------------------------------------------------------
- * Common definitions between client and server sides.
- */
- typedef struct hybrid_alg hybrid_alg;
- struct hybrid_alg {
- const ssh_hashalg *combining_hash;
- const pq_kemalg *pq_alg;
- const ssh_kex *classical_alg;
- void (*reformat)(ptrlen input, BinarySink *output);
- };
- static char *hybrid_description(const ssh_kex *kex)
- {
- const struct hybrid_alg *alg = kex->extra;
- /* Bit of a bodge, but think up a short name to describe the
- * classical algorithm */
- const char *classical_name;
- if (alg->classical_alg == &ssh_ec_kex_curve25519)
- classical_name = "Curve25519";
- else if (alg->classical_alg == &ssh_ec_kex_nistp256)
- classical_name = "NIST P256";
- else if (alg->classical_alg == &ssh_ec_kex_nistp384)
- classical_name = "NIST P384";
- else
- unreachable("don't have a name for this classical alg");
- return dupprintf("%s / %s hybrid key exchange",
- alg->pq_alg->description, classical_name);
- }
- static void reformat_mpint_be(ptrlen input, BinarySink *output, size_t bytes)
- {
- BinarySource src[1];
- BinarySource_BARE_INIT_PL(src, input);
- mp_int *mp = get_mp_ssh2(src);
- assert(!get_err(src));
- assert(get_avail(src) == 0);
- for (size_t i = bytes; i-- > 0 ;)
- put_byte(output, mp_get_byte(mp, i));
- mp_free(mp);
- }
- static void reformat_mpint_be_32(ptrlen input, BinarySink *output)
- {
- reformat_mpint_be(input, output, 32);
- }
- static void reformat_mpint_be_48(ptrlen input, BinarySink *output)
- {
- reformat_mpint_be(input, output, 48);
- }
- /* ----------------------------------------------------------------------
- * Client side.
- */
- typedef struct hybrid_client_state hybrid_client_state;
- static const ecdh_keyalg hybrid_client_vt;
- struct hybrid_client_state {
- const hybrid_alg *alg;
- strbuf *pq_ek;
- pq_kem_dk *pq_dk;
- ecdh_key *classical;
- ecdh_key ek;
- };
- static ecdh_key *hybrid_client_new(const ssh_kex *kex, bool is_server)
- {
- assert(!is_server);
- hybrid_client_state *s = snew(hybrid_client_state);
- s->alg = kex->extra;
- s->ek.vt = &hybrid_client_vt;
- s->pq_ek = strbuf_new();
- s->pq_dk = pq_kem_keygen(s->alg->pq_alg, BinarySink_UPCAST(s->pq_ek));
- s->classical = ecdh_key_new(s->alg->classical_alg, is_server);
- return &s->ek;
- }
- static void hybrid_client_free(ecdh_key *ek)
- {
- hybrid_client_state *s = container_of(ek, hybrid_client_state, ek);
- strbuf_free(s->pq_ek);
- pq_kem_free_dk(s->pq_dk);
- ecdh_key_free(s->classical);
- sfree(s);
- }
- /*
- * In the client, getpublic is called first: we make up a KEM key
- * pair, and send the public key along with a classical DH value.
- */
- static void hybrid_client_getpublic(ecdh_key *ek, BinarySink *bs)
- {
- hybrid_client_state *s = container_of(ek, hybrid_client_state, ek);
- put_datapl(bs, ptrlen_from_strbuf(s->pq_ek));
- ecdh_key_getpublic(s->classical, bs);
- }
- /*
- * In the client, getkey is called second, after the server sends its
- * response: we use our KEM private key to decapsulate the server's
- * ciphertext.
- */
- static bool hybrid_client_getkey(ecdh_key *ek, ptrlen remoteKey, BinarySink *bs)
- {
- hybrid_client_state *s = container_of(ek, hybrid_client_state, ek);
- BinarySource src[1];
- BinarySource_BARE_INIT_PL(src, remoteKey);
- ssh_hash *h = ssh_hash_new(s->alg->combining_hash);
- ptrlen pq_ciphertext = get_data(src, s->alg->pq_alg->c_len);
- if (get_err(src)) {
- ssh_hash_free(h);
- return false; /* not enough data */
- }
- if (!pq_kem_decaps(s->pq_dk, BinarySink_UPCAST(h), pq_ciphertext)) {
- ssh_hash_free(h);
- return false; /* pq ciphertext didn't validate */
- }
- ptrlen classical_data = get_data(src, get_avail(src));
- strbuf *classical_key = strbuf_new();
- if (!ecdh_key_getkey(s->classical, classical_data,
- BinarySink_UPCAST(classical_key))) {
- ssh_hash_free(h);
- return false; /* classical DH key didn't validate */
- }
- s->alg->reformat(ptrlen_from_strbuf(classical_key), BinarySink_UPCAST(h));
- strbuf_free(classical_key);
- /*
- * Finish up: compute the final output hash and return it encoded
- * as a string.
- */
- unsigned char hashdata[MAX_HASH_LEN];
- ssh_hash_final(h, hashdata);
- put_stringpl(bs, make_ptrlen(hashdata, s->alg->combining_hash->hlen));
- smemclr(hashdata, sizeof(hashdata));
- return true;
- }
- static const ecdh_keyalg hybrid_client_vt = {
- .new = hybrid_client_new, /* but normally the selector calls this */
- .free = hybrid_client_free,
- .getpublic = hybrid_client_getpublic,
- .getkey = hybrid_client_getkey,
- .description = hybrid_description,
- .packet_naming_ctx = SSH2_PKTCTX_HYBRIDKEX,
- };
- /* ----------------------------------------------------------------------
- * Server side.
- */
- typedef struct hybrid_server_state hybrid_server_state;
- static const ecdh_keyalg hybrid_server_vt;
- struct hybrid_server_state {
- const hybrid_alg *alg;
- strbuf *pq_ciphertext;
- ecdh_key *classical;
- ecdh_key ek;
- };
- static ecdh_key *hybrid_server_new(const ssh_kex *kex, bool is_server)
- {
- assert(is_server);
- hybrid_server_state *s = snew(hybrid_server_state);
- s->alg = kex->extra;
- s->ek.vt = &hybrid_server_vt;
- s->pq_ciphertext = strbuf_new_nm();
- s->classical = ecdh_key_new(s->alg->classical_alg, is_server);
- return &s->ek;
- }
- static void hybrid_server_free(ecdh_key *ek)
- {
- hybrid_server_state *s = container_of(ek, hybrid_server_state, ek);
- strbuf_free(s->pq_ciphertext);
- ecdh_key_free(s->classical);
- sfree(s);
- }
- /*
- * In the server, getkey is called first: we receive a KEM encryption
- * key from the client and encapsulate a secret with it. We write the
- * output secret to bs; the data we'll send to the client is saved to
- * return from getpublic.
- */
- static bool hybrid_server_getkey(ecdh_key *ek, ptrlen remoteKey, BinarySink *bs)
- {
- hybrid_server_state *s = container_of(ek, hybrid_server_state, ek);
- BinarySource src[1];
- BinarySource_BARE_INIT_PL(src, remoteKey);
- ssh_hash *h = ssh_hash_new(s->alg->combining_hash);
- ptrlen pq_ek = get_data(src, s->alg->pq_alg->ek_len);
- if (get_err(src)) {
- ssh_hash_free(h);
- return false; /* not enough data */
- }
- if (!pq_kem_encaps(s->alg->pq_alg,
- BinarySink_UPCAST(s->pq_ciphertext),
- BinarySink_UPCAST(h), pq_ek)) {
- ssh_hash_free(h);
- return false; /* pq encryption key didn't validate */
- }
- ptrlen classical_data = get_data(src, get_avail(src));
- strbuf *classical_key = strbuf_new();
- if (!ecdh_key_getkey(s->classical, classical_data,
- BinarySink_UPCAST(classical_key))) {
- ssh_hash_free(h);
- return false; /* classical DH key didn't validate */
- }
- s->alg->reformat(ptrlen_from_strbuf(classical_key), BinarySink_UPCAST(h));
- strbuf_free(classical_key);
- /*
- * Finish up: compute the final output hash and return it encoded
- * as a string.
- */
- unsigned char hashdata[MAX_HASH_LEN];
- ssh_hash_final(h, hashdata);
- put_stringpl(bs, make_ptrlen(hashdata, s->alg->combining_hash->hlen));
- smemclr(hashdata, sizeof(hashdata));
- return true;
- }
- static void hybrid_server_getpublic(ecdh_key *ek, BinarySink *bs)
- {
- hybrid_server_state *s = container_of(ek, hybrid_server_state, ek);
- put_datapl(bs, ptrlen_from_strbuf(s->pq_ciphertext));
- ecdh_key_getpublic(s->classical, bs);
- }
- static const ecdh_keyalg hybrid_server_vt = {
- .new = hybrid_server_new, /* but normally the selector calls this */
- .free = hybrid_server_free,
- .getkey = hybrid_server_getkey,
- .getpublic = hybrid_server_getpublic,
- .description = hybrid_description,
- .packet_naming_ctx = SSH2_PKTCTX_HYBRIDKEX,
- };
- /* ----------------------------------------------------------------------
- * Selector vtable that instantiates the appropriate one of the above,
- * depending on is_server.
- */
- static ecdh_key *hybrid_selector_new(const ssh_kex *kex, bool is_server)
- {
- if (is_server)
- return hybrid_server_new(kex, is_server);
- else
- return hybrid_client_new(kex, is_server);
- }
- static const ecdh_keyalg hybrid_selector_vt = {
- /* This is a never-instantiated vtable which only implements the
- * functions that don't require an instance. */
- .new = hybrid_selector_new,
- .description = hybrid_description,
- .packet_naming_ctx = SSH2_PKTCTX_HYBRIDKEX,
- };
- /* ----------------------------------------------------------------------
- * Actual KEX methods.
- */
- static const hybrid_alg ssh_ntru_curve25519_hybrid = {
- .combining_hash = &ssh_sha512,
- .pq_alg = &ssh_ntru,
- .classical_alg = &ssh_ec_kex_curve25519,
- .reformat = reformat_mpint_be_32,
- };
- static const ssh_kex ssh_ntru_curve25519 = {
- .name = "sntrup761x25519-sha512",
- .main_type = KEXTYPE_ECDH,
- .hash = &ssh_sha512,
- .ecdh_vt = &hybrid_selector_vt,
- .extra = &ssh_ntru_curve25519_hybrid,
- };
- static const ssh_kex ssh_ntru_curve25519_openssh = {
- .name = "[email protected]",
- .main_type = KEXTYPE_ECDH,
- .hash = &ssh_sha512,
- .ecdh_vt = &hybrid_selector_vt,
- .extra = &ssh_ntru_curve25519_hybrid,
- };
- static const ssh_kex *const ntru_hybrid_list[] = {
- &ssh_ntru_curve25519,
- &ssh_ntru_curve25519_openssh,
- };
- const ssh_kexes ssh_ntru_hybrid_kex = {
- lenof(ntru_hybrid_list), ntru_hybrid_list,
- };
- static const hybrid_alg ssh_mlkem768_curve25519_hybrid = {
- .combining_hash = &ssh_sha256,
- .pq_alg = &ssh_mlkem768,
- .classical_alg = &ssh_ec_kex_curve25519,
- .reformat = reformat_mpint_be_32,
- };
- static const ssh_kex ssh_mlkem768_curve25519 = {
- .name = "mlkem768x25519-sha256",
- .main_type = KEXTYPE_ECDH,
- .hash = &ssh_sha256,
- .ecdh_vt = &hybrid_selector_vt,
- .extra = &ssh_mlkem768_curve25519_hybrid,
- };
- static const ssh_kex *const mlkem_curve25519_hybrid_list[] = {
- &ssh_mlkem768_curve25519,
- };
- const ssh_kexes ssh_mlkem_curve25519_hybrid_kex = {
- lenof(mlkem_curve25519_hybrid_list), mlkem_curve25519_hybrid_list,
- };
- static const hybrid_alg ssh_mlkem768_p256_hybrid = {
- .combining_hash = &ssh_sha256,
- .pq_alg = &ssh_mlkem768,
- .classical_alg = &ssh_ec_kex_nistp256,
- .reformat = reformat_mpint_be_32,
- };
- static const ssh_kex ssh_mlkem768_p256 = {
- .name = "mlkem768nistp256-sha256",
- .main_type = KEXTYPE_ECDH,
- .hash = &ssh_sha256,
- .ecdh_vt = &hybrid_selector_vt,
- .extra = &ssh_mlkem768_p256_hybrid,
- };
- static const hybrid_alg ssh_mlkem1024_p384_hybrid = {
- .combining_hash = &ssh_sha384,
- .pq_alg = &ssh_mlkem1024,
- .classical_alg = &ssh_ec_kex_nistp384,
- .reformat = reformat_mpint_be_48,
- };
- static const ssh_kex ssh_mlkem1024_p384 = {
- .name = "mlkem1024nistp384-sha384",
- .main_type = KEXTYPE_ECDH,
- .hash = &ssh_sha384,
- .ecdh_vt = &hybrid_selector_vt,
- .extra = &ssh_mlkem1024_p384_hybrid,
- };
- static const ssh_kex *const mlkem_nist_hybrid_list[] = {
- &ssh_mlkem1024_p384,
- &ssh_mlkem768_p256,
- };
- const ssh_kexes ssh_mlkem_nist_hybrid_kex = {
- lenof(mlkem_nist_hybrid_list), mlkem_nist_hybrid_list,
- };
|