@@ -50,15 +50,13 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
5050 mock_pgexecute .assert_called_once ()
5151
5252 call_args , call_kwargs = mock_pgexecute .call_args
53- assert call_args == (
54- db_params ["database" ],
55- db_params ["user" ],
56- db_params ["passwd" ],
57- "127.0.0.1" ,
58- pgcli .ssh_tunnel .local_bind_ports [0 ],
59- "" ,
60- notify_callback ,
61- )
53+ # Original host is preserved for .pgpass lookup, hostaddr has tunnel endpoint
54+ assert call_args [0 ] == db_params ["database" ]
55+ assert call_args [1 ] == db_params ["user" ]
56+ assert call_args [2 ] == db_params ["passwd" ]
57+ assert call_args [3 ] == db_params ["host" ] # original host preserved
58+ assert call_args [4 ] == pgcli .ssh_tunnel .local_bind_ports [0 ]
59+ assert call_kwargs .get ("hostaddr" ) == "127.0.0.1"
6260 mock_ssh_tunnel_forwarder .reset_mock ()
6361 mock_pgexecute .reset_mock ()
6462
@@ -86,15 +84,9 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
8684 mock_pgexecute .assert_called_once ()
8785
8886 call_args , call_kwargs = mock_pgexecute .call_args
89- assert call_args == (
90- db_params ["database" ],
91- db_params ["user" ],
92- db_params ["passwd" ],
93- "127.0.0.1" ,
94- pgcli .ssh_tunnel .local_bind_ports [0 ],
95- "" ,
96- notify_callback ,
97- )
87+ assert call_args [3 ] == db_params ["host" ] # original host preserved
88+ assert call_args [4 ] == pgcli .ssh_tunnel .local_bind_ports [0 ]
89+ assert call_kwargs .get ("hostaddr" ) == "127.0.0.1"
9890 mock_ssh_tunnel_forwarder .reset_mock ()
9991 mock_pgexecute .reset_mock ()
10092
@@ -104,13 +96,55 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
10496 pgcli = PGCli (ssh_tunnel_url = tunnel_url )
10597 pgcli .connect (dsn = dsn )
10698
107- expected_dsn = f"user={ db_params ['user' ]} password={ db_params ['passwd' ]} host=127.0.0.1 port={ pgcli .ssh_tunnel .local_bind_ports [0 ]} "
108-
10999 mock_ssh_tunnel_forwarder .assert_called_once_with (** expected_tunnel_params )
110100 mock_pgexecute .assert_called_once ()
111101
112102 call_args , call_kwargs = mock_pgexecute .call_args
113- assert expected_dsn in call_args
103+ # DSN should contain original host AND hostaddr for tunnel
104+ dsn_arg = call_args [5 ]
105+ assert f"host={ db_params ['host' ]} " in dsn_arg
106+ assert "hostaddr=127.0.0.1" in dsn_arg
107+ assert f"port={ pgcli .ssh_tunnel .local_bind_ports [0 ]} " in dsn_arg
108+
109+
110+ def test_ssh_tunnel_preserves_original_host_for_pgpass (
111+ mock_ssh_tunnel_forwarder : MagicMock , mock_pgexecute : MagicMock
112+ ) -> None :
113+ """Verify that the original hostname is preserved for .pgpass lookup."""
114+ tunnel_url = "bastion.example.com"
115+ original_host = "production.db.example.com"
116+
117+ pgcli = PGCli (ssh_tunnel_url = tunnel_url )
118+ pgcli .connect (database = "mydb" , host = original_host , user = "dbuser" , passwd = "dbpass" )
119+
120+ call_args , call_kwargs = mock_pgexecute .call_args
121+ assert call_args [3 ] == original_host # host preserved
122+ assert call_kwargs .get ("hostaddr" ) == "127.0.0.1" # tunnel endpoint
123+
124+
125+ def test_ssh_tunnel_with_dsn_preserves_host (
126+ mock_ssh_tunnel_forwarder : MagicMock , mock_pgexecute : MagicMock
127+ ) -> None :
128+ """DSN connections should include hostaddr for tunnel while preserving host."""
129+ tunnel_url = "bastion.example.com"
130+ dsn = "host=production.db.example.com port=5432 dbname=mydb user=dbuser"
131+
132+ pgcli = PGCli (ssh_tunnel_url = tunnel_url )
133+ pgcli .connect (dsn = dsn )
134+
135+ call_args , call_kwargs = mock_pgexecute .call_args
136+ dsn_arg = call_args [5 ]
137+ assert "host=production.db.example.com" in dsn_arg
138+ assert "hostaddr=127.0.0.1" in dsn_arg
139+
140+
141+ def test_no_ssh_tunnel_does_not_set_hostaddr (mock_pgexecute : MagicMock ) -> None :
142+ """Without SSH tunnel, hostaddr should not be set."""
143+ pgcli = PGCli ()
144+ pgcli .connect (database = "mydb" , host = "localhost" , user = "user" , passwd = "pass" )
145+
146+ call_args , call_kwargs = mock_pgexecute .call_args
147+ assert "hostaddr" not in call_kwargs
114148
115149
116150def test_cli_with_tunnel () -> None :
0 commit comments