|
6 | 6 | import threading |
7 | 7 | import contextlib |
8 | 8 | import time |
| 9 | +import inspect |
9 | 10 |
|
10 | | -from async_generator import ( |
11 | | - async_generator, |
12 | | - yield_, |
13 | | - isasyncgenfunction, |
14 | | - asynccontextmanager, |
15 | | -) |
| 11 | +try: |
| 12 | + from async_generator import yield_, async_generator |
| 13 | +except ImportError: # pragma: no cover |
| 14 | + async_generator = yield_ = None |
16 | 15 |
|
17 | 16 | from ... import _core |
18 | 17 | from ...testing import wait_all_tasks_blocked |
@@ -142,7 +141,8 @@ def protected_manager(): |
142 | 141 | raise KeyError |
143 | 142 |
|
144 | 143 |
|
145 | | -async def test_agen_protection(): |
| 144 | +@pytest.mark.skipif(async_generator is None, reason="async_generator not installed") |
| 145 | +async def test_async_generator_agen_protection(): |
146 | 146 | @_core.enable_ki_protection |
147 | 147 | @async_generator |
148 | 148 | async def agen_protected1(): |
@@ -180,45 +180,49 @@ async def agen_unprotected2(): |
180 | 180 | finally: |
181 | 181 | assert not _core.currently_ki_protected() |
182 | 182 |
|
| 183 | + await _check_agen(agen_protected1) |
| 184 | + await _check_agen(agen_protected2) |
| 185 | + await _check_agen(agen_unprotected1) |
| 186 | + await _check_agen(agen_unprotected2) |
| 187 | + |
| 188 | + |
| 189 | +async def test_native_agen_protection(): |
183 | 190 | # Native async generators |
184 | 191 | @_core.enable_ki_protection |
185 | | - async def agen_protected3(): |
| 192 | + async def agen_protected(): |
186 | 193 | assert _core.currently_ki_protected() |
187 | 194 | try: |
188 | 195 | yield |
189 | 196 | finally: |
190 | 197 | assert _core.currently_ki_protected() |
191 | 198 |
|
192 | 199 | @_core.disable_ki_protection |
193 | | - async def agen_unprotected3(): |
| 200 | + async def agen_unprotected(): |
194 | 201 | assert not _core.currently_ki_protected() |
195 | 202 | try: |
196 | 203 | yield |
197 | 204 | finally: |
198 | 205 | assert not _core.currently_ki_protected() |
199 | 206 |
|
200 | | - for agen_fn in [ |
201 | | - agen_protected1, |
202 | | - agen_protected2, |
203 | | - agen_protected3, |
204 | | - agen_unprotected1, |
205 | | - agen_unprotected2, |
206 | | - agen_unprotected3, |
207 | | - ]: |
208 | | - async for _ in agen_fn(): # noqa |
| 207 | + await _check_agen(agen_protected) |
| 208 | + await _check_agen(agen_unprotected) |
| 209 | + |
| 210 | + |
| 211 | +async def _check_agen(agen_fn): |
| 212 | + async for _ in agen_fn(): # noqa |
| 213 | + assert not _core.currently_ki_protected() |
| 214 | + |
| 215 | + # asynccontextmanager insists that the function passed must itself be an |
| 216 | + # async gen function, not a wrapper around one |
| 217 | + if inspect.isasyncgenfunction(agen_fn): |
| 218 | + async with contextlib.asynccontextmanager(agen_fn)(): |
209 | 219 | assert not _core.currently_ki_protected() |
210 | 220 |
|
211 | | - # asynccontextmanager insists that the function passed must itself be an |
212 | | - # async gen function, not a wrapper around one |
213 | | - if isasyncgenfunction(agen_fn): |
214 | | - async with asynccontextmanager(agen_fn)(): |
215 | | - assert not _core.currently_ki_protected() |
216 | | - |
217 | | - # Another case that's tricky due to: |
218 | | - # https://bugs.python.org/issue29590 |
219 | | - with pytest.raises(KeyError): |
220 | | - async with asynccontextmanager(agen_fn)(): |
221 | | - raise KeyError |
| 221 | + # Another case that's tricky due to: |
| 222 | + # https://bugs.python.org/issue29590 |
| 223 | + with pytest.raises(KeyError): |
| 224 | + async with contextlib.asynccontextmanager(agen_fn)(): |
| 225 | + raise KeyError |
222 | 226 |
|
223 | 227 |
|
224 | 228 | # Test the case where there's no magic local anywhere in the call stack |
|
0 commit comments