Skip to content

Commit bfffc68

Browse files
committed
Enable .pgpass support for SSH tunnel connections
Preserve original hostname for .pgpass lookup using PostgreSQL's host/hostaddr parameters: host keeps the original DB hostname (for .pgpass and SSL), hostaddr gets 127.0.0.1 (the tunnel endpoint). Changes: - main.py: Use hostaddr instead of replacing host with 127.0.0.1 - pgexecute.py: Simplify DSN filtering to keep dsn, password, hostaddr - tests: Add 3 new tests, update existing to verify host preservation Made with ❤️ and 🤖 Claude
1 parent c84d913 commit bfffc68

4 files changed

Lines changed: 70 additions & 24 deletions

File tree

changelog.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ Features:
77
* Add cursor shape support for vi mode. When ``vi = True``, the terminal cursor now
88
reflects the current editing mode: beam in INSERT, block in NORMAL, underline in REPLACE.
99
Uses prompt_toolkit's ``ModalCursorShapeConfig``.
10+
* Enable ``.pgpass`` support for SSH tunnel connections.
11+
* Preserve original hostname for ``.pgpass`` lookup using PostgreSQL's ``hostaddr`` parameter
12+
* SSH tunnel endpoint (``127.0.0.1``) is passed via ``hostaddr``, keeping ``host`` for ``.pgpass``
13+
* Works with both DSN and host/port connection styles
1014

1115
4.4.0 (2025-12-24)
1216
==================

pgcli/main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,11 +710,15 @@ def should_ask_for_password(exc):
710710
self.logger.handlers = logger_handlers
711711

712712
atexit.register(self.ssh_tunnel.stop)
713-
host = "127.0.0.1"
713+
# Preserve original host for .pgpass lookup and SSL certificate verification.
714+
# Use hostaddr to specify the actual connection endpoint (SSH tunnel).
715+
hostaddr = "127.0.0.1"
714716
port = self.ssh_tunnel.local_bind_ports[0]
715717

716718
if dsn:
717-
dsn = make_conninfo(dsn, host=host, port=port)
719+
dsn = make_conninfo(dsn, host=host, hostaddr=hostaddr, port=port)
720+
else:
721+
kwargs["hostaddr"] = hostaddr
718722

719723
# Attempt to connect to the database.
720724
# Note that passwd may be empty on the first attempt. If connection

pgcli/pgexecute.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,11 @@ def connect(
212212
new_params.update(kwargs)
213213

214214
if new_params["dsn"]:
215-
new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
215+
# When using DSN, only keep dsn, password, and hostaddr (for SSH tunnels)
216+
new_params = {
217+
k: v for k, v in new_params.items()
218+
if k in ("dsn", "password", "hostaddr")
219+
}
216220

217221
if new_params["password"]:
218222
new_params["dsn"] = make_conninfo(new_params["dsn"], password=new_params.pop("password"))

tests/test_ssh_tunnel.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,13 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
5050
mock_pgexecute.assert_called_once()
5151

5252
call_args, call_kwargs = mock_pgexecute.call_args
53-
assert call_args == (
54-
db_params["database"],
55-
db_params["user"],
56-
db_params["passwd"],
57-
"127.0.0.1",
58-
pgcli.ssh_tunnel.local_bind_ports[0],
59-
"",
60-
notify_callback,
61-
)
53+
# Original host is preserved for .pgpass lookup, hostaddr has tunnel endpoint
54+
assert call_args[0] == db_params["database"]
55+
assert call_args[1] == db_params["user"]
56+
assert call_args[2] == db_params["passwd"]
57+
assert call_args[3] == db_params["host"] # original host preserved
58+
assert call_args[4] == pgcli.ssh_tunnel.local_bind_ports[0]
59+
assert call_kwargs.get("hostaddr") == "127.0.0.1"
6260
mock_ssh_tunnel_forwarder.reset_mock()
6361
mock_pgexecute.reset_mock()
6462

@@ -86,15 +84,9 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
8684
mock_pgexecute.assert_called_once()
8785

8886
call_args, call_kwargs = mock_pgexecute.call_args
89-
assert call_args == (
90-
db_params["database"],
91-
db_params["user"],
92-
db_params["passwd"],
93-
"127.0.0.1",
94-
pgcli.ssh_tunnel.local_bind_ports[0],
95-
"",
96-
notify_callback,
97-
)
87+
assert call_args[3] == db_params["host"] # original host preserved
88+
assert call_args[4] == pgcli.ssh_tunnel.local_bind_ports[0]
89+
assert call_kwargs.get("hostaddr") == "127.0.0.1"
9890
mock_ssh_tunnel_forwarder.reset_mock()
9991
mock_pgexecute.reset_mock()
10092

@@ -104,13 +96,55 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
10496
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
10597
pgcli.connect(dsn=dsn)
10698

107-
expected_dsn = f"user={db_params['user']} password={db_params['passwd']} host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}"
108-
10999
mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
110100
mock_pgexecute.assert_called_once()
111101

112102
call_args, call_kwargs = mock_pgexecute.call_args
113-
assert expected_dsn in call_args
103+
# DSN should contain original host AND hostaddr for tunnel
104+
dsn_arg = call_args[5]
105+
assert f"host={db_params['host']}" in dsn_arg
106+
assert "hostaddr=127.0.0.1" in dsn_arg
107+
assert f"port={pgcli.ssh_tunnel.local_bind_ports[0]}" in dsn_arg
108+
109+
110+
def test_ssh_tunnel_preserves_original_host_for_pgpass(
111+
mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
112+
) -> None:
113+
"""Verify that the original hostname is preserved for .pgpass lookup."""
114+
tunnel_url = "bastion.example.com"
115+
original_host = "production.db.example.com"
116+
117+
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
118+
pgcli.connect(database="mydb", host=original_host, user="dbuser", passwd="dbpass")
119+
120+
call_args, call_kwargs = mock_pgexecute.call_args
121+
assert call_args[3] == original_host # host preserved
122+
assert call_kwargs.get("hostaddr") == "127.0.0.1" # tunnel endpoint
123+
124+
125+
def test_ssh_tunnel_with_dsn_preserves_host(
126+
mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
127+
) -> None:
128+
"""DSN connections should include hostaddr for tunnel while preserving host."""
129+
tunnel_url = "bastion.example.com"
130+
dsn = "host=production.db.example.com port=5432 dbname=mydb user=dbuser"
131+
132+
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
133+
pgcli.connect(dsn=dsn)
134+
135+
call_args, call_kwargs = mock_pgexecute.call_args
136+
dsn_arg = call_args[5]
137+
assert "host=production.db.example.com" in dsn_arg
138+
assert "hostaddr=127.0.0.1" in dsn_arg
139+
140+
141+
def test_no_ssh_tunnel_does_not_set_hostaddr(mock_pgexecute: MagicMock) -> None:
142+
"""Without SSH tunnel, hostaddr should not be set."""
143+
pgcli = PGCli()
144+
pgcli.connect(database="mydb", host="localhost", user="user", passwd="pass")
145+
146+
call_args, call_kwargs = mock_pgexecute.call_args
147+
assert "hostaddr" not in call_kwargs
114148

115149

116150
def test_cli_with_tunnel() -> None:

0 commit comments

Comments
 (0)