Source code for quantum_safe.migrate.scanner

"""
quantum_safe.migrate.scanner
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

AST-based scanner that finds classical cryptography usage in Python codebases.

The scanner walks Python source files, parses them with the stdlib `ast`
module, and matches patterns against a catalogue of known-classical APIs.
The output is a structured ScanReport that can be serialized to SARIF (for
GitHub Code Scanning / GitLab SAST), plain JSON, or a human-readable table.

Why AST and not grep?
    grep finds string literals, not actual usage. A grep for "RSA" would
    miss `from cryptography.hazmat.primitives.asymmetric import rsa; rsa.generate_private_key(...)`.
    The AST walk sees the import, resolves the alias, and matches the call.
    This reduces both false positives (comments mentioning RSA) and false
    negatives (aliased imports).

Severity model
    CRITICAL:  RSA < 2048 bits, or key material directly in source
    HIGH:      RSA >= 2048, ECDSA/ECDH with any curve, DSA, Ed25519-only signing
    MEDIUM:    AES-128, SHA-1, MD5, PBKDF2 with low iterations
    INFO:      Classical-safe usage that will need migration eventually
               (e.g. AES-256 is fine today but note it for inventory)

What we scan for
    - Imports of classical crypto libraries (cryptography, pycryptodome, pyca)
    - Direct key generation calls (rsa.generate_private_key, etc.)
    - Hardcoded key sizes below thresholds
    - Algorithm string literals ("RS256", "HS256", "AES-128-CBC")
    - JWT algorithm identifiers in string constants

Limitations
    - Dynamic imports (importlib.import_module) are not resolved.
    - Obfuscated code or eval() are not analyzed.
    - Third-party library internals are not followed.
    - Only Python files are supported in this version. TypeScript/Rust
      scanning is planned for v0.2.
"""

from __future__ import annotations

import ast
import fnmatch
import json
import os
import pathlib
from collections.abc import Iterator
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any


class Severity(IntEnum):
    """Finding severity levels, ordered so higher = worse."""

    INFO = 0
    MEDIUM = 1
    HIGH = 2
    CRITICAL = 3


