From 4c1baf264a2bc69f260a81546d0243d2c4fe979c Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 13 May 2026 15:04:17 -0400 Subject: [PATCH 1/2] enforce immutability in ci --- .github/workflows/build_and_test.yml | 5 ++ .../enforce_alembic_revision_immutability.py | 77 ++++++++++++++++--- 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 7c284ffe2a..7d080382ed 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -33,6 +33,11 @@ jobs: contents: read steps: - uses: actions/checkout@v5 + with: + # Full history is required so pre-commit hooks (notably + # enforce_alembic_revision_immutability) can compute merge-bases and + # diff ranges against origin/main. + fetch-depth: 0 - uses: actions/setup-python@v6 with: diff --git a/build_scripts/enforce_alembic_revision_immutability.py b/build_scripts/enforce_alembic_revision_immutability.py index 1e1971542b..efcb5b8b34 100644 --- a/build_scripts/enforce_alembic_revision_immutability.py +++ b/build_scripts/enforce_alembic_revision_immutability.py @@ -4,36 +4,89 @@ """ Migration history must be immutable. This hook enforces that by preventing deletion or updates to migration scripts. -Checks both staged changes (local pre-commit) and the full branch diff against origin/main (CI). +Checks staged changes (local pre-commit), the full branch diff against origin/main (CI PRs), +and the previous commit (CI merge-queue / push-to-main). """ +import os import subprocess import sys _VERSIONS_PATH = "pyrit/memory/alembic/versions/" -def _git(*args: str) -> str: - result = subprocess.run(["git", *args], capture_output=True, text=True) - return result.stdout.strip() +def _git(*args: str) -> subprocess.CompletedProcess[str]: + return subprocess.run(["git", *args], capture_output=True, text=True) -def _has_non_add_changes(diff_spec: list[str]) -> bool: - output = _git("diff", "--name-status", *diff_spec, "--", _VERSIONS_PATH) - return any(line and not line.startswith("A") for line in output.splitlines()) +def _git_stdout(*args: str) -> str: + return _git(*args).stdout.strip() + + +def _get_violations(diff_spec: list[str]) -> list[str]: + """Return lines from ``git diff --name-status`` that are not pure additions.""" + output = _git_stdout("diff", "--name-status", *diff_spec, "--", _VERSIONS_PATH) + return [line for line in output.splitlines() if line and not line.startswith("A")] + + +def _in_ci() -> bool: + return os.environ.get("CI", "").lower() in {"1", "true"} or "GITHUB_ACTIONS" in os.environ + + +def _fail_ci(reason: str) -> bool: + """Fail closed in CI when we can't perform the check, pass through locally.""" + if _in_ci(): + print(f"[ERROR] Cannot verify alembic revision immutability: {reason}") + print(" Ensure the CI checkout has full history (fetch-depth: 0).") + return True + return False def has_revision_violations() -> bool: # Local pre-commit: check staged changes - if _has_non_add_changes(["--cached"]): + violations = _get_violations(["--cached"]) + if violations: + _report(violations) + return True + + # CI (PR): check full branch diff against origin/main + merge_base = _git_stdout("merge-base", "origin/main", "HEAD") + head_sha = _git_stdout("rev-parse", "HEAD") + if merge_base and merge_base != head_sha: + violations = _get_violations([f"{merge_base}...HEAD"]) + if violations: + _report(violations) + return True + elif not merge_base: + # On CI this is almost always a shallow-clone problem and must not be + # treated as "no violations". Locally (e.g. a brand-new repo with no + # origin/main) it's expected, so we only fail in CI. + if _fail_ci("git merge-base origin/main HEAD returned empty"): + return True + + # CI (merge-queue / push-to-main): compare HEAD against its first parent. + # In a merge queue the branch *is* main, so merge-base == HEAD and the + # check above produces an empty diff. Comparing HEAD~1..HEAD catches + # deletions or modifications introduced by the merge commit. + head_parent = _git("rev-parse", "--verify", "HEAD~1") + if head_parent.returncode == 0: + violations = _get_violations(["HEAD~1..HEAD"]) + if violations: + _report(violations) + return True + elif _fail_ci("HEAD~1 is not available (shallow clone?)"): return True - # CI: check full branch diff against origin/main - merge_base = _git("merge-base", "origin/main", "HEAD") - return bool(merge_base and _has_non_add_changes([f"{merge_base}...HEAD"])) + return False + + +def _report(violations: list[str]) -> None: + print("[ERROR] Migration scripts can only be added, not modified or deleted.") + print("The following disallowed changes were detected:") + for v in violations: + print(f" {v}") if __name__ == "__main__": if has_revision_violations(): - print("[ERROR] Migration scripts can only be added, not modified or deleted.") sys.exit(1) From 01f72f15e161183ce466c17fcae0dfc87b65fbf4 Mon Sep 17 00:00:00 2001 From: hannahwestra25 Date: Wed, 13 May 2026 18:07:47 -0400 Subject: [PATCH 2/2] pr comment --- .../enforce_alembic_revision_immutability.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/build_scripts/enforce_alembic_revision_immutability.py b/build_scripts/enforce_alembic_revision_immutability.py index efcb5b8b34..9fa5e1122e 100644 --- a/build_scripts/enforce_alembic_revision_immutability.py +++ b/build_scripts/enforce_alembic_revision_immutability.py @@ -49,25 +49,22 @@ def has_revision_violations() -> bool: _report(violations) return True - # CI (PR): check full branch diff against origin/main - merge_base = _git_stdout("merge-base", "origin/main", "HEAD") - head_sha = _git_stdout("rev-parse", "HEAD") - if merge_base and merge_base != head_sha: - violations = _get_violations([f"{merge_base}...HEAD"]) + # CI (PR): diff branch against its merge-base with origin/main. + # The three-dot syntax (A...B) resolves to ``git diff $(merge-base A B) B`` + # automatically, so we don't need a separate merge-base call. When + # origin/main is missing (shallow clone) git exits non-zero. + pr_diff = _git("diff", "--name-status", "origin/main...HEAD", "--", _VERSIONS_PATH) + if pr_diff.returncode == 0: + violations = [line for line in pr_diff.stdout.strip().splitlines() if line and not line.startswith("A")] if violations: _report(violations) return True - elif not merge_base: - # On CI this is almost always a shallow-clone problem and must not be - # treated as "no violations". Locally (e.g. a brand-new repo with no - # origin/main) it's expected, so we only fail in CI. - if _fail_ci("git merge-base origin/main HEAD returned empty"): - return True + elif _fail_ci("origin/main is not available (shallow clone?)"): + return True - # CI (merge-queue / push-to-main): compare HEAD against its first parent. - # In a merge queue the branch *is* main, so merge-base == HEAD and the - # check above produces an empty diff. Comparing HEAD~1..HEAD catches - # deletions or modifications introduced by the merge commit. + # CI (merge-queue / push-to-main): on main the branch *is* origin/main, so + # the diff above is empty. Compare HEAD against its first parent to catch + # deletions or modifications introduced by the merge commit itself. head_parent = _git("rev-parse", "--verify", "HEAD~1") if head_parent.returncode == 0: violations = _get_violations(["HEAD~1..HEAD"])