Skip to content

Commit 8c92d8a

Browse files
committed
feat: fix SSRF bypass vulnerability with IPv6 support
- Update is_safe_url() to use getaddrinfo() instead of gethostbyname() - Resolves ALL IPs (IPv4 + IPv6) to prevent dual-stack bypass - Blocks requests if ANY resolved IP is private/loopback/reserved - Add comprehensive SSRF tests including dual-stack scenario - Update .gitignore with __pycache__
1 parent faf755f commit 8c92d8a

3 files changed

Lines changed: 119 additions & 17 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
.aider*
2+
__pycache__/
3+
*.pyc

server.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import cloudscraper
2-
import time
3-
import socket
42
import ipaddress
3+
import socket
4+
import time
55

66
from urllib.parse import unquote, urlparse
77
from flask import Flask, request, Response
@@ -54,28 +54,40 @@ def generate_origin_and_ref(url, headers):
5454
def is_safe_url(url):
5555
"""
5656
Validates URL to prevent SSRF attacks by blocking local/private IP ranges.
57+
Uses getaddrinfo to resolve ALL IP addresses (IPv4 and IPv6) to prevent
58+
bypass via dual-stack hostnames that resolve to both public and private IPs.
59+
60+
Returns:
61+
tuple: (is_safe: bool, error_message: str | None)
5762
"""
5863
try:
5964
parsed = urlparse(url)
6065
if parsed.scheme not in ('http', 'https'):
6166
return False, "Only HTTP/HTTPS protocols are allowed"
6267

6368
hostname = parsed.hostname
69+
6470
if not hostname:
6571
return False, "Invalid hostname"
6672

6773
try:
68-
# Resolve hostname to IP to check against blocklist
69-
ip_str = socket.gethostbyname(hostname)
70-
ip = ipaddress.ip_address(ip_str)
74+
# Resolve hostname to IP(s) to check against blocklist.
75+
# We use getaddrinfo to get all resolved IPs (IPv4 and IPv6) to prevent evasion
76+
# where a hostname resolves to both a safe IP and a private IP.
77+
addr_info = socket.getaddrinfo(hostname, None)
78+
79+
for family, _, _, _, sockaddr in addr_info:
80+
# sockaddr is (address, port) for AF_INET and (address, port, flow info, scope id) for AF_INET6
81+
ip_str = sockaddr[0]
82+
ip = ipaddress.ip_address(ip_str)
83+
84+
if ip.is_loopback or ip.is_private or ip.is_reserved or ip.is_multicast or ip.is_unspecified:
85+
return False, "Access to private/local network is forbidden"
86+
7187
except (socket.gaierror, ValueError):
72-
# If we can't resolve it, it might be safer to block or allow if it's external.
73-
# For security, fail closed if resolution fails.
88+
# If we can't resolve it, fail closed.
7489
return False, "Could not resolve hostname or invalid IP"
7590

76-
if ip.is_loopback or ip.is_private or ip.is_reserved or ip.is_multicast or ip.is_unspecified:
77-
return False, "Access to private/local network is forbidden"
78-
7991
return True, None
8092
except Exception:
8193
return False, "Invalid URL format"
@@ -154,9 +166,9 @@ def get_proxy_request_headers(req, url):
154166
headers = get_headers()
155167
headers['Accept-Encoding'] = 'gzip, deflate, br'
156168

157-
for header in req.headers:
158-
if header[0].lower() not in ['host', 'connection', 'content-length']:
159-
headers[header[0]] = header[1]
169+
for key, value in req.headers.items():
170+
if key.lower() not in ['host', 'connection', 'content-length']:
171+
headers[key] = value
160172
headers = generate_origin_and_ref(url, headers)
161173
return headers
162174

@@ -167,16 +179,16 @@ def handle_proxy(url):
167179
if request.method == 'GET':
168180
full_url = get_proxy_request_url(request, url) # parse request url
169181

170-
# Sentinel: SSRF Protection
182+
# SSRF protection check
171183
is_safe, error_msg = is_safe_url(full_url)
172184
if not is_safe:
173-
return {'error': error_msg}, 400
185+
print(f"SSRF blocked: {full_url} - {error_msg}")
186+
return {'error': error_msg}, 403
174187

175188
headers = get_proxy_request_headers(request, url) # generate headers for the request
176189

177190
try:
178191
start = time.time()
179-
# Sentinel: Added timeout to prevent hanging
180192
response = scraper.get(full_url, headers=headers, timeout=30)
181193
end = time.time()
182194
elapsed = end - start
@@ -187,7 +199,7 @@ def handle_proxy(url):
187199

