Skip to content

Commit ef8f581

Browse files
authored
Merge pull request #2742 from devitocodes/improve-guard-expr
compiler: Implement auto-simplification of GuardExpr
2 parents 06a4937 + 0fb2dcc commit ef8f581

2 files changed

Lines changed: 199 additions & 22 deletions

File tree

devito/ir/support/guards.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
(e.g., Eq, Cluster, ...) should be evaluated at runtime.
55
"""
66

7+
from operator import ge, gt, le, lt
8+
9+
from functools import singledispatch
710
from sympy import And, Ge, Gt, Le, Lt, Mul, true
811
from sympy.logic.boolalg import BooleanFunction
912
import numpy as np
@@ -284,33 +287,54 @@ class GuardExpr(LocalObject, BooleanFunction):
284287
Being a LocalObject, a GuardExpr may carry an `initvalue`, which is
285288
the value that the guard assumes at the beginning of the scope where
286289
it is defined.
287-
288-
Through the `supersets` argument, a GuardExpr may also carry a set of
289-
GuardExprs that are known to be more restrictive than itself. This is
290-
usesful, e.g., to avoid redundant checks when chaining multiple guards
291-
together (see `simplify_and`).
292290
"""
293291

294292
dtype = np.bool
295293

296-
def __init__(self, name, liveness='eager', supersets=None, **kwargs):
294+
def __init__(self, name, liveness='eager', **kwargs):
297295
super().__init__(name, liveness=liveness, **kwargs)
298296

299-
self.supersets = frozenset(as_tuple(supersets))
297+
@singledispatch
298+
def _handle_boolean(obj, mapper):
299+
raise NotImplementedError(f"Cannot handle boolean of type {type(obj)}")
300300

301-
def _hashable_content(self):
302-
return super()._hashable_content() + (self.supersets,)
301+
@_handle_boolean.register(And)
302+
def _(obj, mapper):
303+
for a in obj.args:
304+
GuardExpr._handle_boolean(a, mapper)
303305

304-
__hash__ = LocalObject.__hash__
306+
@_handle_boolean.register(Le)
307+
@_handle_boolean.register(Ge)
308+
@_handle_boolean.register(Lt)
309+
@_handle_boolean.register(Gt)
310+
def _(obj, mapper):
311+
d, v = obj.args
312+
k = obj.__class__
313+
mapper.setdefault(k, {})[d] = v
314+
315+
@property
316+
def as_mapper(self):
317+
mapper = {}
318+
GuardExpr._handle_boolean(self.initvalue, mapper)
319+
return frozendict(mapper)
305320

306-
def __eq__(self, other):
307-
return (isinstance(other, GuardExpr) and
308-
super().__eq__(other) and
309-
self.supersets == other.supersets)
321+
def sort_key(self, order=None):
322+
# Use the overarching LocalObject name for arguments ordering
323+
class_key, args, exp, coeff = super().sort_key(order=order)
324+
args = (len(args[1]) + 1, (self.name,) + args[1])
325+
return class_key, args, exp, coeff
310326

311327

312328
# *** Utils
313329

