|
7 | 7 | import json |
8 | 8 | import os |
9 | 9 | import sys |
| 10 | +import time |
10 | 11 |
|
11 | 12 | import aiohttp |
12 | 13 |
|
13 | 14 | from .cache.default import DefaultCache |
14 | 15 | from .details import Details |
15 | | -from .exceptions import RequestQuotaExceededError |
| 16 | +from .exceptions import RequestQuotaExceededError, TimeoutExceededError |
16 | 17 | from .handler_utils import ( |
17 | 18 | API_URL, |
18 | 19 | COUNTRY_FILE_DEFAULT, |
@@ -197,49 +198,80 @@ async def getBatchDetails( |
197 | 198 | url = API_URL + "/batch" |
198 | 199 | headers = handler_utils.get_headers(self.access_token) |
199 | 200 | headers["content-type"] = "application/json" |
200 | | - reqs = [] |
201 | | - for i in range(0, len(lookup_addresses), batch_size): |
202 | | - chunk = lookup_addresses[i : i + batch_size] |
203 | | - |
204 | | - # do http req |
205 | | - reqs.append( |
206 | | - self.httpsess.post( |
207 | | - url, |
208 | | - data=json.dumps(chunk), |
209 | | - headers=headers, |
210 | | - timeout=timeout_per_batch, |
211 | | - ) |
| 201 | + |
| 202 | + # prepare coroutines that will make reqs and update results. |
| 203 | + reqs = [ |
| 204 | + self._do_batch_req( |
| 205 | + lookup_addresses[i : i + batch_size], |
| 206 | + url, |
| 207 | + headers, |
| 208 | + timeout_per_batch, |
| 209 | + raise_on_fail, |
| 210 | + result, |
| 211 | + ) |
| 212 | + for i in range(0, len(lookup_addresses), batch_size) |
| 213 | + ] |
| 214 | + |
| 215 | + try: |
| 216 | + _, pending = await asyncio.wait( |
| 217 | + {*reqs}, |
| 218 | + timeout=timeout_total, |
| 219 | + return_when=asyncio.FIRST_EXCEPTION, |
212 | 220 | ) |
213 | 221 |
|
214 | | - resps = await asyncio.wait_for( |
215 | | - asyncio.gather(*reqs, return_exceptions=raise_on_fail), |
216 | | - timeout_total |
217 | | - ) |
218 | | - for resp in resps: |
219 | | - # gather data |
220 | | - try: |
221 | | - if resp.status == 429: |
222 | | - raise RequestQuotaExceededError() |
223 | | - resp.raise_for_status() |
224 | | - except Exception as e: |
225 | | - if raise_on_fail: |
226 | | - raise e |
227 | | - else: |
228 | | - return result |
229 | | - |
230 | | - json_resp = await resp.json() |
231 | | - |
232 | | - # format & fill up cache |
233 | | - for ip_address, details in json_resp.items(): |
234 | | - if isinstance(details, dict): |
235 | | - handler_utils.format_details(details, self.countries) |
236 | | - self.cache[ip_address] = details |
237 | | - |
238 | | - # merge cached results with new lookup |
239 | | - result.update(json_resp) |
| 222 | + # if all done, return result. |
| 223 | + if len(pending) == 0: |
| 224 | + return result |
| 225 | + |
| 226 | + # if some had a timeout, first cancel timed out stuff and wait for |
| 227 | + # cleanup. then exit with return_or_fail. |
| 228 | + for co in pending: |
| 229 | + try: |
| 230 | + co.cancel() |
| 231 | + await co |
| 232 | + except asyncio.CancelledError: |
| 233 | + pass |
| 234 | + |
| 235 | + return handler_utils.return_or_fail( |
| 236 | + raise_on_fail, TimeoutExceededError(), result |
| 237 | + ) |
| 238 | + except Exception as e: |
| 239 | + return handler_utils.return_or_fail(raise_on_fail, e, result) |
240 | 240 |
|
241 | 241 | return result |
242 | 242 |
|
| 243 | + async def _do_batch_req( |
| 244 | + self, chunk, url, headers, timeout_per_batch, raise_on_fail, result |
| 245 | + ): |
| 246 | + """ |
| 247 | + Coroutine which will do the actual POST request for getBatchDetails. |
| 248 | + """ |
| 249 | + resp = await self.httpsess.post( |
| 250 | + url, |
| 251 | + data=json.dumps(chunk), |
| 252 | + headers=headers, |
| 253 | + timeout=timeout_per_batch, |
| 254 | + ) |
| 255 | + |
| 256 | + # gather data |
| 257 | + try: |
| 258 | + if resp.status == 429: |
| 259 | + raise RequestQuotaExceededError() |
| 260 | + resp.raise_for_status() |
| 261 | + except Exception as e: |
| 262 | + return handler_utils.return_or_fail(raise_on_fail, e, None) |
| 263 | + |
| 264 | + json_resp = await resp.json() |
| 265 | + |
| 266 | + # format & fill up cache |
| 267 | + for ip_address, details in json_resp.items(): |
| 268 | + if isinstance(details, dict): |
| 269 | + handler_utils.format_details(details, self.countries) |
| 270 | + self.cache[ip_address] = details |
| 271 | + |
| 272 | + # merge cached results with new lookup |
| 273 | + result.update(json_resp) |
| 274 | + |
243 | 275 | def _ensure_aiohttp_ready(self): |
244 | 276 | """Ensures aiohttp internal state is initialized.""" |
245 | 277 | if self.httpsess: |
|
0 commit comments