Source code for quantum_safe.types.kem

"""
quantum_safe.types.kem
~~~~~~~~~~~~~~~~~~~~~~

Typed wrappers for KEM operation outputs.

The two core outputs of a KEM are:
  - CipherText: what the encapsulator sends to the decapsulator
  - SharedSecret: the symmetric key material both parties derive

Both are distinct types (not type aliases for bytes) to prevent the
class of bug where you accidentally pass a shared secret as a ciphertext
or vice versa. This has happened in real implementations.

For hybrid KEMs, we have HybridCipherText which carries both the classical
(X25519) and PQC (ML-KEM) ciphertexts, and derives a combined shared secret.

The combination follows the IETF hybrid KEM construction:
  combined_ss = HKDF-SHA256(
      ikm  = classical_ss || pqc_ss,
      salt = "",
      info = "quantum-safe hybrid KEM v1" || algorithm_string
  )

References:
  - FIPS 203 §6.2 — ML-KEM.Encaps / ML-KEM.Decaps
  - draft-ietf-tls-hybrid-design §3 — combiner construction
  - RFC 5869 — HKDF
"""

from __future__ import annotations

import ctypes
import hmac
import struct
from typing import ClassVar

from quantum_safe.exceptions import DecapsulationError

# We use HKDF-SHA256 for the hybrid combiner. The info string is fixed and
# version-pinned so that old clients can't be tricked into using a different
# construction.
_HYBRID_COMBINER_INFO = b"quantum-safe hybrid KEM v1"
_SHARED_SECRET_LEN = 32  # bytes — 256-bit symmetric key


