Source code for quantum_safe.migrate.shims

"""
quantum_safe.migrate.shims
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Drop-in replacement shims for common classical cryptography APIs.

These shims let you migrate a codebase incrementally: replace the import
and the rest of your code works unchanged, but now has PQC protection.

Every shim call is logged so you can see exactly which code paths are
still using the shim vs. direct quantum-safe calls. Once a path is fully
migrated to native quantum-safe calls, the shim is no longer needed.

Available shims
---------------
FernetShim     — Replaces cryptography.fernet.Fernet
                 (symmetric encryption with HMAC)
                 → Envelope.seal() / Envelope.open() with HybridKEM

JWTShim        — Replaces PyJWT's jwt.encode() / jwt.decode()
                 → JWTSigner / JWTVerifier with ML-DSA

Usage::

    # Before migration:
    from cryptography.fernet import Fernet
    key = Fernet.generate_key()
    f = Fernet(key)
    token = f.encrypt(data)

    # During migration — drop-in replacement, logs every call:
    from quantum_safe.migrate.shims import FernetShim as Fernet
    # ... rest of code unchanged ...

    # After full migration:
    from quantum_safe.protocols import Envelope
    sealed = Envelope.seal(data, recipient_public_key)
    # ... use native API ...

Note on FernetShim semantics
----------------------------
Fernet uses a symmetric key, which means anyone with the key can both
encrypt and decrypt. Envelope.seal() uses asymmetric KEM — you encrypt
to a public key, and only the holder of the secret key can decrypt.

The FernetShim cannot be a true drop-in for Fernet because the security
model is fundamentally different. Instead, it:
  1. Auto-generates a keypair on construction (stateful)
  2. Encrypts to that keypair's public key
  3. Decrypts with that keypair's secret key

This maintains the same interface (encrypt/decrypt with the same object)
while upgrading the underlying construction to PQC.

For code that shares a Fernet key between processes, use the
quantum_safe.protocols.Envelope API directly with explicit key management.
"""

from __future__ import annotations

import logging
import warnings
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from quantum_safe.types import PublicKey, SecretKey

_logger = logging.getLogger(__name__)


class _ShimBase:
    """Common behavior for all shims: logging, deprecation warnings, call counting."""

    _shim_name: str = "ShimBase"
    _call_count: int = 0

    @classmethod
    def _log_shim_call(cls, method: str, note: str = "") -> None:
        cls._call_count += 1
        _logger.debug(
            "[quantum-safe shim] %s.%s() called (total calls: %d)%s",
            cls._shim_name,
            method,
            cls._call_count,
            f" - {note}" if note else "",
        )

    @classmethod
    def shim_stats(cls) -> dict[str, Any]:
        return {
            "shim": cls._shim_name,
            "call_count": cls._call_count,
        }


