Source code for quantum_safe.migrate.state

"""
quantum_safe.migrate.state
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Migration state machine for tracking PQC migration progress across a key store.

The real-world migration problem is not "upgrade all keys at once" — it's
"upgrade keys gradually, without breaking existing users, and track where
everything is." This module provides the state machine and audit log for that.

Migration states (from types/keys.py MigrationState):
    CLASSICAL_ONLY     → Key uses only classical crypto. Not yet migrated.
    HYBRID_TRANSITION  → Key has both classical and PQC components. This is
                         the recommended transition state — both components
                         work, so you maintain backward compatibility while
                         gaining PQC protection.
    PQC_PREFERRED      → Hybrid key, but the system now treats the PQC
                         component as authoritative. The classical component
                         is still present for legacy verifiers.
    PQC_ONLY           → Classical component has been removed. Fully migrated.
                         Not backward compatible with classical-only clients.

Valid transitions:
    CLASSICAL_ONLY → HYBRID_TRANSITION          (first upgrade)
    HYBRID_TRANSITION → PQC_PREFERRED           (gain confidence in PQC)
    PQC_PREFERRED → PQC_ONLY                    (remove classical component)
    HYBRID_TRANSITION → CLASSICAL_ONLY          (rollback — logged as warning)

Downgrade transitions (anything that goes toward less PQC) are allowed but
logged as warnings with a mandatory reason string.

Each state change creates a MigrationRecord that can be stored in a database,
an audit log, or a file. The records are immutable once created.
"""

from __future__ import annotations

import threading
import time
import uuid
from dataclasses import dataclass, field
from typing import Any

from quantum_safe._internal import serialization as _ser
from quantum_safe.types.keys import MigrationState

# Valid forward and backward transitions.
# Forward = toward more PQC. Backward = toward less PQC (requires reason).
_FORWARD_TRANSITIONS: dict[MigrationState, set[MigrationState]] = {
    MigrationState.CLASSICAL_ONLY: {MigrationState.HYBRID_TRANSITION},
    MigrationState.HYBRID_TRANSITION: {MigrationState.PQC_PREFERRED},
    MigrationState.PQC_PREFERRED: {MigrationState.PQC_ONLY},
    MigrationState.PQC_ONLY: set(),  # terminal state
}

_BACKWARD_TRANSITIONS: dict[MigrationState, set[MigrationState]] = {
    MigrationState.HYBRID_TRANSITION: {MigrationState.CLASSICAL_ONLY},
    MigrationState.PQC_PREFERRED: {MigrationState.HYBRID_TRANSITION},
    # PQC_ONLY → anything is intentionally not allowed without a manual override
}


@dataclass(frozen=True)
class MigrationRecord:
    """An immutable record of a single migration state transition.

    Attributes:
        record_id:      Unique identifier for this record (UUID string).
        key_id:         Application-level identifier for the key being migrated.
                        This is whatever your app uses to identify keys (user ID,
                        key fingerprint, database row ID, etc.).
        from_state:     The state before the transition.
        to_state:       The state after the transition.
        algorithm:      The key algorithm after transition.
        timestamp:      Unix timestamp of the transition.
        actor:          Who initiated the migration (service name, user ID, etc.).
        reason:         Required for backward transitions. Optional for forward.
        metadata:       Arbitrary additional data for audit purposes.
    """

    record_id: str
    key_id: str
    from_state: MigrationState
    to_state: MigrationState
    algorithm: str
    timestamp: float
    actor: str = "system"
    reason: str = ""
    metadata: dict[str, Any] = field(default_factory=dict)

    @property
    def is_forward(self) -> bool:
        """True if this transition moves toward more PQC."""
        forward_set = _FORWARD_TRANSITIONS.get(self.from_state, set())
        return self.to_state in forward_set

    @property
    def is_backward(self) -> bool:
        return not self.is_forward

    def to_dict(self) -> dict[str, Any]:
        return {
            "record_id": self.record_id,
            "key_id": self.key_id,
            "from_state": self.from_state.value,
            "to_state": self.to_state.value,
            "algorithm": self.algorithm,
            "timestamp": self.timestamp,
            "actor": self.actor,
            "reason": self.reason,
            "metadata": self.metadata,
            "is_forward": self.is_forward,
        }

    def to_bytes(self) -> bytes:
        return _ser.dumps(self.to_dict())

    @classmethod
    def from_dict(cls, d: dict[str, Any]) -> MigrationRecord:
        return cls(
            record_id=d["record_id"],
            key_id=d["key_id"],
            from_state=MigrationState(d["from_state"]),
            to_state=MigrationState(d["to_state"]),
            algorithm=d["algorithm"],
            timestamp=float(d["timestamp"]),
            actor=d.get("actor", "system"),
            reason=d.get("reason", ""),
            metadata=d.get("metadata", {}),
        )

    @classmethod
    def from_bytes(cls, data: bytes) -> MigrationRecord:
        return cls.from_dict(_ser.loads(data))


