Source code for quantum_safe.kem.hybrid

"""
quantum_safe.kem.hybrid
~~~~~~~~~~~~~~~~~~~~~~~~

HybridKEM: combines a classical Diffie-Hellman KEM (X25519 or P-256) with
a PQC KEM (ML-KEM) into a single hybrid operation.

The construction is based on draft-ietf-tls-hybrid-design and is exactly
what TLS 1.3 hybrid key exchange uses (the X25519MLKEM768 group in RFC 9001
and the IANA TLS group registry).

Why hybrid?
-----------
During the transition period, we can't be certain that ML-KEM is unbroken.
NIST standardized it, but it's young — a decade-old algorithm with a massive
cryptanalysis community behind it is worth more assurance than mathematical
proofs alone. Hybrid mode means:

  - If ML-KEM is broken but X25519 isn't: the hybrid is still X25519-secure.
  - If X25519 is broken by a quantum computer but ML-KEM isn't: the hybrid
    is still ML-KEM-secure.
  - Both would have to be broken simultaneously for the hybrid to fail.

This is the position taken by NIST, CISA, BSI, NCSC, and every major TLS
library that has added PQC support.

The combiner
------------
Given:
  - X25519 ephemeral keypair: (epk_x, esk_x)
  - ML-KEM keypair: (pk_m, sk_m)

Encapsulate:
  1. Generate ephemeral X25519 keypair (epk_x, esk_x).
  2. Compute X25519 DH: ss_x = X25519(esk_x, pk_x_recipient)
  3. Run ML-KEM encapsulate: (ct_m, ss_m) = MLKEMEncap(pk_m)
  4. Combined ciphertext: ct = len(epk_x) || epk_x || ct_m
     (ephemeral public key replaces a traditional ciphertext for X25519)
  5. Combined secret: ss = HKDF(ikm=ss_x||ss_m, salt=ct_x||ct_m, info=...)

Decapsulate:
  1. Split ct into epk_x and ct_m.
  2. Compute X25519 DH: ss_x = X25519(sk_x, epk_x)
  3. Run ML-KEM decapsulate: ss_m = MLKEMDecap(sk_m, ct_m)
  4. Derive combined secret with same HKDF call.

The public key for a hybrid KEM is (pk_x || pk_m) — both components.
The secret key is (sk_x || sk_m) — both components.

Key format
----------
We store hybrid keys in a length-prefixed format so we can split them
without out-of-band length information:

    2 bytes big-endian: classical_component_len
    N bytes:            classical component (public or secret key)
    remaining bytes:    pqc component

This is the same framing as HybridCipherText.
"""

from __future__ import annotations

import struct
from typing import TYPE_CHECKING, cast

from cryptography.hazmat.primitives.asymmetric.x25519 import (
    X25519PrivateKey,
    X25519PublicKey,
)
from cryptography.hazmat.primitives.serialization import (
    Encoding,
    NoEncryption,
    PrivateFormat,
    PublicFormat,
)

from quantum_safe.backends import get_kem_backend
from quantum_safe.exceptions import (
    DecapsulationError,
    UnsupportedAlgorithm,
)
from quantum_safe.kem.algorithms import (
    DEFAULT_HYBRID_CLASSICAL,
    DEFAULT_HYBRID_PQC,
    canonical_hybrid_name,
    validate_hybrid_combination,
)
from quantum_safe.types import (
    HybridCipherText,
    KeyPair,
    MigrationState,
    PublicKey,
    SecretKey,
    SharedSecret,
    combine_shared_secrets,
)

if TYPE_CHECKING:
    from quantum_safe.backends.base import AbstractKEMBackend


# Length prefix format: 2-byte big-endian uint16
_LEN_FMT = ">H"
_LEN_SIZE = 2


def _pack_components(a: bytes, b: bytes) -> bytes:
    """Pack two byte strings with a length prefix on the first."""
    return struct.pack(_LEN_FMT, len(a)) + a + b


