Skip to content

Commit c308329

Browse files
authored
Merge pull request openwallet-foundation#1808 from andrewwhitehead/fix/put-redirect
Fix put_file when the server returns a redirect
2 parents 592dfd0 + d5ac9dd commit c308329

2 files changed

Lines changed: 115 additions & 26 deletions

File tree

aries_cloudagent/utils/http.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
"""HTTP utility methods."""
22

33
import asyncio
4-
5-
from aiohttp import BaseConnector, ClientError, ClientResponse, ClientSession
4+
import logging
5+
import urllib.parse
6+
7+
from aiohttp import (
8+
BaseConnector,
9+
ClientError,
10+
ClientResponse,
11+
ClientSession,
12+
FormData,
13+
)
614
from aiohttp.web import HTTPConflict
715

816
from ..core.error import BaseError
917

1018
from .repeat import RepeatSequence
1119

1220

21+
LOGGER = logging.getLogger(__name__)
22+
23+
1324
class FetchError(BaseError):
1425
"""Error raised when an HTTP fetch fails."""
1526

@@ -147,7 +158,6 @@ async def put_file(
147158
148159
"""
149160
(data_key, file_path) = [k for k in file_data.items()][0]
150-
data = {**extra_data}
151161
limit = max_attempts if retry else 1
152162

153163
if not session:
@@ -158,17 +168,51 @@ async def put_file(
158168
async for attempt in RepeatSequence(limit, interval, backoff):
159169
try:
160170
async with attempt.timeout(request_timeout):
161-
with open(file_path, "rb") as f:
162-
data[data_key] = f
163-
response: ClientResponse = await session.put(url, data=data)
164-
if (response.status < 200 or response.status >= 300) and (
165-
response.status != HTTPConflict.status_code
166-
):
167-
raise ClientError(
168-
f"Bad response from server: {response.status}, "
169-
f"{response.reason}"
170-
)
171+
formdata = FormData()
172+
try:
173+
fp = open(file_path, "rb")
174+
except OSError as e:
175+
raise PutError("Error opening file for upload") from e
176+
if extra_data:
177+
for k, v in extra_data.items():
178+
formdata.add_field(k, v)
179+
formdata.add_field(
180+
data_key, fp, content_type="application/octet-stream"
181+
)
182+
response: ClientResponse = await session.put(
183+
url, data=formdata, allow_redirects=False
184+
)
185+
if (
186+
# redirect codes
187+
response.status in (301, 302, 303, 307, 308)
188+
and not attempt.final
189+
):
190+
# NOTE: a redirect counts as another upload attempt
191+
to_url = response.headers.get("Location")
192+
if not to_url:
193+
raise PutError("Redirect missing target URL")
194+
try:
195+
parsed_to = urllib.parse.urlsplit(to_url)
196+
parsed_from = urllib.parse.urlsplit(url)
197+
except ValueError:
198+
raise PutError("Invalid redirect URL")
199+
if parsed_to.hostname != parsed_from.hostname:
200+
raise PutError("Redirect denied: hostname mismatch")
201+
url = to_url
202+
LOGGER.info("Upload redirect: %s", to_url)
203+
elif (response.status < 200 or response.status >= 300) and (
204+
response.status != HTTPConflict.status_code
205+
):
206+
raise ClientError(
207+
f"Bad response from server: {response.status}, "
208+
f"{response.reason}"
209+
)
210+
else:
171211
return await (response.json() if json else response.text())
172212
except (ClientError, asyncio.TimeoutError) as e:
213+
if isinstance(e, ClientError):
214+
LOGGER.warning("Upload error: %s", e)
215+
else:
216+
LOGGER.warning("Upload error: request timed out")
173217
if attempt.final:
174-
raise PutError("Exceeded maximum put attempts") from e
218+
raise PutError("Exceeded maximum upload attempts") from e

aries_cloudagent/utils/tests/test_http.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,33 @@
1+
import os
2+
import tempfile
3+
14
from aiohttp import web
2-
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
3-
from asynctest import mock as async_mock, mock_open
5+
from aiohttp.test_utils import AioHTTPTestCase
46

57
from ..http import fetch, fetch_stream, FetchError, put_file, PutError
68

79

10+
class TempFile:
11+
def __init__(self):
12+
self.name = None
13+
14+
def __enter__(self):
15+
file = tempfile.NamedTemporaryFile(delete=False)
16+
file.write(b"test")
17+
file.close()
18+
self.name = file.name
19+
return self.name
20+
21+
def __exit__(self, *args):
22+
if self.name:
23+
os.unlink(self.name)
24+
25+
826
class TestTransportUtils(AioHTTPTestCase):
927
async def setUpAsync(self):
1028
self.fail_calls = 0
1129
self.succeed_calls = 0
30+
self.redirects = 0
1231
await super().setUpAsync()
1332

1433
async def get_application(self):
@@ -19,19 +38,30 @@ async def get_application(self):
1938
web.get("/succeed", self.succeed_route),
2039
web.put("/fail", self.fail_route),
2140
web.put("/succeed", self.succeed_route),
41+
web.put("/redirect", self.redirect_route),
2242
]
2343
)
2444
return app
2545

2646
async def fail_route(self, request):
2747
self.fail_calls += 1
48+
# avoid aiohttp test server issue: https://github.com/aio-libs/aiohttp/issues/3968
49+
await request.read()
2850
raise web.HTTPForbidden()
2951

3052
async def succeed_route(self, request):
3153
self.succeed_calls += 1
3254
ret = web.json_response([True])
3355
return ret
3456

57+
async def redirect_route(self, request):
58+
if self.redirects > 0:
59+
self.redirects -= 1
60+
# avoid aiohttp test server issue: https://github.com/aio-libs/aiohttp/issues/3968
61+
await request.read()
62+
raise web.HTTPRedirection(f"http://localhost:{self.server.port}/success")
63+
return await self.succeed_route(request)
64+
3565
async def test_fetch_stream(self):
3666
server_addr = f"http://localhost:{self.server.port}"
3767
stream = await fetch_stream(
@@ -84,40 +114,55 @@ async def test_fetch_fail(self):
84114
)
85115
assert self.fail_calls == 2
86116

87-
async def test_put_file(self):
117+
async def test_put_file_with_session(self):
88118
server_addr = f"http://localhost:{self.server.port}"
89-
with async_mock.patch("builtins.open", mock_open(read_data="data")):
119+
with TempFile() as tails:
90120
result = await put_file(
91121
f"{server_addr}/succeed",
92-
{"tails": "/tmp/dummy/path"},
122+
{"tails": tails},
93123
{"genesis": "..."},
94124
session=self.client.session,
95125
json=True,
96126
)
97-
assert result == [1]
127+
assert result == [True]
98128
assert self.succeed_calls == 1
99129

100130
async def test_put_file_default_client(self):
101131
server_addr = f"http://localhost:{self.server.port}"
102-
with async_mock.patch("builtins.open", mock_open(read_data="data")):
132+
with TempFile() as tails:
103133
result = await put_file(
104134
f"{server_addr}/succeed",
105-
{"tails": "/tmp/dummy/path"},
135+
{"tails": tails},
106136
{"genesis": "..."},
107137
json=True,
108138
)
109-
assert result == [1]
139+
assert result == [True]
110140
assert self.succeed_calls == 1
111141

112142
async def test_put_file_fail(self):
113143
server_addr = f"http://localhost:{self.server.port}"
114-
with async_mock.patch("builtins.open", mock_open(read_data="data")):
144+
with TempFile() as tails:
115145
with self.assertRaises(PutError):
116-
result = await put_file(
146+
_ = await put_file(
117147
f"{server_addr}/fail",
118-
{"tails": "/tmp/dummy/path"},
148+
{"tails": tails},
119149
{"genesis": "..."},
120150
max_attempts=2,
121151
json=True,
122152
)
123153
assert self.fail_calls == 2
154+
155+
async def test_put_file_redirect(self):
156+
server_addr = f"http://localhost:{self.server.port}"
157+
self.redirects = 1
158+
with TempFile() as tails:
159+
result = await put_file(
160+
f"{server_addr}/redirect",
161+
{"tails": tails},
162+
{"genesis": "..."},
163+
max_attempts=2,
164+
json=True,
165+
)
166+
assert result == [True]
167+
assert self.succeed_calls == 1
168+
assert self.redirects == 0

0 commit comments

Comments
 (0)