188200
except Exception as e:
189201
print(f"Proxy Request Error: {str(e)}")
190-
# Sentinel: Don't leak stack traces or internal details
202+
# Don't leak stack traces or internal details
191203
return {'error': "Proxy request failed. Check server logs for details."}, 500
192204

193205

tests/test_is_safe_url.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import unittest
2+
from unittest.mock import patch
3+
from server import is_safe_url
4+
import socket
5+
6+
7+
class TestIsSafeUrl(unittest.TestCase):
8+
9+
@patch('server.socket.getaddrinfo')
10+
def test_public_ipv4(self, mock_getaddrinfo):
11+
mock_getaddrinfo.return_value = [
12+
(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('8.8.8.8', 80))
13+
]
14+
is_safe, msg = is_safe_url('http://example.com')
15+
self.assertTrue(is_safe)
16+
self.assertIsNone(msg)
17+
18+
@patch('server.socket.getaddrinfo')
19+
def test_private_ipv4(self, mock_getaddrinfo):
20+
mock_getaddrinfo.return_value = [
21+
(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('127.0.0.1', 80))
22+
]
23+
is_safe, msg = is_safe_url('http://example.com')
24+
self.assertFalse(is_safe)
25+
self.assertIn("Access to private/local network is forbidden", msg)
26+
27+
@patch('server.socket.getaddrinfo')
28+
def test_dual_stack_ipv6_private(self, mock_getaddrinfo):
29+
# Scenario: DNS has public IPv4 (8.8.8.8) and private IPv6 (::1)
30+
mock_getaddrinfo.return_value = [
31+
(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('8.8.8.8', 80)),
32+
(socket.AF_INET6, socket.SOCK_STREAM, 6, '', ('::1', 80, 0, 0))
33+
]
34+
35+
is_safe, msg = is_safe_url('http://example.com')
36+
37+
# We expect this to be False (Blocked) because of the private IPv6 address.
38+
self.assertFalse(is_safe, "Should block if ANY resolved IP is private (IPv6 loopback)")
39+
40+
@patch('server.socket.getaddrinfo')
41+
def test_private_ipv4_192_168(self, mock_getaddrinfo):
42+
mock_getaddrinfo.return_value = [
43+
(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('192.168.1.1', 80))
44+
]
45+
is_safe, msg = is_safe_url('http://example.com')
46+
self.assertFalse(is_safe)
47+
self.assertIn("Access to private/local network is forbidden", msg)
48+
49+
@patch('server.socket.getaddrinfo')
50+
def test_private_ipv4_10_0_0_0(self, mock_getaddrinfo):
51+
mock_getaddrinfo.return_value = [
52+
(socket.AF_INET, socket.SOCK_STREAM, 6, '', ('10.0.0.1', 80))
53+
]
54+
is_safe, msg = is_safe_url('http://example.com')
55+
self.assertFalse(is_safe)
56+
self.assertIn("Access to private/local network is forbidden", msg)
57+
58+
@patch('server.socket.getaddrinfo')
59+
def test_ipv6_loopback(self, mock_getaddrinfo):
60+
mock_getaddrinfo.return_value = [
61+
(socket.AF_INET6, socket.SOCK_STREAM, 6, '', ('::1', 80, 0, 0))
62+
]
63+
is_safe, msg = is_safe_url('http://example.com')
64+
self.assertFalse(is_safe)
65+
self.assertIn("Access to private/local network is forbidden", msg)
66+
67+
@patch('server.socket.getaddrinfo')
68+
def test_ipv6_private(self, mock_getaddrinfo):
69+
mock_getaddrinfo.return_value = [
70+
(socket.AF_INET6, socket.SOCK_STREAM, 6, '', ('fe80::1', 80, 0, 0))
71+
]
72+
is_safe, msg = is_safe_url('http://example.com')
73+
self.assertFalse(is_safe)
74+
self.assertIn("Access to private/local network is forbidden", msg)
75+
76+
def test_invalid_url(self):
77+
is_safe, msg = is_safe_url('not-a-url')
78+
self.assertFalse(is_safe)
79+
self.assertIn("Only HTTP/HTTPS protocols are allowed", msg)
80+
81+
def test_invalid_hostname(self):
82+
is_safe, msg = is_safe_url('http://')
83+
self.assertFalse(is_safe)
84+
self.assertIn("Invalid hostname", msg)
85+
86+
87+
if __name__ == '__main__':
88+
unittest.main()

0 commit comments

Comments
 (0)