|
25 | 25 | from aioresponses import aioresponses as aioresponses_ctx |
26 | 26 | from pyarrow import ipc |
27 | 27 |
|
28 | | -from vgi_rpc.external import ExternalLocationConfig, UploadUrl |
| 28 | +from vgi_rpc.external import ( |
| 29 | + ExternalLocationConfig, |
| 30 | + UploadUrl, |
| 31 | + make_external_location_batch, |
| 32 | + resolve_external_location, |
| 33 | +) |
29 | 34 | from vgi_rpc.http import ( |
30 | 35 | _ARROW_CONTENT_TYPE, |
31 | 36 | MAX_REQUEST_BYTES_HEADER, |
@@ -673,6 +678,132 @@ def test_small_batches_inline(self) -> None: |
673 | 678 | client.close() |
674 | 679 |
|
675 | 680 |
|
| 681 | +class TestHttpExternalSHA256: |
| 682 | + """SHA-256 checksum tests for ExternalLocation over HTTP transport.""" |
| 683 | + |
| 684 | + def _make_config(self, storage: MockStorage, threshold: int = 100) -> ExternalLocationConfig: |
| 685 | + """Create an ExternalLocationConfig with low threshold for testing.""" |
| 686 | + return ExternalLocationConfig( |
| 687 | + storage=storage, |
| 688 | + externalize_threshold_bytes=threshold, |
| 689 | + max_retries=0, |
| 690 | + retry_delay_seconds=0.0, |
| 691 | + ) |
| 692 | + |
| 693 | + def _make_client(self, config: ExternalLocationConfig) -> _SyncTestClient: |
| 694 | + """Create a _SyncTestClient wrapping an RpcServer with external storage.""" |
| 695 | + server = RpcServer(_ExternalService, _ExternalServiceImpl(), external_location=config) |
| 696 | + return make_sync_client(server, signing_key=b"test-key") |
| 697 | + |
| 698 | + def _mock_aio_dynamic(self, storage: MockStorage, mock: aioresponses_ctx) -> None: |
| 699 | + """Register pattern-based HEAD + GET callbacks.""" |
| 700 | + pattern = re.compile(r"^https://mock\.storage/.*$") |
| 701 | + |
| 702 | + def _head_callback(url_: Any, **kwargs: Any) -> CallbackResult: |
| 703 | + url_str = str(url_) |
| 704 | + if url_str not in storage.data: |
| 705 | + return CallbackResult(status=404) |
| 706 | + body = storage.data[url_str] |
| 707 | + return CallbackResult(status=200, headers={"Content-Length": str(len(body))}) |
| 708 | + |
| 709 | + def _get_callback(url_: Any, **kwargs: Any) -> CallbackResult: |
| 710 | + url_str = str(url_) |
| 711 | + if url_str not in storage.data: |
| 712 | + return CallbackResult(status=404) |
| 713 | + body = storage.data[url_str] |
| 714 | + return CallbackResult(status=200, body=body, headers={"Content-Length": str(len(body))}) |
| 715 | + |
| 716 | + for _ in range(50): |
| 717 | + mock.head(pattern, callback=_head_callback) |
| 718 | + mock.get(pattern, callback=_get_callback) |
| 719 | + |
| 720 | + def test_sha256_present_on_http_externalize(self) -> None: |
| 721 | + """Externalized batch over HTTP has SHA-256 metadata in the pointer.""" |
| 722 | + import hashlib |
| 723 | + |
| 724 | + storage = MockStorage() |
| 725 | + config = self._make_config(storage, threshold=10) |
| 726 | + client = self._make_client(config) |
| 727 | + |
| 728 | + with aioresponses_ctx() as mock: |
| 729 | + self._mock_aio_dynamic(storage, mock) |
| 730 | + with http_connect(_ExternalService, client=client, external_location=config) as proxy: |
| 731 | + result = proxy.echo_large(data="x" * 200) |
| 732 | + |
| 733 | + assert result == "x" * 200 |
| 734 | + assert len(storage.data) >= 1 |
| 735 | + |
| 736 | + # Verify the uploaded data's SHA-256 matches what was computed |
| 737 | + uploaded_bytes = next(iter(storage.data.values())) |
| 738 | + expected_sha256 = hashlib.sha256(uploaded_bytes).hexdigest() |
| 739 | + assert len(expected_sha256) == 64 |
| 740 | + client.close() |
| 741 | + |
| 742 | + def test_sha256_verified_on_http_roundtrip(self) -> None: |
| 743 | + """Full HTTP round-trip with SHA-256 verification succeeds.""" |
| 744 | + storage = MockStorage() |
| 745 | + config = self._make_config(storage, threshold=10) |
| 746 | + client = self._make_client(config) |
| 747 | + |
| 748 | + with aioresponses_ctx() as mock: |
| 749 | + self._mock_aio_dynamic(storage, mock) |
| 750 | + with http_connect(_ExternalService, client=client, external_location=config) as proxy: |
| 751 | + result = proxy.echo_large(data="hello world " * 50) |
| 752 | + |
| 753 | + assert result == "hello world " * 50 |
| 754 | + client.close() |
| 755 | + |
| 756 | + def test_sha256_mismatch_over_http(self) -> None: |
| 757 | + """SHA-256 mismatch during HTTP resolution raises RuntimeError.""" |
| 758 | + from io import BytesIO |
| 759 | + |
| 760 | + storage = MockStorage() |
| 761 | + config = self._make_config(storage, threshold=10) |
| 762 | + |
| 763 | + # Create a real data batch and serialize it |
| 764 | + schema = pa.schema([pa.field("data", pa.string())]) |
| 765 | + data_batch = pa.RecordBatch.from_pydict({"data": ["x" * 200]}, schema=schema) |
| 766 | + buf = BytesIO() |
| 767 | + with ipc.new_stream(buf, schema) as writer: |
| 768 | + writer.write_batch(data_batch) |
| 769 | + ipc_bytes = buf.getvalue() |
| 770 | + |
| 771 | + url = "https://mock.storage/sha-mismatch" |
| 772 | + storage.data[url] = ipc_bytes |
| 773 | + |
| 774 | + # Create a pointer batch with a deliberately wrong SHA-256 |
| 775 | + pointer, cm = make_external_location_batch(schema, url, sha256="0" * 64) |
| 776 | + |
| 777 | + with ( |
| 778 | + aioresponses_ctx() as mock, |
| 779 | + pytest.raises(RuntimeError, match="SHA-256 checksum mismatch"), |
| 780 | + ): |
| 781 | + self._mock_aio_dynamic(storage, mock) |
| 782 | + resolve_external_location(pointer, cm, config) |
| 783 | + |
| 784 | + def test_sha256_http_stream_externalized(self) -> None: |
| 785 | + """Externalized stream batches over HTTP include SHA-256 metadata.""" |
| 786 | + storage = MockStorage() |
| 787 | + config = self._make_config(storage, threshold=100) |
| 788 | + client = self._make_client(config) |
| 789 | + |
| 790 | + received_logs: list[Message] = [] |
| 791 | + |
| 792 | + with aioresponses_ctx() as mock: |
| 793 | + self._mock_aio_dynamic(storage, mock) |
| 794 | + with http_connect( |
| 795 | + _ExternalService, client=client, on_log=received_logs.append, external_location=config |
| 796 | + ) as proxy: |
| 797 | + batches = list(proxy.stream_large(count=3, size=50)) |
| 798 | + |
| 799 | + assert len(batches) == 3 |
| 800 | + for ab in batches: |
| 801 | + assert ab.batch.num_rows == 50 |
| 802 | + # Verify storage was used (batches were externalized) |
| 803 | + assert len(storage.data) >= 1 |
| 804 | + client.close() |
| 805 | + |
| 806 | + |
676 | 807 | # --------------------------------------------------------------------------- |
677 | 808 | # Auth test protocol + implementation |
678 | 809 | # --------------------------------------------------------------------------- |
|
0 commit comments