[docs] class SharedSecret: """The shared secret output of a KEM operation. This is 32 bytes of symmetric key material derived from the KEM. Like SecretKey, it zeroizes on deletion. You should use this as input to a KDF (e.g. HKDF) to derive actual encryption keys — don't use it directly as an AES key without further processing. The library does this for you in quantum_safe.protocols.envelope. """ __slots__ = ("_data", "_algorithm", "_is_hybrid") def __init__(self, data: bytes, algorithm: str, is_hybrid: bool = False) -> None: if len(data) != _SHARED_SECRET_LEN: raise ValueError( f"SharedSecret must be exactly {_SHARED_SECRET_LEN} bytes, got {len(data)}" ) # Store in a mutable buffer so we can zero it self._data = bytearray(data) self._algorithm = algorithm self._is_hybrid = is_hybrid @property def algorithm(self) -> str: return self._algorithm @property def is_hybrid(self) -> bool: """Whether this secret was derived from a hybrid KEM.""" return self._is_hybrid def __bytes__(self) -> bytes: return bytes(self._data) def __len__(self) -> int: return len(self._data) def __eq__(self, other: object) -> bool: if isinstance(other, SharedSecret): return hmac.compare_digest(bytes(self._data), bytes(other._data)) if isinstance(other, (bytes, bytearray)): return hmac.compare_digest(bytes(self._data), bytes(other)) return NotImplemented def __repr__(self) -> str: return f"SharedSecret(algo={self._algorithm!r}, <{len(self._data)} bytes REDACTED>)" def __del__(self) -> None: try: n = len(self._data) if n: ctypes.memset((ctypes.c_char * n).from_buffer(self._data), 0, n) except Exception: # noqa: BLE001, S110 pass
[docs] def derive_key( self, length: int = 32, salt: bytes | None = None, info: bytes = b"", ) -> bytes: """Derive a key from this shared secret using HKDF-SHA256. This is a convenience wrapper. For full control, use cryptography.hazmat.primitives.kdf.hkdf directly. Args: length: Desired output length in bytes (max 255 * 32 = 8160). salt: Optional salt. Defaults to a zero-filled string if None. info: Application-specific context. Include your app name and version to prevent cross-context key reuse. Returns: Raw key bytes of the requested length. Example: enc_key = ss.derive_key(32, info=b"myapp-encryption-v1") mac_key = ss.derive_key(32, info=b"myapp-mac-v1") """ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDF hkdf = HKDF( algorithm=hashes.SHA256(), length=length, salt=salt, info=info, ) return hkdf.derive(bytes(self._data))
[docs] class CipherText: """The ciphertext output of KEM encapsulation. This is what the encapsulator transmits to the decapsulator. It's not secret — but it is authenticated (MAC'd or KEM-authenticated depending on the scheme), so any modification will cause decapsulation to fail. Size varies by algorithm: ML-KEM-512: 768 bytes ML-KEM-768: 1088 bytes ML-KEM-1024: 1568 bytes """ __slots__ = ("_data", "_algorithm") # Expected ciphertext sizes per algorithm (from FIPS 203 §2.4) _EXPECTED_SIZES: ClassVar[dict[str, int]] = { "ML-KEM-512": 768, "ML-KEM-768": 1088, "ML-KEM-1024": 1568, } def __init__(self, data: bytes, algorithm: str) -> None: if not data: raise ValueError("ciphertext cannot be empty") self._data = data self._algorithm = algorithm # Warn (not error) on size mismatch — the backend may use a slightly # different internal format for hybrid ciphertexts expected = self._EXPECTED_SIZES.get(algorithm) if expected is not None and len(data) != expected: import warnings warnings.warn( f"CipherText for {algorithm} has unexpected size " f"(expected {expected}, got {len(data)}). " "This may indicate a backend format difference.", stacklevel=2, ) @property def algorithm(self) -> str: return self._algorithm @property def data(self) -> bytes: return self._data def __bytes__(self) -> bytes: return self._data def __len__(self) -> int: return len(self._data) def __repr__(self) -> str: return f"CipherText(algo={self._algorithm!r}, size={len(self._data)}B)" def __eq__(self, other: object) -> bool: if isinstance(other, CipherText): return self._algorithm == other._algorithm and self._data == other._data return NotImplemented
[docs] class HybridCipherText: """Ciphertext from a hybrid KEM: classical ephemeral + PQC encapsulation. The wire format is: classical_ct_len (2 bytes, big-endian uint16) || classical_ct || pqc_ct This framing allows the receiver to split the two components without needing out-of-band length information. """ __slots__ = ("_classical_ct", "_pqc_ct", "_algorithm") # The length prefix is a 2-byte big-endian uint16 _LEN_PREFIX_FORMAT = ">H" _LEN_PREFIX_SIZE = 2 def __init__( self, classical_ct: bytes, pqc_ct: bytes, algorithm: str, ) -> None: if not classical_ct: raise ValueError("classical_ct cannot be empty") if not pqc_ct: raise ValueError("pqc_ct cannot be empty") self._classical_ct = classical_ct self._pqc_ct = pqc_ct self._algorithm = algorithm @property def algorithm(self) -> str: return self._algorithm @property def classical_ct(self) -> bytes: return self._classical_ct @property def pqc_ct(self) -> bytes: return self._pqc_ct
[docs] def to_bytes(self) -> bytes: """Encode as length-prefixed wire format.""" prefix = struct.pack(self._LEN_PREFIX_FORMAT, len(self._classical_ct)) return prefix + self._classical_ct + self._pqc_ct
[docs] @classmethod def from_bytes(cls, data: bytes, algorithm: str) -> HybridCipherText: """Decode from length-prefixed wire format.""" if len(data) < cls._LEN_PREFIX_SIZE: raise DecapsulationError(algo=algorithm) (classical_len,) = struct.unpack_from(cls._LEN_PREFIX_FORMAT, data, 0) offset = cls._LEN_PREFIX_SIZE if len(data) < offset + classical_len: raise DecapsulationError(algo=algorithm) classical_ct = data[offset : offset + classical_len] pqc_ct = data[offset + classical_len :] if not pqc_ct: raise DecapsulationError(algo=algorithm) return cls( classical_ct=classical_ct, pqc_ct=pqc_ct, algorithm=algorithm, )
def __len__(self) -> int: return self._LEN_PREFIX_SIZE + len(self._classical_ct) + len(self._pqc_ct) def __repr__(self) -> str: return ( f"HybridCipherText(algo={self._algorithm!r}, " f"classical={len(self._classical_ct)}B, " f"pqc={len(self._pqc_ct)}B)" )
def combine_shared_secrets( classical_ss: bytes, pqc_ss: bytes, algorithm: str, classical_ct: bytes, pqc_ct: bytes, ) -> SharedSecret: """Combine classical and PQC shared secrets using the hybrid KEM combiner. Implements the construction from draft-ietf-tls-hybrid-design §3.2: combined = HKDF-SHA256( ikm = classical_ss || pqc_ss, salt = classical_ct || pqc_ct, info = info_string || algo_name ) Using the concatenated ciphertexts as the salt binds the shared secret to the specific exchange (prevents KCI attacks where an attacker reuses a ciphertext in a different session). Args: classical_ss: Shared secret from X25519 (32 bytes) pqc_ss: Shared secret from ML-KEM (32 bytes) algorithm: Algorithm string for domain separation classical_ct: Classical ciphertext (X25519 public key of encapsulator) pqc_ct: PQC ciphertext Returns: A 32-byte SharedSecret derived from both components. """ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.hkdf import HKDF # IKM is the concatenation of both secrets ikm = classical_ss + pqc_ss # Salt is the concatenation of both ciphertexts — binds to this exchange salt = classical_ct + pqc_ct # Info provides algorithm-level domain separation info = _HYBRID_COMBINER_INFO + b"\x00" + algorithm.encode("ascii") hkdf = HKDF( algorithm=hashes.SHA256(), length=_SHARED_SECRET_LEN, salt=salt, info=info, ) combined = hkdf.derive(ikm) return SharedSecret(data=combined, algorithm=algorithm, is_hybrid=True)