def _unpack_components(data: bytes, context: str = "") -> tuple[bytes, bytes]:
    """Unpack two byte strings packed by _pack_components."""
    if len(data) < _LEN_SIZE:
        raise DecapsulationError(algo=context)
    (a_len,) = struct.unpack_from(_LEN_FMT, data, 0)
    if len(data) < _LEN_SIZE + a_len:
        raise DecapsulationError(algo=context)
    a = data[_LEN_SIZE : _LEN_SIZE + a_len]
    b = data[_LEN_SIZE + a_len :]
    return a, b


[docs] class HybridKEM: """Hybrid KEM: classical Diffie-Hellman + post-quantum KEM. Default configuration: X25519 + ML-KEM-768. This matches the TLS 1.3 hybrid group X25519MLKEM768 and is recommended by all major standards bodies for the current transition period. Args: classical: Classical KEM algorithm. Currently "X25519" or "P-256". Default: "X25519". pqc: PQC KEM algorithm. Default: "ML-KEM-768". backend: Backend for PQC operations: "auto", "liboqs", "rustcrypto". Default: "auto". validate: If True (default), validate that the classical+pqc combination is an approved hybrid. Set False only if you're testing a non-standard combination. Example:: from quantum_safe import HybridKEM kem = HybridKEM() # X25519 + ML-KEM-768 kp = kem.generate_keypair() ct, ss = kem.encapsulate(kp.public) ss2 = kem.decapsulate(kp.secret, ct) assert ss == ss2 """ def __init__( self, classical: str = DEFAULT_HYBRID_CLASSICAL, pqc: str = DEFAULT_HYBRID_PQC, backend: str = "auto", validate: bool = True, ) -> None: if validate: validate_hybrid_combination(classical, pqc) self._classical = classical self._pqc = pqc self._algorithm = canonical_hybrid_name(classical, pqc) self._backend: AbstractKEMBackend = get_kem_backend(backend) # Pre-validate that the backend supports the PQC algorithm supported = {a.name for a in self._backend.supported_algorithms()} if pqc not in supported and self._backend.is_available(): raise UnsupportedAlgorithm(pqc, available=list(supported)) @property def algorithm(self) -> str: """Full hybrid algorithm string, e.g. 'X25519+ML-KEM-768'.""" return self._algorithm @property def classical_algorithm(self) -> str: return self._classical @property def pqc_algorithm(self) -> str: return self._pqc @property def backend_name(self) -> str: return self._backend.name # ------------------------------------------------------------------ # Key generation # ------------------------------------------------------------------
[docs] def generate_keypair(self) -> KeyPair: """Generate a hybrid key pair. The public key contains both the X25519 public key and the ML-KEM public key, packed with a length prefix. The secret key is similarly structured. The migration_state is set to HYBRID_TRANSITION by default — this key participates in the current hybrid deployment. Returns: KeyPair where both .public and .secret contain hybrid key material. """ # Generate classical component classical_priv, classical_pub = self._gen_classical_keypair() # Generate PQC component via backend pqc_pub_bytes, pqc_sec_bytes = self._backend.keygen(self._pqc) # Pack: length-prefix classical, append PQC combined_pub = _pack_components(classical_pub, pqc_pub_bytes) combined_sec = _pack_components(classical_priv, pqc_sec_bytes) pub = PublicKey( raw=combined_pub, algorithm=self._algorithm, migration_state=MigrationState.HYBRID_TRANSITION, backend_tag=self._backend.name, ) sec = SecretKey( raw=combined_sec, algorithm=self._algorithm, migration_state=MigrationState.HYBRID_TRANSITION, backend_tag=self._backend.name, ) return KeyPair(public=pub, secret=sec)
# ------------------------------------------------------------------ # Encapsulate # ------------------------------------------------------------------
[docs] def encapsulate(self, public_key: PublicKey) -> tuple[HybridCipherText, SharedSecret]: """Encapsulate a shared secret under the recipient's hybrid public key. Args: public_key: The recipient's HybridKEM public key. Returns: (ct, ss): HybridCipherText to send to the recipient, SharedSecret for the sender's use. Note: The HybridCipherText.to_bytes() gives you the wire-format bytes to transmit. The SharedSecret is 32 bytes of combined key material derived from both the classical and PQC exchanges. """ if public_key.algorithm != self._algorithm: raise UnsupportedAlgorithm( public_key.algorithm, available=[self._algorithm], ) # Split the recipient's combined public key classical_pub_bytes, pqc_pub_bytes = _unpack_components( public_key.raw_bytes, context=self._algorithm ) # --- Classical half --- classical_ct, classical_ss = self._encapsulate_classical(classical_pub_bytes) # --- PQC half --- pqc_ct_bytes, pqc_ss_bytes = self._backend.encapsulate(self._pqc, pqc_pub_bytes) # --- Combine --- hct = HybridCipherText( classical_ct=classical_ct, pqc_ct=pqc_ct_bytes, algorithm=self._algorithm, ) combined_ss = combine_shared_secrets( classical_ss=classical_ss, pqc_ss=pqc_ss_bytes[:32], algorithm=self._algorithm, classical_ct=classical_ct, pqc_ct=pqc_ct_bytes, ) return hct, combined_ss
# ------------------------------------------------------------------ # Decapsulate # ------------------------------------------------------------------
[docs] def decapsulate(self, secret_key: SecretKey, ciphertext: HybridCipherText) -> SharedSecret: """Decapsulate: recover the shared secret from a hybrid ciphertext. Args: secret_key: The recipient's HybridKEM secret key. ciphertext: HybridCipherText from the sender. Returns: SharedSecret matching the sender's. Raises: DecapsulationError: on structural failures. Note that ML-KEM uses implicit rejection, so a bad ML-KEM ciphertext returns a pseudorandom value rather than failing — this is by design. """ if secret_key.algorithm != self._algorithm: raise UnsupportedAlgorithm( secret_key.algorithm, available=[self._algorithm], ) # Split secret key into classical + PQC components classical_sec_bytes, pqc_sec_bytes = _unpack_components( secret_key.raw_bytes, context=self._algorithm ) # --- Classical half --- classical_ss = self._decapsulate_classical(classical_sec_bytes, ciphertext.classical_ct) # --- PQC half --- pqc_ss_bytes = self._backend.decapsulate(self._pqc, pqc_sec_bytes, ciphertext.pqc_ct) # --- Combine (same construction as encapsulate) --- return combine_shared_secrets( classical_ss=classical_ss, pqc_ss=pqc_ss_bytes[:32], algorithm=self._algorithm, classical_ct=ciphertext.classical_ct, pqc_ct=ciphertext.pqc_ct, )
# ------------------------------------------------------------------ # Classical KEM internals # ------------------------------------------------------------------ def _gen_classical_keypair(self) -> tuple[bytes, bytes]: """Generate classical ephemeral key material. Returns (private_bytes, public_bytes). We generate a fresh ephemeral key for every key pair. In the context of hybrid KEM, the classical component functions as a long-term key (unlike in DH-based protocols where it would be ephemeral per session). The caller decides the key lifecycle. """ if self._classical == "X25519": priv = X25519PrivateKey.generate() pub = priv.public_key() priv_bytes = priv.private_bytes(Encoding.Raw, PrivateFormat.Raw, NoEncryption()) pub_bytes = pub.public_bytes(Encoding.Raw, PublicFormat.Raw) return priv_bytes, pub_bytes elif self._classical == "P-256": from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric.ec import ( SECP256R1, generate_private_key, ) ec_priv = generate_private_key(SECP256R1(), default_backend()) ec_pub = ec_priv.public_key() priv_bytes = ec_priv.private_bytes(Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()) pub_bytes = ec_pub.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) return priv_bytes, pub_bytes else: raise UnsupportedAlgorithm( self._classical, available=["X25519", "P-256"], ) def _encapsulate_classical(self, recipient_pub_bytes: bytes) -> tuple[bytes, bytes]: """Perform classical key encapsulation. For X25519, encapsulation = generate ephemeral keypair, compute DH. The 'ciphertext' is the ephemeral public key. Returns (classical_ct_bytes, shared_secret_bytes). """ if self._classical == "X25519": # Generate an ephemeral X25519 keypair ephem_priv = X25519PrivateKey.generate() ephem_pub = ephem_priv.public_key() # Load recipient's X25519 public key recipient_pub = X25519PublicKey.from_public_bytes(recipient_pub_bytes) # DH exchange shared = ephem_priv.exchange(recipient_pub) # The "ciphertext" is the ephemeral public key ct = ephem_pub.public_bytes(Encoding.Raw, PublicFormat.Raw) return ct, shared elif self._classical == "P-256": from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric.ec import ( ECDH, SECP256R1, EllipticCurvePublicNumbers, generate_private_key, ) ec_ephem_priv = generate_private_key(SECP256R1(), default_backend()) ec_ephem_pub = ec_ephem_priv.public_key() # Load recipient's public key from uncompressed point # cryptography doesn't have a direct from_encoded_point in all versions # so we use the longer form if len(recipient_pub_bytes) != 65 or recipient_pub_bytes[0] != 0x04: raise DecapsulationError(algo=self._classical) x = int.from_bytes(recipient_pub_bytes[1:33], "big") y = int.from_bytes(recipient_pub_bytes[33:65], "big") ec_recipient_pub = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key( default_backend() ) shared_key = ec_ephem_priv.exchange(ECDH(), ec_recipient_pub) if len(shared_key) < 32: raise DecapsulationError(algo=self._classical) ct = ec_ephem_pub.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) return ct, shared_key[:32] else: raise UnsupportedAlgorithm(self._classical, available=["X25519", "P-256"]) def _decapsulate_classical(self, secret_key_bytes: bytes, ciphertext_bytes: bytes) -> bytes: """Recover the classical shared secret. For X25519, the ciphertext is the sender's ephemeral public key. We compute DH(our_secret_key, sender_ephemeral_pub). Returns shared_secret_bytes. """ if self._classical == "X25519": try: # Load our static secret key our_priv = X25519PrivateKey.from_private_bytes(secret_key_bytes) # Load sender's ephemeral public key (the "ciphertext") sender_ephem_pub = X25519PublicKey.from_public_bytes(ciphertext_bytes) return our_priv.exchange(sender_ephem_pub) except Exception as exc: raise DecapsulationError(algo=self._algorithm) from exc elif self._classical == "P-256": try: from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric.ec import ( ECDH, SECP256R1, EllipticCurvePrivateKey, EllipticCurvePublicNumbers, ) from cryptography.hazmat.primitives.serialization import load_pem_private_key ec_our_priv = cast( EllipticCurvePrivateKey, load_pem_private_key( secret_key_bytes, password=None, backend=default_backend() ), ) if len(ciphertext_bytes) != 65 or ciphertext_bytes[0] != 0x04: raise DecapsulationError(algo=self._algorithm) x = int.from_bytes(ciphertext_bytes[1:33], "big") y = int.from_bytes(ciphertext_bytes[33:65], "big") sender_pub = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key( default_backend() ) shared = ec_our_priv.exchange(ECDH(), sender_pub) if len(shared) < 32: raise DecapsulationError(algo=self._algorithm) return shared[:32] except DecapsulationError: raise except Exception as exc: raise DecapsulationError(algo=self._algorithm) from exc else: raise UnsupportedAlgorithm(self._classical, available=["X25519", "P-256"]) def __repr__(self) -> str: return ( f"HybridKEM(" f"classical={self._classical!r}, " f"pqc={self._pqc!r}, " f"backend={self._backend.name!r})" )