Skip to content

Commit 56882e2

Browse files
committed
fix proxyDialAddr and add more tests for it
1 parent 94f58c8 commit 56882e2

2 files changed

Lines changed: 17 additions & 8 deletions

File tree

internal/api/proxy.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ func (c *connWithBufferedReader) Read(p []byte) (int, error) {
2424
// proxyDialAddr returns proxyURL.Host with a default port appended if one is
2525
// not already present (443 for https, 80 for http).
2626
func proxyDialAddr(proxyURL *url.URL) string {
27-
addr := proxyURL.Host
28-
if _, _, err := net.SplitHostPort(addr); err != nil {
29-
if proxyURL.Scheme == "https" {
30-
return net.JoinHostPort(addr, "443")
31-
}
32-
return net.JoinHostPort(addr, "80")
27+
// net.SplitHostPort returns an error when the input doesn't contain a port
28+
if _, _, err := net.SplitHostPort(proxyURL.Host); err == nil {
29+
return proxyURL.Host
30+
}
31+
if proxyURL.Scheme == "https" {
32+
return net.JoinHostPort(proxyURL.Hostname(), "443")
3333
}
34-
return addr
34+
return net.JoinHostPort(proxyURL.Hostname(), "80")
3535
}
3636

3737
// withProxyTransport modifies the given transport to handle proxying of unix, socks5 and http connections.

internal/api/proxy_test.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,19 @@ func TestProxyDialAddr(t *testing.T) {
416416
{"https without port", "https://proxy.example.com", "proxy.example.com:443"},
417417
{"http with port", "http://proxy.example.com:8080", "proxy.example.com:8080"},
418418
{"http without port", "http://proxy.example.com", "proxy.example.com:80"},
419+
{"ipv4 with port", "http://192.168.1.100:3128", "192.168.1.100:3128"},
420+
{"ipv4 without port https", "https://10.0.0.1", "10.0.0.1:443"},
421+
{"ipv4 without port http", "http://172.16.0.5", "172.16.0.5:80"},
422+
{"ipv6 with port", "http://[::1]:8080", "[::1]:8080"},
423+
{"ipv6 without port https", "https://[2001:db8::1]", "[2001:db8::1]:443"},
424+
{"ipv6 without port http", "http://[fe80::1]", "[fe80::1]:80"},
425+
{"localhost with port", "http://localhost:9090", "localhost:9090"},
426+
{"localhost without port https", "https://localhost", "localhost:443"},
427+
{"localhost without port http", "http://localhost", "localhost:80"},
419428
}
420429
for _, tt := range tests {
421430
t.Run(tt.name, func(t *testing.T) {
422-
u, err := url.Parse(tt.url)
431+
u, err := url.ParseRequestURI(tt.url)
423432
if err != nil {
424433
t.Fatalf("parse URL: %v", err)
425434
}

0 commit comments

Comments
 (0)