|
15 | 15 | from devito.mpi.reduction_scheme import DistReduce |
16 | 16 | from devito.symbolics import estimate_cost |
17 | 17 | from devito.tools import as_tuple, filter_ordered, flatten, infer_dtype |
18 | | -from devito.types import Fence, WeakFence, CriticalRegion |
| 18 | +from devito.types import ( |
| 19 | + Fence, WeakFence, CriticalRegion, ThreadPoolSync, ThreadCommit, ThreadWait |
| 20 | +) |
19 | 21 |
|
20 | 22 | __all__ = ["Cluster", "ClusterGroup"] |
21 | 23 |
|
@@ -262,26 +264,40 @@ def is_wild(self): |
262 | 264 | self.is_weak_fence or |
263 | 265 | self.is_critical_region) |
264 | 266 |
|
| 267 | + def _is_type(self, cls): |
| 268 | + return self.exprs and all(isinstance(e.rhs, cls) for e in self.exprs) |
| 269 | + |
265 | 270 | @cached_property |
266 | 271 | def is_halo_touch(self): |
267 | | - return self.exprs and all(isinstance(e.rhs, HaloTouch) for e in self.exprs) |
| 272 | + return self._is_type(HaloTouch) |
268 | 273 |
|
269 | 274 | @cached_property |
270 | 275 | def is_dist_reduce(self): |
271 | | - return self.exprs and all(isinstance(e.rhs, DistReduce) for e in self.exprs) |
| 276 | + return self._is_type(DistReduce) |
272 | 277 |
|
273 | 278 | @cached_property |
274 | 279 | def is_fence(self): |
275 | | - return (self.exprs and all(isinstance(e.rhs, Fence) for e in self.exprs) or |
276 | | - self.is_critical_region) |
| 280 | + return self._is_type(Fence) or self.is_critical_region |
277 | 281 |
|
278 | 282 | @cached_property |
279 | 283 | def is_weak_fence(self): |
280 | | - return self.exprs and all(isinstance(e.rhs, WeakFence) for e in self.exprs) |
| 284 | + return self._is_type(WeakFence) |
281 | 285 |
|
282 | 286 | @cached_property |
283 | 287 | def is_critical_region(self): |
284 | | - return self.exprs and all(isinstance(e.rhs, CriticalRegion) for e in self.exprs) |
| 288 | + return self._is_type(CriticalRegion) |
| 289 | + |
| 290 | + @cached_property |
| 291 | + def is_thread_pool_sync(self): |
| 292 | + return self._is_type(ThreadPoolSync) |
| 293 | + |
| 294 | + @cached_property |
| 295 | + def is_thread_commit(self): |
| 296 | + return self._is_type(ThreadCommit) |
| 297 | + |
| 298 | + @cached_property |
| 299 | + def is_thread_wait(self): |
| 300 | + return self._is_type(ThreadWait) |
285 | 301 |
|
286 | 302 | @cached_property |
287 | 303 | def is_async(self): |
|
0 commit comments