|
8 | 8 | from operator import ge, gt, le, lt |
9 | 9 |
|
10 | 10 | from functools import singledispatch |
11 | | -from sympy import And, Ge, Gt, Le, Lt, Mul, true |
| 11 | +from sympy import And, Expr, Ge, Gt, Le, Lt, Mul, true |
12 | 12 | from sympy.logic.boolalg import BooleanFunction |
13 | 13 | import numpy as np |
14 | 14 |
|
15 | 15 | from devito.ir.support.space import Forward, IterationDirection |
16 | | -from devito.symbolics import CondEq, CondNe |
| 16 | +from devito.symbolics import CondEq, CondNe, search |
17 | 17 | from devito.tools import Pickable, as_tuple, frozendict, split |
18 | 18 | from devito.types import Dimension, LocalObject |
19 | 19 |
|
20 | 20 | __all__ = ['GuardFactor', 'GuardBound', 'GuardBoundNext', 'BaseGuardBound', |
21 | | - 'BaseGuardBoundNext', 'GuardOverflow', 'Guards', 'GuardExpr'] |
| 21 | + 'BaseGuardBoundNext', 'GuardOverflow', 'Guards', 'GuardExpr', |
| 22 | + 'GuardSwitch', 'GuardCaseSwitch'] |
22 | 23 |
|
23 | 24 |
|
24 | | -class Guard: |
| 25 | +class AbstractGuard: |
| 26 | + pass |
| 27 | + |
| 28 | + |
| 29 | +class Guard(AbstractGuard): |
25 | 30 |
|
26 | 31 | @property |
27 | 32 | def _args_rebuild(self): |
@@ -217,6 +222,35 @@ class GuardOverflowLt(BaseGuardOverflow, Lt): |
217 | 222 | } |
218 | 223 |
|
219 | 224 |
|
| 225 | +class GuardSwitch(AbstractGuard, Expr): |
| 226 | + |
| 227 | + """ |
| 228 | + A switch guard (akin to C's switch-case) that can be used to select |
| 229 | + between multiple cases at runtime. |
| 230 | + """ |
| 231 | + |
| 232 | + def __new__(cls, arg, **kwargs): |
| 233 | + return Expr.__new__(cls, arg) |
| 234 | + |
| 235 | + @property |
| 236 | + def arg(self): |
| 237 | + return self.args[0] |
| 238 | + |
| 239 | + |
| 240 | +class GuardCaseSwitch(GuardSwitch): |
| 241 | + |
| 242 | + """ |
| 243 | + A case within a GuardSwitch. |
| 244 | + """ |
| 245 | + |
| 246 | + def __new__(cls, arg, case, **kwargs): |
| 247 | + return Expr.__new__(cls, arg, case) |
| 248 | + |
| 249 | + @property |
| 250 | + def case(self): |
| 251 | + return self.args[1] |
| 252 | + |
| 253 | + |
220 | 254 | class Guards(frozendict): |
221 | 255 |
|
222 | 256 | """ |
|
0 commit comments