Skip to content

Commit 6bdf1e0

Browse files
committed
compiler: Refine GuardExpr
1 parent cd17375 commit 6bdf1e0

1 file changed

Lines changed: 45 additions & 14 deletions

File tree

devito/ir/support/guards.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,37 @@ def filter(self, key):
277277

278278
class GuardExpr(LocalObject, BooleanFunction):
279279

280+
"""
281+
A boolean symbol that can be used as a guard. As such, it can be chained
282+
with other relations using the standard boolean operators (&, |, ...).
283+
284+
Being a LocalObject, a GuardExpr may carry an `initvalue`, which is
285+
the value that the guard assumes at the beginning of the scope where
286+
it is defined.
287+
288+
Through the `supersets` argument, a GuardExpr may also carry a set of
289+
GuardExprs that are known to be more restrictive than itself. This is
290+
usesful, e.g., to avoid redundant checks when chaining multiple guards
291+
together (see `simplify_and`).
292+
"""
293+
280294
dtype = np.bool
281295

282-
def __init__(self, name, liveness='eager', **kwargs):
296+
def __init__(self, name, liveness='eager', supersets=None, **kwargs):
283297
super().__init__(name, liveness=liveness, **kwargs)
284298

299+
self.supersets = frozenset(as_tuple(supersets))
300+
301+
def _hashable_content(self):
302+
return super()._hashable_content() + (self.supersets,)
303+
304+
__hash__ = LocalObject.__hash__
305+
306+
def __eq__(self, other):
307+
return (isinstance(other, GuardExpr) and
308+
super().__eq__(other) and
309+
self.supersets == other.supersets)
310+
285311

286312
# *** Utils
287313

@@ -301,25 +327,30 @@ def simplify_and(relation, v):
301327
else:
302328
candidates, other = [], [relation, v]
303329

330+
# Quick check based on GuardExpr.supersets to avoid adding `v` to `relation`
331+
# if `relation` already includes a more restrictive guard than `v`
332+
if isinstance(v, GuardExpr):
333+
if any(a in v.supersets for a in candidates):
334+
return relation
335+
304336
covered = False
305337
new_args = []
306338
for a in candidates:
307339
if isinstance(a, GuardExpr) or a.lhs is not v.lhs:
308340
new_args.append(a)
309-
continue
310-
311-
covered = True
312-
try:
313-
if type(a) in (Gt, Ge) and v.rhs > a.rhs:
314-
new_args.append(v)
315-
elif type(a) in (Lt, Le) and v.rhs < a.rhs:
316-
new_args.append(v)
317-
else:
341+
else:
342+
covered = True
343+
try:
344+
if type(a) in (Gt, Ge) and v.rhs > a.rhs:
345+
new_args.append(v)
346+
elif type(a) in (Lt, Le) and v.rhs < a.rhs:
347+
new_args.append(v)
348+
else:
349+
new_args.append(a)
350+
except TypeError:
351+
# E.g., `v.rhs = const + z_M` and `a.rhs = z_M`, so the inequalities
352+
# above are not evaluable to True/False
318353
new_args.append(a)
319-
except TypeError:
320-
# E.g., `v.rhs = const + z_M` and `a.rhs = z_M`, so the inequalities
321-
# above are not evaluable to True/False
322-
new_args.append(a)
323354

324355
if not covered:
325356
new_args.append(v)

0 commit comments

Comments
 (0)