330+
op_mapper = {
331+
Le: le,
332+
Lt: lt,
333+
Ge: ge,
334+
Gt: gt
335+
}
336+
337+
314338
def simplify_and(relation, v):
315339
"""
316340
Given `x = And(*relation.args, v)`, return `relation` if `x ≡ relation`,
@@ -327,19 +351,44 @@ def simplify_and(relation, v):
327351
else:
328352
candidates, other = [], [relation, v]
329353

330-
# Quick check based on GuardExpr.supersets to avoid adding `v` to `relation`
331-
# if `relation` already includes a more restrictive guard than `v`
332-
if isinstance(v, GuardExpr):
333-
if any(a in v.supersets for a in candidates):
334-
return relation
335-
336354
covered = False
337355
new_args = []
338356
for a in candidates:
339-
if isinstance(a, GuardExpr) or a.lhs is not v.lhs:
357+
if isinstance(v, GuardExpr) and isinstance(a, GuardExpr):
358+
# Attempt optimizing guards in GuardExpr form
359+
covered = True
360+
361+
m0 = v.as_mapper
362+
m1 = a.as_mapper
363+
364+
for cls, op in op_mapper.items():
365+
if cls in m0 and cls in m1:
366+
try:
367+
if set(m0[cls]) != set(m1[cls]):
368+
new_args.extend([a, v])
369+
elif all(op(m0[cls][d], m1[cls][d]) for d in m0[cls]):
370+
new_args.append(v)
371+
elif all(op(m1[cls][d], m0[cls][d]) for d in m1[cls]):
372+
new_args.append(a)
373+
else:
374+
new_args.extend([a, v])
375+
except TypeError:
376+
# E.g., `cls = Le`, then `z <= 2` and `z <= z_M + 1`
377+
new_args.extend([a, v])
378+
379+
elif cls in m0:
380+
new_args.append(v)
381+
382+
elif cls in m1:
383+
new_args.append(a)
384+
385+
elif a.lhs is not v.lhs:
340386
new_args.append(a)
387+
341388
else:
389+
# Attempt optimizing guards in relational form
342390
covered = True
391+
343392
try:
344393
if type(a) in (Gt, Ge) and v.rhs > a.rhs:
345394
new_args.append(v)

tests/test_symbolics.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import pytest
55
import numpy as np
66

7-
from sympy import Expr, Number, Symbol
7+
from sympy import And, Expr, Number, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
1010
Min, Max, Real, Imag, Conj, SubDomain, configuration)
1111
from devito.finite_differences.differentiable import SafeInv, Weights, Mul
1212
from devito.ir import Expression, FindNodes, ccode
13+
from devito.ir.support.guards import GuardExpr, simplify_and
1314
from devito.mpi.halo_scheme import HaloTouch
1415
from devito.symbolics import (
1516
retrieve_functions, retrieve_indexed, evalrel, CallFromPointer, Cast, # noqa
@@ -543,6 +544,133 @@ def test_halo_touch():
543544
assert hash(ht0) == hash(ht0._rebuild())
544545

545546

547+
def test_guard_expr_Le():
548+
grid = Grid(shape=(3, 3, 3))
549+
_, y, z = grid.dimensions
550+
551+
y_M = y.symbolic_max
552+
z_M = z.symbolic_max
553+
554+
g0 = GuardExpr('g0', initvalue=And(y <= y_M + 3, z <= z_M + 3))
555+
g1 = GuardExpr('g1', initvalue=And(y <= y_M + 3, z <= z_M + 4))
556+
557+
v0 = simplify_and(g0, g1)
558+
v1 = simplify_and(g1, g0)
559+
assert v0 is g0
560+
assert v0 is v1
561+
562+
v2 = simplify_and(And(g0, g1), g0)
563+
assert v2 is g0
564+
565+
g3 = GuardExpr('g3', initvalue=And(y <= y_M + 2, z <= z_M + 3))
566+
v3 = simplify_and(g0, g3)
567+
v4 = simplify_and(g3, g0)
568+
assert v3 is g3
569+
assert v4 is v3
570+
571+
g4 = GuardExpr('g4', initvalue=And(y <= y_M + 4, z <= z_M + 2))
572+
v5 = simplify_and(g0, g4)
573+
assert v5 == And(g0, g4)
574+
575+
g5 = GuardExpr('g5', initvalue=And(y <= y_M))
576+
g6 = GuardExpr('g6', initvalue=And(z <= z_M))
577+
v6 = simplify_and(g0, g5)
578+
v7 = simplify_and(g0, g6)
579+
v8 = simplify_and(g5, g6)
580+
assert v6 == And(g0, g5)
581+
assert v7 == And(g0, g6)
582+
assert v8 == And(g5, g6)
583+
584+
585+
def test_guard_expr_Ge():
586+
grid = Grid(shape=(3, 3, 3))
587+
_, y, z = grid.dimensions
588+
589+
y_m = y.symbolic_min
590+
z_m = z.symbolic_min
591+
592+
g0 = GuardExpr('g0', initvalue=And(y >= y_m - 3, z >= z_m - 3))
593+
g1 = GuardExpr('g1', initvalue=And(y >= y_m - 3, z >= z_m - 4))
594+
v0 = simplify_and(g0, g1)
595+
v1 = simplify_and(g1, g0)
596+
assert v0 is g0
597+
assert v0 is v1
598+
599+
v2 = simplify_and(And(g0, g1), g0)
600+
assert v2 is g0
601+
602+
g3 = GuardExpr('g3', initvalue=And(y >= y_m - 2, z >= z_m - 3))
603+
v3 = simplify_and(g0, g3)
604+
v4 = simplify_and(g3, g0)
605+
assert v3 is g3
606+
assert v4 is v3
607+
608+
g4 = GuardExpr('g4', initvalue=And(y >= y_m - 4, z >= z_m - 2))
609+
v5 = simplify_and(g0, g4)
610+
assert v5 == And(g0, g4)
611+
g5 = GuardExpr('g5', initvalue=And(y >= y_m))
612+
g6 = GuardExpr('g6', initvalue=And(z >= z_m))
613+
v6 = simplify_and(g0, g5)
614+
v7 = simplify_and(g0, g6)
615+
v8 = simplify_and(g5, g6)
616+
assert v6 == And(g0, g5)
617+
assert v7 == And(g0, g6)
618+
assert v8 == And(g5, g6)
619+
620+
g7 = GuardExpr('g7', initvalue=And(y >= -2, z >= -2))
621+
v9 = simplify_and(g0, g7)
622+
assert v9 == And(g0, g7)
623+
624+
625+
def test_guard_expr_Le_Ge_mixed():
626+
grid = Grid(shape=(3, 3, 3))
627+
_, y, z = grid.dimensions
628+
629+
y_m = y.symbolic_min
630+
y_M = y.symbolic_max
631+
z_m = z.symbolic_min
632+
z_M = z.symbolic_max
633+
634+
g0 = GuardExpr('g0', initvalue=And(y <= y_M + 3, z <= z_M + 3))
635+
g1 = GuardExpr('g1', initvalue=And(y >= y_m - 3, z >= z_m - 3))
636+
v0 = simplify_and(g0, g1)
637+
v1 = simplify_and(g1, g0)
638+
assert v0 == And(g0, g1)
639+
assert v1 == And(g0, g1)
640+
641+
g2 = GuardExpr('g2', initvalue=And(y <= y_M + 2, z >= z_m - 2))
642+
v2 = simplify_and(g0, g2)
643+
v3 = simplify_and(g2, g0)
644+
assert v2 == And(g0, g2)
645+
assert v3 == And(g0, g2)
646+
647+
g3 = GuardExpr('g3', initvalue=And(y >= y_m - 2, z <= z_M + 2))
648+
v4 = simplify_and(g0, g3)
649+
v5 = simplify_and(g3, g0)
650+
assert v4 == And(g0, g3)
651+
assert v5 == And(g0, g3)
652+
653+
g4 = GuardExpr('g4', initvalue=And(y <= y_M + 2, z >= z_m, z <= z_M + 2))
654+
v6 = simplify_and(g0, g4)
655+
v7 = simplify_and(g4, g0)
656+
assert v6 is g4
657+
assert v7 is g4
658+
659+
g5 = GuardExpr('g5', initvalue=And(y >= y_m - 2, y <= y_M + 2,
660+
z >= z_m - 2, z <= z_M + 2))
661+
v8 = simplify_and(g0, g5)
662+
v9 = simplify_and(g5, g0)
663+
assert v8 is g5
664+
assert v9 is g5
665+
666+
g6 = GuardExpr('g6', initvalue=And(y >= y_m, y <= y_M,
667+
z >= z_m - 3, z <= z_M))
668+
v10 = simplify_and(g5, g6)
669+
v11 = simplify_and(g6, g5)
670+
assert v10 is And(g5, g6)
671+
assert v11 is And(g5, g6)
672+
673+
546674
def test_canonical_ordering_of_weights():
547675
grid = Grid(shape=(3, 3, 3))
548676
x, y, z = grid.dimensions

0 commit comments

Comments
 (0)