[docs] class MigrationStateManager: """Manages migration state for a collection of keys. This class is storage-agnostic — it takes a dict-like store and wraps it with validation, history, and audit logging. You provide the storage; we provide the business logic. Args: store: A dict-like object for persistent state storage. Keys are string key_ids; values are MigrationRecord bytes. In production, back this with Redis, DynamoDB, Postgres, etc. In tests, a plain dict works fine. Example:: store = {} # replace with your database abstraction mgr = MigrationStateManager(store) # First time we see this key rec = mgr.transition( key_id="user-123", from_state=MigrationState.CLASSICAL_ONLY, to_state=MigrationState.HYBRID_TRANSITION, algorithm="X25519+ML-KEM-768", actor="key-rotation-job-v1", ) # rec is stored in `store["user-123_current"]` # Full history in `store["user-123_history"]` """ def __init__(self, store: dict[str, bytes]) -> None: self._store = store # Per-key locks prevent concurrent transitions from racing past the # read-then-write check. For multi-process deployments, callers must # additionally hold an external distributed lock (e.g. Redis SETNX, # database row-level lock) on the key_id before calling transition(). self._key_locks: dict[str, threading.Lock] = {} self._meta_lock = threading.Lock() # guards _key_locks dict itself def _lock_for(self, key_id: str) -> threading.Lock: """Return (creating if needed) the per-key lock for key_id.""" with self._meta_lock: if key_id not in self._key_locks: self._key_locks[key_id] = threading.Lock() return self._key_locks[key_id]
[docs] def transition( self, key_id: str, from_state: MigrationState, to_state: MigrationState, algorithm: str, actor: str = "system", reason: str = "", metadata: dict[str, Any] | None = None, allow_backward: bool = False, ) -> MigrationRecord: """Record a state transition for a key. Args: key_id: Application key identifier. from_state: Expected current state (for optimistic concurrency check). to_state: New target state. algorithm: Key algorithm after this transition. actor: Who is performing the migration. reason: Why (required for backward transitions). metadata: Arbitrary key-value pairs for audit. allow_backward: Set True to explicitly permit backward transitions. Still requires a non-empty reason string. Returns: The created MigrationRecord. Raises: ValueError: if the transition is not valid or current state doesn't match from_state. """ # Validate the transition forward_targets = _FORWARD_TRANSITIONS.get(from_state, set()) backward_targets = _BACKWARD_TRANSITIONS.get(from_state, set()) all_targets = forward_targets | backward_targets if to_state not in all_targets: raise ValueError( f"Invalid transition from {from_state.value!r} to {to_state.value!r}. " f"Valid targets: {[s.value for s in all_targets]}" ) is_backward = to_state in backward_targets if is_backward: if not allow_backward: raise ValueError( f"Backward transition from {from_state.value!r} to " f"{to_state.value!r} requires allow_backward=True" ) if not reason: raise ValueError("Backward transition requires a non-empty reason string") with self._lock_for(key_id): # Check current state matches expected from_state current = self.get_current_state(key_id) if current is not None and current != from_state: raise ValueError( f"Key '{key_id}' is in state {current.value!r} but " f"transition expected {from_state.value!r}. " f"Concurrent modification or stale state?" ) record = MigrationRecord( record_id=str(uuid.uuid4()), key_id=key_id, from_state=from_state, to_state=to_state, algorithm=algorithm, timestamp=time.time(), actor=actor, reason=reason, metadata=metadata or {}, ) # Store current state and append to history self._store[f"{key_id}_current"] = record.to_bytes() history_key = f"{key_id}_history" history = self._load_history(key_id) history.append(record.to_dict()) self._store[history_key] = _ser.dumps(history) return record
[docs] def get_current_state(self, key_id: str) -> MigrationState | None: """Return the current migration state for a key, or None if unknown.""" current_key = f"{key_id}_current" if current_key not in self._store: return None try: rec = MigrationRecord.from_bytes(self._store[current_key]) return rec.to_state except Exception: # noqa: BLE001 return None
[docs] def get_current_record(self, key_id: str) -> MigrationRecord | None: """Return the full current record for a key.""" current_key = f"{key_id}_current" if current_key not in self._store: return None try: return MigrationRecord.from_bytes(self._store[current_key]) except Exception: # noqa: BLE001 return None
[docs] def get_history(self, key_id: str) -> list[MigrationRecord]: """Return full migration history for a key, oldest first.""" history_dicts = self._load_history(key_id) records = [] for d in history_dicts: try: records.append(MigrationRecord.from_dict(d)) except Exception: # noqa: BLE001, S110 pass # skip malformed records return records
[docs] def keys_by_state(self, state: MigrationState) -> list[str]: """Return all key IDs currently in the given migration state. This is a full scan — in production, maintain a secondary index. """ result = [] for store_key in self._store: if not store_key.endswith("_current"): continue key_id = store_key[: -len("_current")] current = self.get_current_state(key_id) if current == state: result.append(key_id) return sorted(result)
[docs] def needs_migration(self) -> list[str]: """Return key IDs that are not yet in HYBRID_TRANSITION or better.""" not_migrated = [] for store_key in self._store: if not store_key.endswith("_current"): continue key_id = store_key[: -len("_current")] state = self.get_current_state(key_id) if state in (MigrationState.CLASSICAL_ONLY, None): not_migrated.append(key_id) return sorted(not_migrated)
[docs] def migration_progress(self) -> dict[str, int]: """Return a count of keys in each migration state.""" counts: dict[str, int] = {s.value: 0 for s in MigrationState} counts["unknown"] = 0 for store_key in self._store: if not store_key.endswith("_current"): continue key_id = store_key[: -len("_current")] state = self.get_current_state(key_id) if state: counts[state.value] += 1 else: counts["unknown"] += 1 return counts
def _load_history(self, key_id: str) -> list[dict[str, Any]]: history_key = f"{key_id}_history" if history_key not in self._store: return [] try: data = _ser.loads(self._store[history_key]) return data if isinstance(data, list) else [] except Exception: # noqa: BLE001 return []