[docs] class FernetShim(_ShimBase): """Drop-in replacement for cryptography.fernet.Fernet. Replaces the symmetric Fernet construction with asymmetric hybrid KEM + AES-256-GCM encryption. See module docstring for semantic differences. The interface is intentionally similar to Fernet but not identical: - No `generate_key()` class method (keys are asymmetric now) - `encrypt(data)` → bytes (SealedMessage serialized) - `decrypt(token)` → bytes Call shim_stats() to see how often this shim is being used. """ _shim_name = "FernetShim" def __init__(self, backend: str = "auto") -> None: warnings.warn( "FernetShim: you are using a migration shim. " "This replaces Fernet with quantum-safe encryption (HybridKEM + AES-256-GCM). " "The security model is different — see quantum_safe.migrate.shims docstring. " "Migrate to quantum_safe.protocols.Envelope when ready.", DeprecationWarning, stacklevel=2, ) from quantum_safe.kem.hybrid import HybridKEM self._kem = HybridKEM(backend=backend) self._keypair = self._kem.generate_keypair() self._log_shim_call("__init__")
[docs] def encrypt(self, data: bytes) -> bytes: """Encrypt data. Returns a SealedMessage serialized to bytes.""" from quantum_safe.protocols.envelope import Envelope self._log_shim_call("encrypt") sealed = Envelope.seal(data, self._keypair.public, kem=self._kem) return sealed.to_bytes()
[docs] def decrypt(self, token: bytes) -> bytes: """Decrypt a token produced by encrypt().""" from quantum_safe.protocols.envelope import Envelope, SealedMessage self._log_shim_call("decrypt") sealed = SealedMessage.from_bytes(token) return Envelope.open(sealed, self._keypair.secret, kem=self._kem)
@property def public_key(self) -> PublicKey: """The public key used for encryption. Share this with senders.""" return self._keypair.public @property def secret_key(self) -> SecretKey: """The secret key used for decryption. Keep this private.""" return self._keypair.secret
[docs] class JWTShim(_ShimBase): """Drop-in replacement for PyJWT's jwt.encode() / jwt.decode(). Replaces classical JWT signing (RS256, ES256, HS256) with hybrid PQC signing (Ed25519+ML-DSA-65). Usage:: # Before: import jwt token = jwt.encode({"sub": "user"}, private_key, algorithm="RS256") claims = jwt.decode(token, public_key, algorithms=["RS256"]) # After (drop-in): from quantum_safe.migrate.shims import JWTShim as jwt token = jwt.encode({"sub": "user"}, keypair, algorithm="Ed25519+ML-DSA-65") claims = jwt.decode(token, public_key, algorithms=["Ed25519+ML-DSA-65"]) The `key` parameter accepts a quantum_safe.types.KeyPair for encoding and a quantum_safe.types.PublicKey for decoding. """ _shim_name = "JWTShim"
[docs] @staticmethod def encode( payload: dict[str, Any], key: Any, # noqa: ANN401 algorithm: str = "Ed25519+ML-DSA-65", **kwargs: Any, # noqa: ANN401 ) -> str: """Sign a JWT payload. Args: payload: Claims dict. key: quantum_safe.types.KeyPair for signing. algorithm: Hybrid or PQC algorithm string. Returns: JWT token string. """ JWTShim._log_shim_call("encode") warnings.warn( "JWTShim.encode() is a migration shim. " "Migrate to quantum_safe.protocols.jwt.JWTSigner when ready.", DeprecationWarning, stacklevel=2, ) from quantum_safe.protocols.jwt import JWTSigner from quantum_safe.types import KeyPair if not isinstance(key, KeyPair): raise TypeError( f"JWTShim.encode() requires a quantum_safe KeyPair, got {type(key).__name__}. " f"Generate one with HybridSign().generate_keypair()." ) signer = JWTSigner(key, issuer=payload.get("iss")) # Pass claims without duplicating iss (JWTSigner adds it from keypair issuer param) claims = {k: v for k, v in payload.items() if k != "iss"} expires_in = 0 if "exp" in payload: import time exp_delta = int(payload["exp"]) - int(time.time()) expires_in = max(exp_delta, 1) return signer.sign(claims, expires_in=expires_in)
[docs] @staticmethod def decode( token: str, key: Any, # noqa: ANN401 algorithms: list[str] | None = None, **kwargs: Any, # noqa: ANN401 ) -> dict[str, Any]: """Verify and decode a JWT. Args: token: JWT string from encode(). key: quantum_safe.types.PublicKey for verification. algorithms: Ignored — algorithm is inferred from the token header. Returns: Verified claims dict. """ JWTShim._log_shim_call("decode") warnings.warn( "JWTShim.decode() is a migration shim. " "Migrate to quantum_safe.protocols.jwt.JWTVerifier when ready.", DeprecationWarning, stacklevel=2, ) from quantum_safe.protocols.jwt import JWTVerifier from quantum_safe.types import PublicKey if not isinstance(key, PublicKey): raise TypeError( f"JWTShim.decode() requires a quantum_safe PublicKey, got {type(key).__name__}." ) verifier = JWTVerifier(key) return verifier.verify(token)