|
3 | 3 | import inspect |
4 | 4 | import sys |
5 | 5 | import warnings |
6 | | -from asyncio import CancelledError, Event, create_task |
| 6 | +from asyncio import FIRST_COMPLETED, CancelledError, Event, create_task, wait |
7 | 7 | from collections.abc import Awaitable, Coroutine, Sequence |
8 | 8 | from logging import getLogger |
9 | 9 | from types import FunctionType |
@@ -240,20 +240,36 @@ async def __aexit__(self, exc_type: type[BaseException], *exc: Any) -> Any: |
240 | 240 | # propagate non-cancellation exceptions |
241 | 241 | return None |
242 | 242 |
|
| 243 | + if not self._maybe_uncancel_task(): |
| 244 | + return None |
| 245 | + |
| 246 | + wait_for_stop = create_task(self._stop.wait()) |
243 | 247 | try: |
244 | | - await self._stop.wait() |
| 248 | + await wait([wait_for_stop, self.task], return_when=FIRST_COMPLETED) |
245 | 249 | except CancelledError: |
246 | | - if self.task.cancelling() > self._cancel_count: |
247 | | - # Task has been cancelled by something else - propagate it |
248 | | - return None |
249 | | - self.task.uncancel() |
| 250 | + if not self._maybe_uncancel_task(): |
| 251 | + raise |
250 | 252 |
|
251 | 253 | return True |
252 | 254 |
|
253 | 255 | def _cancel_task(self) -> None: |
254 | 256 | self.task.cancel() |
255 | 257 | self._cancel_count += 1 |
256 | 258 |
|
| 259 | + def _maybe_uncancel_task(self) -> bool: |
| 260 | + """Return if task was uncancelled |
| 261 | +
|
| 262 | + If task was not cancelled by this effect then returns. Otherwise, if the task |
| 263 | + was cancelled at all, uncancell it and return True |
| 264 | + """ |
| 265 | + if self.task.cancelling() > self._cancel_count: |
| 266 | + # Task has been cancelled by something else - propagate it |
| 267 | + return False |
| 268 | + elif self._cancel_count: |
| 269 | + for _ in range(self._cancel_count): |
| 270 | + self.task.uncancel() |
| 271 | + return True |
| 272 | + |
257 | 273 |
|
258 | 274 | def _cast_async_effect(function: Callable[..., Any]) -> _AsyncEffectFunc: |
259 | 275 | if inspect.iscoroutinefunction(function): |
|
0 commit comments