|
40 | 40 | from vgi_rpc.metadata import ( |
41 | 41 | LOCATION_FETCH_MS_KEY, |
42 | 42 | LOCATION_KEY, |
| 43 | + LOCATION_SHA256_KEY, |
43 | 44 | LOCATION_SOURCE_KEY, |
44 | 45 | LOG_LEVEL_KEY, |
45 | 46 | encode_metadata, |
@@ -792,6 +793,179 @@ def test_metadata_survives_roundtrip(self) -> None: |
792 | 793 | assert resolved_cm.get(b"user.key") == b"user.value" |
793 | 794 |
|
794 | 795 |
|
| 796 | +# =========================================================================== |
| 797 | +# Unit tests — SHA-256 checksum |
| 798 | +# =========================================================================== |
| 799 | + |
| 800 | + |
| 801 | +class TestSHA256Checksum: |
| 802 | + """Tests for SHA-256 checksum on pointer batches.""" |
| 803 | + |
| 804 | + def test_sha256_present_on_externalize_batch(self) -> None: |
| 805 | + """Pointer batch from maybe_externalize_batch has SHA-256 metadata.""" |
| 806 | + storage = MockStorage() |
| 807 | + config = ExternalLocationConfig(storage=storage, externalize_threshold_bytes=10) |
| 808 | + |
| 809 | + batch = pa.RecordBatch.from_pydict({"value": list(range(100))}, schema=_SCHEMA) |
| 810 | + result_batch, result_cm = maybe_externalize_batch(batch, None, config) |
| 811 | + |
| 812 | + assert result_batch.num_rows == 0 |
| 813 | + assert result_cm is not None |
| 814 | + sha256_val = result_cm.get(LOCATION_SHA256_KEY) |
| 815 | + assert sha256_val is not None |
| 816 | + sha256_str = sha256_val.decode() if isinstance(sha256_val, bytes) else sha256_val |
| 817 | + assert len(sha256_str) == 64 # hex-encoded SHA-256 |
| 818 | + |
| 819 | + def test_sha256_present_on_externalize_collector(self) -> None: |
| 820 | + """Pointer batch from maybe_externalize_collector has SHA-256 metadata.""" |
| 821 | + storage = MockStorage() |
| 822 | + config = ExternalLocationConfig(storage=storage, externalize_threshold_bytes=10) |
| 823 | + |
| 824 | + out = OutputCollector(_SCHEMA) |
| 825 | + out.emit_pydict({"value": list(range(100))}) |
| 826 | + |
| 827 | + result = maybe_externalize_collector(out, config) |
| 828 | + assert len(result) == 1 |
| 829 | + _batch, cm = result[0] |
| 830 | + assert cm is not None |
| 831 | + sha256_val = cm.get(LOCATION_SHA256_KEY) |
| 832 | + assert sha256_val is not None |
| 833 | + sha256_str = sha256_val.decode() if isinstance(sha256_val, bytes) else sha256_val |
| 834 | + assert len(sha256_str) == 64 |
| 835 | + |
| 836 | + def test_sha256_matches_raw_ipc_bytes(self) -> None: |
| 837 | + """SHA-256 in metadata matches hash of raw IPC bytes (pre-compression).""" |
| 838 | + import hashlib |
| 839 | + |
| 840 | + storage = MockStorage() |
| 841 | + config = ExternalLocationConfig(storage=storage, externalize_threshold_bytes=10) |
| 842 | + |
| 843 | + batch = pa.RecordBatch.from_pydict({"value": list(range(100))}, schema=_SCHEMA) |
| 844 | + _result_batch, result_cm = maybe_externalize_batch(batch, None, config) |
| 845 | + |
| 846 | + assert result_cm is not None |
| 847 | + sha256_val = result_cm.get(LOCATION_SHA256_KEY) |
| 848 | + assert sha256_val is not None |
| 849 | + expected_hex = sha256_val.decode() if isinstance(sha256_val, bytes) else sha256_val |
| 850 | + |
| 851 | + # Uploaded data is raw IPC (no compression configured) |
| 852 | + uploaded_bytes = next(iter(storage.data.values())) |
| 853 | + actual_hex = hashlib.sha256(uploaded_bytes).hexdigest() |
| 854 | + assert actual_hex == expected_hex |
| 855 | + |
| 856 | + def test_sha256_verified_on_fetch(self) -> None: |
| 857 | + """Resolution succeeds when SHA-256 matches.""" |
| 858 | + storage = MockStorage() |
| 859 | + config = ExternalLocationConfig(storage=storage, externalize_threshold_bytes=10, max_retries=0) |
| 860 | + |
| 861 | + batch = pa.RecordBatch.from_pydict({"value": list(range(100))}, schema=_SCHEMA) |
| 862 | + ext_batch, ext_cm = maybe_externalize_batch(batch, None, config) |
| 863 | + |
| 864 | + with _mock_aio(storage): |
| 865 | + resolved, _ = resolve_external_location(ext_batch, ext_cm, config) |
| 866 | + |
| 867 | + assert resolved.num_rows == 100 |
| 868 | + |
| 869 | + def test_sha256_mismatch_raises(self) -> None: |
| 870 | + """Resolution fails when SHA-256 does not match.""" |
| 871 | + storage = MockStorage() |
| 872 | + config = ExternalLocationConfig(storage=storage, externalize_threshold_bytes=10, max_retries=0) |
| 873 | + |
| 874 | + batch = pa.RecordBatch.from_pydict({"value": list(range(100))}, schema=_SCHEMA) |
| 875 | + ext_batch, ext_cm = maybe_externalize_batch(batch, None, config) |
| 876 | + |
| 877 | + # Tamper with the SHA-256 in the metadata |
| 878 | + assert ext_cm is not None |
| 879 | + location_val = ext_cm.get(LOCATION_KEY) |
| 880 | + assert location_val is not None |
| 881 | + tampered_cm = pa.KeyValueMetadata( |
| 882 | + { |
| 883 | + LOCATION_KEY: location_val, |
| 884 | + LOCATION_SHA256_KEY: b"0" * 64, |
| 885 | + } |
| 886 | + ) |
| 887 | + |
| 888 | + with ( |
| 889 | + _mock_aio(storage), |
| 890 | + pytest.raises(RuntimeError, match="SHA-256 checksum mismatch"), |
| 891 | + ): |
| 892 | + resolve_external_location(ext_batch, tampered_cm, config) |
| 893 | + |
| 894 | + def test_sha256_absent_skips_verification(self) -> None: |
| 895 | + """Old pointer batches without SHA-256 metadata resolve successfully.""" |
| 896 | + storage = MockStorage() |
| 897 | + config = ExternalLocationConfig(storage=storage, externalize_threshold_bytes=10, max_retries=0) |
| 898 | + |
| 899 | + # Create IPC data and a pointer batch WITHOUT SHA-256 |
| 900 | + data_batch = pa.RecordBatch.from_pydict({"value": [42]}, schema=_SCHEMA) |
| 901 | + ipc_bytes = _serialize_ipc(_SCHEMA, [(data_batch, None)]) |
| 902 | + url = "https://mock.storage/no-sha" |
| 903 | + storage.data[url] = ipc_bytes |
| 904 | + |
| 905 | + pointer, cm = make_external_location_batch(_SCHEMA, url) # no sha256 param |
| 906 | + assert cm.get(LOCATION_SHA256_KEY) is None |
| 907 | + |
| 908 | + with _mock_aio(storage): |
| 909 | + resolved, _ = resolve_external_location(pointer, cm, config) |
| 910 | + |
| 911 | + assert resolved.num_rows == 1 |
| 912 | + assert resolved.column("value")[0].as_py() == 42 |
| 913 | + |
| 914 | + def test_sha256_roundtrip_with_compression(self) -> None: |
| 915 | + """SHA-256 checksum works correctly with zstd compression.""" |
| 916 | + storage = MockStorage() |
| 917 | + config = ExternalLocationConfig( |
| 918 | + storage=storage, |
| 919 | + externalize_threshold_bytes=10, |
| 920 | + compression=Compression(), |
| 921 | + max_retries=0, |
| 922 | + ) |
| 923 | + |
| 924 | + batch = pa.RecordBatch.from_pydict({"value": list(range(100))}, schema=_SCHEMA) |
| 925 | + ext_batch, ext_cm = maybe_externalize_batch(batch, None, config) |
| 926 | + |
| 927 | + assert ext_cm is not None |
| 928 | + assert ext_cm.get(LOCATION_SHA256_KEY) is not None |
| 929 | + |
| 930 | + with _mock_aio(storage): |
| 931 | + resolved, _ = resolve_external_location(ext_batch, ext_cm, config) |
| 932 | + |
| 933 | + assert resolved.num_rows == 100 |
| 934 | + assert resolved.column("value")[0].as_py() == 0 |
| 935 | + assert resolved.column("value")[99].as_py() == 99 |
| 936 | + |
| 937 | + def test_sha256_with_compression_is_pre_compression(self) -> None: |
| 938 | + """SHA-256 is of raw IPC bytes, not compressed bytes.""" |
| 939 | + import hashlib |
| 940 | + |
| 941 | + storage = MockStorage() |
| 942 | + config = ExternalLocationConfig( |
| 943 | + storage=storage, |
| 944 | + externalize_threshold_bytes=10, |
| 945 | + compression=Compression(), |
| 946 | + ) |
| 947 | + |
| 948 | + batch = pa.RecordBatch.from_pydict({"value": list(range(100))}, schema=_SCHEMA) |
| 949 | + _result_batch, result_cm = maybe_externalize_batch(batch, None, config) |
| 950 | + |
| 951 | + assert result_cm is not None |
| 952 | + sha256_hex = result_cm.get(LOCATION_SHA256_KEY) |
| 953 | + assert sha256_hex is not None |
| 954 | + sha256_str = sha256_hex.decode() if isinstance(sha256_hex, bytes) else sha256_hex |
| 955 | + |
| 956 | + # Uploaded data is compressed — SHA-256 should NOT match compressed bytes |
| 957 | + uploaded_compressed = next(iter(storage.data.values())) |
| 958 | + compressed_hash = hashlib.sha256(uploaded_compressed).hexdigest() |
| 959 | + assert sha256_str != compressed_hash, "SHA-256 should be of raw IPC, not compressed bytes" |
| 960 | + |
| 961 | + # Decompress and verify SHA-256 matches the raw IPC |
| 962 | + import zstandard |
| 963 | + |
| 964 | + raw_ipc = zstandard.ZstdDecompressor().decompress(uploaded_compressed) |
| 965 | + raw_hash = hashlib.sha256(raw_ipc).hexdigest() |
| 966 | + assert sha256_str == raw_hash |
| 967 | + |
| 968 | + |
795 | 969 | # =========================================================================== |
796 | 970 | # Integration tests — pipe transport with MockStorage |
797 | 971 | # =========================================================================== |
|
0 commit comments