|
16 | 16 | from __future__ import annotations |
17 | 17 |
|
18 | 18 | import asyncio |
| 19 | +import builtins |
| 20 | +import functools |
| 21 | +import random |
19 | 22 | import socket |
| 23 | +import time |
| 24 | +import time as time # noqa: PLC0414 # needed in sync version |
20 | 25 | from typing import ( |
21 | 26 | Any, |
22 | 27 | Callable, |
23 | 28 | TypeVar, |
24 | 29 | cast, |
25 | 30 | ) |
26 | 31 |
|
| 32 | +from pymongo import _csot |
27 | 33 | from pymongo.errors import ( |
28 | 34 | OperationFailure, |
| 35 | + PyMongoError, |
29 | 36 | ) |
30 | 37 | from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE |
| 38 | +from pymongo.lock import _async_create_lock |
31 | 39 |
|
32 | 40 | _IS_SYNC = False |
33 | 41 |
|
|
36 | 44 |
|
37 | 45 |
|
38 | 46 | def _handle_reauth(func: F) -> F: |
| 47 | + @functools.wraps(func) |
39 | 48 | async def inner(*args: Any, **kwargs: Any) -> Any: |
40 | 49 | no_reauth = kwargs.pop("no_reauth", False) |
41 | 50 | from pymongo.asynchronous.pool import AsyncConnection |
@@ -68,6 +77,123 @@ async def inner(*args: Any, **kwargs: Any) -> Any: |
68 | 77 | return cast(F, inner) |
69 | 78 |
|
70 | 79 |
|
| 80 | +_MAX_RETRIES = 3 |
| 81 | +_BACKOFF_INITIAL = 0.05 |
| 82 | +_BACKOFF_MAX = 10 |
| 83 | +# DRIVERS-3240 will determine these defaults. |
| 84 | +DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 |
| 85 | +DEFAULT_RETRY_TOKEN_RETURN = 0.1 |
| 86 | + |
| 87 | + |
| 88 | +def _backoff( |
| 89 | + attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX |
| 90 | +) -> float: |
| 91 | + jitter = random.random() # noqa: S311 |
| 92 | + return jitter * min(initial_delay * (2**attempt), max_delay) |
| 93 | + |
| 94 | + |
| 95 | +class _TokenBucket: |
| 96 | + """A token bucket implementation for rate limiting.""" |
| 97 | + |
| 98 | + def __init__( |
| 99 | + self, |
| 100 | + capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY, |
| 101 | + return_rate: float = DEFAULT_RETRY_TOKEN_RETURN, |
| 102 | + ): |
| 103 | + self.lock = _async_create_lock() |
| 104 | + self.capacity = capacity |
| 105 | + # DRIVERS-3240 will determine how full the bucket should start. |
| 106 | + self.tokens = capacity |
| 107 | + self.return_rate = return_rate |
| 108 | + |
| 109 | + async def consume(self) -> bool: |
| 110 | + """Consume a token from the bucket if available.""" |
| 111 | + async with self.lock: |
| 112 | + if self.tokens >= 1: |
| 113 | + self.tokens -= 1 |
| 114 | + return True |
| 115 | + return False |
| 116 | + |
| 117 | + async def deposit(self, retry: bool = False) -> None: |
| 118 | + """Deposit a token back into the bucket.""" |
| 119 | + retry_token = 1 if retry else 0 |
| 120 | + async with self.lock: |
| 121 | + self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate) |
| 122 | + |
| 123 | + |
| 124 | +class _RetryPolicy: |
| 125 | + """A retry limiter that performs exponential backoff with jitter. |
| 126 | +
|
| 127 | + Retry attempts are limited by a token bucket to prevent overwhelming the server during |
| 128 | + a prolonged outage or high load. |
| 129 | + """ |
| 130 | + |
| 131 | + def __init__( |
| 132 | + self, |
| 133 | + token_bucket: _TokenBucket, |
| 134 | + attempts: int = _MAX_RETRIES, |
| 135 | + backoff_initial: float = _BACKOFF_INITIAL, |
| 136 | + backoff_max: float = _BACKOFF_MAX, |
| 137 | + ): |
| 138 | + self.token_bucket = token_bucket |
| 139 | + self.attempts = attempts |
| 140 | + self.backoff_initial = backoff_initial |
| 141 | + self.backoff_max = backoff_max |
| 142 | + |
| 143 | + async def record_success(self, retry: bool) -> None: |
| 144 | + """Record a successful operation.""" |
| 145 | + await self.token_bucket.deposit(retry) |
| 146 | + |
| 147 | + def backoff(self, attempt: int) -> float: |
| 148 | + """Return the backoff duration for the given .""" |
| 149 | + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) |
| 150 | + |
| 151 | + async def should_retry(self, attempt: int, delay: float) -> bool: |
| 152 | + """Return if we have budget to retry and how long to backoff.""" |
| 153 | + if attempt > self.attempts: |
| 154 | + return False |
| 155 | + |
| 156 | + # If the delay would exceed the deadline, bail early before consuming a token. |
| 157 | + if _csot.get_timeout(): |
| 158 | + if time.monotonic() + delay > _csot.get_deadline(): |
| 159 | + return False |
| 160 | + |
| 161 | + # Check token bucket last since we only want to consume a token if we actually retry. |
| 162 | + if not await self.token_bucket.consume(): |
| 163 | + # DRIVERS-3246 Improve diagnostics when this case happens. |
| 164 | + # We could add info to the exception and log. |
| 165 | + return False |
| 166 | + return True |
| 167 | + |
| 168 | + |
| 169 | +def _retry_overload(func: F) -> F: |
| 170 | + @functools.wraps(func) |
| 171 | + async def inner(self: Any, *args: Any, **kwargs: Any) -> Any: |
| 172 | + retry_policy = self._retry_policy |
| 173 | + attempt = 0 |
| 174 | + while True: |
| 175 | + try: |
| 176 | + res = await func(self, *args, **kwargs) |
| 177 | + await retry_policy.record_success(retry=attempt > 0) |
| 178 | + return res |
| 179 | + except PyMongoError as exc: |
| 180 | + if not exc.has_error_label("RetryableError"): |
| 181 | + raise |
| 182 | + attempt += 1 |
| 183 | + delay = 0 |
| 184 | + if exc.has_error_label("SystemOverloadedError"): |
| 185 | + delay = retry_policy.backoff(attempt) |
| 186 | + if not await retry_policy.should_retry(attempt, delay): |
| 187 | + raise |
| 188 | + |
| 189 | + # Implement exponential backoff on retry. |
| 190 | + if delay: |
| 191 | + await asyncio.sleep(delay) |
| 192 | + continue |
| 193 | + |
| 194 | + return cast(F, inner) |
| 195 | + |
| 196 | + |
71 | 197 | async def _getaddrinfo( |
72 | 198 | host: Any, port: Any, **kwargs: Any |
73 | 199 | ) -> list[ |
|
0 commit comments