[docs] @dataclass class Finding: """A single classical crypto usage finding. Attributes: file: Absolute or relative path to the source file. line: 1-based line number. col: 1-based column number. severity: Severity level. rule_id: Short machine-readable rule identifier, e.g. "QS001". message: Human-readable description. snippet: The offending source line, stripped of leading whitespace. fix_hint: Optional suggestion for how to fix the issue. """ file: str line: int col: int severity: Severity rule_id: str message: str snippet: str = "" fix_hint: str = "" def to_dict(self) -> dict[str, Any]: return { "file": self.file, "line": self.line, "col": self.col, "severity": self.severity.name, "rule_id": self.rule_id, "message": self.message, "snippet": self.snippet, "fix_hint": self.fix_hint, } def __str__(self) -> str: sev = self.severity.name.ljust(8) return f"[{sev}] {self.file}:{self.line}:{self.col} {self.rule_id} {self.message}"
[docs] @dataclass class ScanReport: """Aggregated results from scanning one or more files/directories. Attributes: root: The directory or file that was scanned. files_scanned: Number of Python files analyzed. findings: All findings, sorted by (file, line). errors: Files that could not be parsed (syntax errors, permission issues). """ root: str files_scanned: int = 0 findings: list[Finding] = field(default_factory=list) errors: list[dict[str, str]] = field(default_factory=list) # ------------------------------------------------------------------ # Convenience accessors # ------------------------------------------------------------------ @property def critical(self) -> list[Finding]: return [f for f in self.findings if f.severity == Severity.CRITICAL] @property def high(self) -> list[Finding]: return [f for f in self.findings if f.severity == Severity.HIGH] @property def medium(self) -> list[Finding]: return [f for f in self.findings if f.severity == Severity.MEDIUM] @property def info(self) -> list[Finding]: return [f for f in self.findings if f.severity == Severity.INFO] @property def has_blocking_findings(self) -> bool: """True if there are any HIGH or CRITICAL findings. This is the condition that should cause a CI gate to fail. """ return any(f.severity >= Severity.HIGH for f in self.findings)
[docs] def summary(self) -> str: """One-line summary for logging.""" parts = [] if self.critical: parts.append(f"{len(self.critical)} CRITICAL") if self.high: parts.append(f"{len(self.high)} HIGH") if self.medium: parts.append(f"{len(self.medium)} MEDIUM") if self.info: parts.append(f"{len(self.info)} INFO") if not parts: parts.append("no findings") return f"Scanned {self.files_scanned} files in '{self.root}': " + ", ".join(parts)
# ------------------------------------------------------------------ # Serialization # ------------------------------------------------------------------
[docs] def to_json(self, indent: int = 2) -> str: """Serialize to JSON.""" return json.dumps( { "root": self.root, "files_scanned": self.files_scanned, "summary": self.summary(), "findings": [f.to_dict() for f in self.findings], "errors": self.errors, }, indent=indent, )
[docs] def to_sarif(self) -> dict[str, Any]: """Produce a SARIF 2.1.0 document for GitHub Code Scanning / GitLab SAST. The result can be written to a file and uploaded as a SARIF artifact. See: https://docs.github.com/en/code-security/code-scanning/sarif-schema """ # Build the rules list from distinct rule_ids seen_rules: dict[str, Finding] = {} for f in self.findings: if f.rule_id not in seen_rules: seen_rules[f.rule_id] = f rules = [ { "id": rule_id, "name": rule_id, "shortDescription": {"text": f.message}, "defaultConfiguration": {"level": _sarif_level(f.severity)}, "helpUri": "https://quantum-safe-py.readthedocs.io/en/latest/guides/audit.html", } for rule_id, f in seen_rules.items() ] results = [] for f in self.findings: results.append( { "ruleId": f.rule_id, "message": {"text": f.message}, "level": _sarif_level(f.severity), "locations": [ { "physicalLocation": { "artifactLocation": {"uri": f.file}, "region": { "startLine": f.line, "startColumn": f.col, }, } } ], "fixes": [ { "description": {"text": f.fix_hint}, } ] if f.fix_hint else [], } ) return { "$schema": "https://raw.githubusercontent.com/oasis-tcs/sarif-spec/master/Schemata/sarif-schema-2.1.0.json", "version": "2.1.0", "runs": [ { "tool": { "driver": { "name": "qs-audit", "version": "0.1.0", "rules": rules, } }, "results": results, } ], }
def _sarif_level(severity: Severity) -> str: return { Severity.CRITICAL: "error", Severity.HIGH: "error", Severity.MEDIUM: "warning", Severity.INFO: "note", }[severity] # --------------------------------------------------------------------------- # Rule catalogue # --------------------------------------------------------------------------- # Each rule is a dict with: # id: Short rule ID, e.g. "QS001" # severity: Severity level # message: Template (may contain {detail}) # fix_hint: Replacement suggestion # match: Callable(node, aliases) -> bool | str (False = no match, str = detail) # Tracks imports: maps alias -> canonical module path # e.g. "rsa" -> "cryptography.hazmat.primitives.asymmetric.rsa" _Aliases = dict[str, str] def _is_attr_call(node: ast.AST, aliases: _Aliases, *path: str) -> bool: """Return True if `node` is a call of the form a.b.c(*args). path is the expected attribute chain, e.g. ("rsa", "generate_private_key"). The first element is matched against the aliases dict. """ if not isinstance(node, ast.Call): return False func = node.func if len(path) == 1: return isinstance(func, ast.Name) and aliases.get(func.id, func.id) == path[0] if len(path) == 2: return ( isinstance(func, ast.Attribute) and func.attr == path[1] and isinstance(func.value, ast.Name) and aliases.get(func.value.id, func.value.id).endswith(path[0]) ) return False def _string_value(node: ast.AST) -> str | None: """Extract string value from an ast.Constant node.""" if isinstance(node, ast.Constant) and isinstance(node.value, str): return node.value return None # Classical API patterns that indicate non-quantum-safe usage _RULES: list[dict[str, Any]] = [ # ---- RSA ------------------------------------------------------- { "id": "QS001", "severity": Severity.HIGH, "message": "RSA key generation detected - RSA is not quantum-safe", "fix_hint": "Replace with HybridKEM() for key exchange or HybridSign() for signatures", "imports": {"cryptography.hazmat.primitives.asymmetric.rsa"}, "calls": {("rsa", "generate_private_key")}, }, { "id": "QS002", "severity": Severity.HIGH, "message": "RSA PKCS1v15 padding detected - not quantum-safe", "fix_hint": "Replace RSA encryption with Envelope.seal() / Envelope.open()", "imports": {"cryptography.hazmat.primitives.asymmetric.padding"}, "calls": {("padding", "PKCS1v15")}, }, { "id": "QS003", "severity": Severity.HIGH, "message": "RSA-OAEP encryption detected - not quantum-safe", "fix_hint": "Replace with Envelope.seal() which uses HybridKEM + AES-256-GCM", "imports": {"cryptography.hazmat.primitives.asymmetric.padding"}, "calls": {("padding", "OAEP")}, }, # ---- ECDSA / ECDH --------------------------------------------- { "id": "QS010", "severity": Severity.HIGH, "message": "ECDSA key generation detected - not quantum-safe", "fix_hint": "Replace with HybridSign() for signatures", "imports": {"cryptography.hazmat.primitives.asymmetric.ec"}, "calls": {("ec", "generate_private_key")}, }, { "id": "QS011", "severity": Severity.HIGH, "message": "ECDH key exchange detected - not quantum-safe", "fix_hint": "Replace with HybridKEM() for key exchange", "imports": {"cryptography.hazmat.primitives.asymmetric.ec"}, "calls": {("ec", "ECDH")}, }, # ---- DSA ------------------------------------------------------- { "id": "QS015", "severity": Severity.CRITICAL, "message": "DSA key generation detected - not quantum-safe and deprecated", "fix_hint": "Replace with HybridSign() (ML-DSA-65 + Ed25519)", "imports": {"cryptography.hazmat.primitives.asymmetric.dsa"}, "calls": {("dsa", "generate_private_key")}, }, # ---- DH (classical Diffie-Hellman) ---------------------------- { "id": "QS016", "severity": Severity.HIGH, "message": "Classical DH key generation detected - not quantum-safe", "fix_hint": "Replace with HybridKEM()", "imports": {"cryptography.hazmat.primitives.asymmetric.dh"}, "calls": {("dh", "generate_parameters")}, }, # ---- Weak symmetric ------------------------------------------- { "id": "QS020", "severity": Severity.MEDIUM, "message": "AES-128 detected - consider upgrading to AES-256", "fix_hint": "Use AES-256-GCM. Envelope.seal() uses AES-256-GCM by default.", "imports": set(), "string_patterns": {"AES-128", "AES128"}, }, { "id": "QS021", "severity": Severity.MEDIUM, "message": "3DES / TripleDES detected - deprecated and not quantum-safe", "fix_hint": "Replace with AES-256-GCM", "imports": {"cryptography.hazmat.primitives.ciphers.algorithms"}, "calls": {("algorithms", "TripleDES")}, "string_patterns": {"3DES", "TripleDES", "DES3"}, }, # ---- Weak hash ------------------------------------------------ { "id": "QS030", "severity": Severity.MEDIUM, "message": "SHA-1 detected - cryptographically broken", "fix_hint": "Replace with SHA-256 or SHA-3", "imports": {"cryptography.hazmat.primitives.hashes", "hashlib"}, "calls": {("hashes", "SHA1"), ("hashlib", "sha1"), ("hashlib", "sha224")}, "string_patterns": {"SHA1", "SHA-1"}, }, { "id": "QS031", "severity": Severity.CRITICAL, "message": "MD5 detected - cryptographically broken", "fix_hint": "Replace with SHA-256 or BLAKE2b", "imports": {"cryptography.hazmat.primitives.hashes", "hashlib"}, "calls": {("hashes", "MD5"), ("hashlib", "md5")}, "string_patterns": {"MD5"}, }, # hashlib.new("sha1") / hashlib.new("md5") are caught by string_patterns above # since the algorithm name is a string literal passed to hashlib.new(). # ---- JWT algorithm identifiers -------------------------------- { "id": "QS040", "severity": Severity.HIGH, "message": "Classical JWT algorithm '{detail}' detected", "fix_hint": "Use JWTSigner from quantum_safe.protocols.jwt with ML-DSA-65", "imports": set(), "string_patterns": {"RS256", "RS384", "RS512", "ES256", "ES384", "PS256", "PS384", "PS512"}, }, # ---- pycryptodome / pycrypto ----------------------------------- { "id": "QS050", "severity": Severity.HIGH, "message": "pycryptodome RSA usage detected - not quantum-safe", "fix_hint": "Replace with quantum_safe.HybridKEM or HybridSign", "imports": {"Crypto.PublicKey.RSA", "Cryptodome.PublicKey.RSA"}, "calls": set(), }, { "id": "QS051", "severity": Severity.HIGH, "message": "pycryptodome ECC usage detected - not quantum-safe", "fix_hint": "Replace with quantum_safe.HybridSign", "imports": {"Crypto.PublicKey.ECC", "Cryptodome.PublicKey.ECC"}, "calls": set(), }, ] # Build fast lookup: module_path -> [rules that trigger on import] _IMPORT_RULE_MAP: dict[str, list[dict[str, Any]]] = {} for _rule in _RULES: for _imp in _rule.get("imports", set()): _IMPORT_RULE_MAP.setdefault(_imp, []).append(_rule) # --------------------------------------------------------------------------- # AST visitor # --------------------------------------------------------------------------- class _ClassicalCryptoVisitor(ast.NodeVisitor): """Walks an AST and collects classical crypto usage findings.""" def __init__(self, filename: str, source_lines: list[str]) -> None: self._filename = filename self._lines = source_lines self.findings: list[Finding] = [] # alias -> canonical module path (built from imports) self._module_aliases: _Aliases = {} # alias -> set of attribute names imported from a module self._from_imports: dict[str, str] = {} # Which rules are triggered by imports we've seen self._active_rules: list[dict[str, Any]] = [] def _snippet(self, lineno: int) -> str: if 1 <= lineno <= len(self._lines): return self._lines[lineno - 1].strip() return "" def _add( self, node: ast.AST, rule: dict[str, Any], detail: str = "", ) -> None: lineno = getattr(node, "lineno", 0) col = getattr(node, "col_offset", 0) + 1 # 1-based message = rule["message"].replace("{detail}", detail) if detail else rule["message"] self.findings.append( Finding( file=self._filename, line=lineno, col=col, severity=rule["severity"], rule_id=rule["id"], message=message, snippet=self._snippet(lineno), fix_hint=rule.get("fix_hint", ""), ) ) # ------------------------------------------------------------------ # Import tracking # ------------------------------------------------------------------ def visit_Import(self, node: ast.Import) -> None: for alias in node.names: name = alias.asname or alias.name.split(".")[0] self._module_aliases[name] = alias.name # Check import-triggered rules for rule in _IMPORT_RULE_MAP.get(alias.name, []): if rule not in self._active_rules: self._active_rules.append(rule) self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: module = node.module or "" for alias in node.names: local_name = alias.asname or alias.name full_path = f"{module}.{alias.name}" if module else alias.name self._from_imports[local_name] = full_path # Track module alias for call matching self._module_aliases[local_name] = full_path # Activate rules: check full_path (module.name) and module itself for trigger_mod, rules in _IMPORT_RULE_MAP.items(): if ( full_path == trigger_mod or module == trigger_mod or module.startswith(trigger_mod + ".") or trigger_mod.startswith(module + ".") ): for rule in rules: if rule not in self._active_rules: self._active_rules.append(rule) self.generic_visit(node) # ------------------------------------------------------------------ # String constant scanning # ------------------------------------------------------------------ def visit_Constant(self, node: ast.Constant) -> None: """Scan string literals for known classical algorithm names.""" if not isinstance(node.value, str): return val = node.value.strip() for rule in _RULES: for pattern in rule.get("string_patterns", set()): if pattern in val: self._add(node, rule, detail=pattern) break # one match per rule per node self.generic_visit(node) # ------------------------------------------------------------------ # Call scanning # ------------------------------------------------------------------ def visit_Call(self, node: ast.Call) -> None: """Check if this call matches any active rule's call patterns.""" for rule in self._active_rules: for call_path in rule.get("calls", set()): if self._matches_call(node, call_path): self._add(node, rule) break self.generic_visit(node) def _matches_call(self, node: ast.Call, call_path: tuple[str, ...]) -> bool: """Check if a Call node matches a (module, function) pattern.""" module_name, func_name = call_path func = node.func if isinstance(func, ast.Attribute): if func.attr != func_name: return False if isinstance(func.value, ast.Name): resolved = self._module_aliases.get(func.value.id, func.value.id) return resolved.endswith(module_name) elif isinstance(func, ast.Name): resolved = self._module_aliases.get(func.id, func.id) return resolved.endswith(func_name) return False # --------------------------------------------------------------------------- # Public scanner interface # ---------------------------------------------------------------------------
[docs] class Scanner: """Scans Python source files for classical cryptography usage. Usage:: report = Scanner.scan_directory("./src") print(report.summary()) if report.has_blocking_findings: for f in report.high + report.critical: print(f) sys.exit(1) """
[docs] @classmethod def scan_file(cls, filepath: str | pathlib.Path) -> ScanReport: """Scan a single Python file. Args: filepath: Path to a .py file. Returns: ScanReport with findings for this file. """ filepath = str(filepath) report = ScanReport(root=filepath) cls._scan_one(filepath, report) report.findings.sort(key=lambda f: (f.file, f.line, f.col)) return report
[docs] @classmethod def scan_directory( cls, directory: str | pathlib.Path, exclude: list[str] | None = None, max_file_size_kb: int = 512, ) -> ScanReport: """Recursively scan a directory for classical crypto usage. Args: directory: Root directory to scan. exclude: Glob patterns to exclude. Default excludes: .git, __pycache__, .venv, node_modules. max_file_size_kb: Skip files larger than this (avoid scanning minified/generated code). Returns: ScanReport with aggregated findings. """ directory = pathlib.Path(directory) if not directory.exists(): raise FileNotFoundError(f"Directory not found: {directory}") default_excludes = { ".git", "__pycache__", ".venv", "venv", "node_modules", ".mypy_cache", ".pytest_cache", "dist", "build", "*.egg-info", } excluded = set(exclude or []) | default_excludes report = ScanReport(root=str(directory)) for py_file in cls._iter_python_files(directory, excluded, max_file_size_kb): cls._scan_one(str(py_file), report) report.findings.sort(key=lambda f: (f.file, f.line, f.col)) return report
[docs] @classmethod def scan_source(cls, source: str, filename: str = "<string>") -> ScanReport: """Scan a source string directly (useful for testing or inline analysis).""" report = ScanReport(root=filename) cls._scan_source_text(source, filename, report) return report
# ------------------------------------------------------------------ # Internals # ------------------------------------------------------------------ @staticmethod def _iter_python_files( root: pathlib.Path, excluded: set[str], max_kb: int, ) -> Iterator[pathlib.Path]: """Yield .py files under root, respecting exclusions.""" for dirpath, dirnames, filenames in os.walk(root): # Prune excluded directories in-place dirnames[:] = [d for d in dirnames if d not in excluded and not d.startswith(".")] for filename in filenames: if not filename.endswith(".py"): continue # Check filename against exclusion patterns (supports fnmatch globs) if any(fnmatch.fnmatch(filename, pat) for pat in excluded): continue full = pathlib.Path(dirpath) / filename try: if full.stat().st_size > max_kb * 1024: continue except OSError: continue yield full @staticmethod def _scan_one(filepath: str, report: ScanReport) -> None: """Read and scan one file, appending to report.""" try: source = pathlib.Path(filepath).read_text(encoding="utf-8", errors="replace") except OSError as exc: report.errors.append({"file": filepath, "error": str(exc)}) return Scanner._scan_source_text(source, filepath, report) @staticmethod def _scan_source_text(source: str, filename: str, report: ScanReport) -> None: """Parse and scan a source text string.""" try: tree = ast.parse(source, filename=filename) except SyntaxError as exc: report.errors.append( { "file": filename, "error": f"SyntaxError at line {exc.lineno}: {exc.msg}", } ) return source_lines = source.splitlines() visitor = _ClassicalCryptoVisitor(filename, source_lines) visitor.visit(tree) report.findings.extend(visitor.findings) report.files_scanned += 1