|
34 | 34 | from google.cloud.sql.connector.exceptions import ConnectorLoopError |
35 | 35 | from google.cloud.sql.connector.exceptions import IncompatibleDriverError |
36 | 36 | from google.cloud.sql.connector.instance import RefreshAheadCache |
| 37 | +from google.cloud.sql.connector.resolver import DnsResolver |
37 | 38 |
|
38 | 39 |
|
39 | 40 | @pytest.mark.asyncio |
@@ -548,3 +549,113 @@ def test_connect_closed_connector( |
548 | 549 | exc_info.value.args[0] |
549 | 550 | == "Connection attempt failed because the connector has already been closed." |
550 | 551 | ) |
| 552 | + |
| 553 | + |
| 554 | +@pytest.mark.asyncio |
| 555 | +async def test_Connector_connect_async_custom_dns_resolver( |
| 556 | + fake_credentials: Credentials, fake_client: CloudSQLClient |
| 557 | +) -> None: |
| 558 | + """Test that Connector.connect_async uses custom DNS name resolution.""" |
| 559 | + |
| 560 | + # Create a mock DnsResolver that returns a fixed IP |
| 561 | + with patch( |
| 562 | + "google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record" |
| 563 | + ) as mock_resolve_a: |
| 564 | + mock_resolve_a.return_value = ["1.2.3.4"] |
| 565 | + |
| 566 | + # We also need to patch resolve because DnsResolver.resolve does DNS lookup for TXT |
| 567 | + # But we can patch DnsResolver.resolve to return a ConnectionName with domain name |
| 568 | + with patch( |
| 569 | + "google.cloud.sql.connector.resolver.DnsResolver.resolve" |
| 570 | + ) as mock_resolve: |
| 571 | + # This must return a ConnectionName object with domain_name set |
| 572 | + conn_name_with_domain = ConnectionName( |
| 573 | + "test-project", "test-region", "test-instance", "db.example.com" |
| 574 | + ) |
| 575 | + mock_resolve.return_value = conn_name_with_domain |
| 576 | + |
| 577 | + async with Connector( |
| 578 | + credentials=fake_credentials, |
| 579 | + loop=asyncio.get_running_loop(), |
| 580 | + resolver=DnsResolver, |
| 581 | + ) as connector: |
| 582 | + connector._client = fake_client |
| 583 | + |
| 584 | + # patch db connection creation |
| 585 | + with patch( |
| 586 | + "google.cloud.sql.connector.asyncpg.connect" |
| 587 | + ) as mock_connect: |
| 588 | + mock_connect.return_value = True |
| 589 | + |
| 590 | + # Call connect_async |
| 591 | + # Use "db.example.com" as instance connection string (resolver will handle it) |
| 592 | + connection = await connector.connect_async( |
| 593 | + "db.example.com", |
| 594 | + "asyncpg", |
| 595 | + user="my-user", |
| 596 | + password="my-pass", |
| 597 | + db="my-db", |
| 598 | + ) |
| 599 | + |
| 600 | + # Verify mock_connect was called with resolved IP "1.2.3.4" |
| 601 | + # The first arg to mock_connect (which patches connector call) is ip_address |
| 602 | + args, _ = mock_connect.call_args |
| 603 | + assert args[0] == "1.2.3.4" |
| 604 | + assert connection is True |
| 605 | + |
| 606 | + |
| 607 | +@pytest.mark.asyncio |
| 608 | +async def test_Connector_connect_async_custom_dns_resolver_fallback( |
| 609 | + fake_credentials: Credentials, fake_client: CloudSQLClient |
| 610 | +) -> None: |
| 611 | + """Test that Connector.connect_async falls back if DNS resolution fails.""" |
| 612 | + |
| 613 | + # Create a mock DnsResolver that returns empty list (failure) |
| 614 | + with patch( |
| 615 | + "google.cloud.sql.connector.resolver.DnsResolver.resolve_a_record" |
| 616 | + ) as mock_resolve_a: |
| 617 | + mock_resolve_a.return_value = [] |
| 618 | + |
| 619 | + with patch( |
| 620 | + "google.cloud.sql.connector.resolver.DnsResolver.resolve" |
| 621 | + ) as mock_resolve: |
| 622 | + conn_name_with_domain = ConnectionName( |
| 623 | + "test-project", "test-region", "test-instance", "db.example.com" |
| 624 | + ) |
| 625 | + mock_resolve.return_value = conn_name_with_domain |
| 626 | + |
| 627 | + async with Connector( |
| 628 | + credentials=fake_credentials, |
| 629 | + loop=asyncio.get_running_loop(), |
| 630 | + resolver=DnsResolver, |
| 631 | + ) as connector: |
| 632 | + connector._client = fake_client |
| 633 | + |
| 634 | + # Save original IPs to restore later (fake_instance is session-scoped) |
| 635 | + original_ips = fake_client.instance.ip_addrs |
| 636 | + # Set metadata IP to something specific |
| 637 | + fake_client.instance.ip_addrs = {"PRIMARY": "5.6.7.8"} |
| 638 | + |
| 639 | + try: |
| 640 | + with patch( |
| 641 | + "google.cloud.sql.connector.asyncpg.connect" |
| 642 | + ) as mock_connect: |
| 643 | + mock_connect.return_value = True |
| 644 | + |
| 645 | + connection = await connector.connect_async( |
| 646 | + "db.example.com", |
| 647 | + "asyncpg", |
| 648 | + user="my-user", |
| 649 | + password="my-pass", |
| 650 | + db="my-db", |
| 651 | + ) |
| 652 | + |
| 653 | + # Verify mock_connect was called with metadata IP "5.6.7.8" |
| 654 | + args, _ = mock_connect.call_args |
| 655 | + assert args[0] == "5.6.7.8" |
| 656 | + assert connection is True |
| 657 | + finally: |
| 658 | + # Restore original IPs |
| 659 | + fake_client.instance.ip_addrs = original_ips |
| 660 | + |
| 661 | + |
0 commit comments