Skip to content

Commit 883a710

Browse files
committed
compiler: Patch IterationSpace.intersection
1 parent 2506149 commit 883a710

2 files changed

Lines changed: 26 additions & 2 deletions

File tree

devito/ir/support/space.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,16 @@ def generate(cls, op, *interval_groups, relations=None):
378378
379379
>>> ig = IntervalGroup.generate('intersection', ig0, ig1, ig2)
380380
"""
381+
if op == 'intersection':
382+
dims = set.intersection(*[set(ig.dimensions) for ig in interval_groups])
383+
else:
384+
dims = set().union(*[ig.dimensions for ig in interval_groups])
385+
381386
mapper = {}
382387
for ig in interval_groups:
383388
for i in ig:
384-
mapper.setdefault(i.dim, []).append(i)
389+
if i.dim in dims:
390+
mapper.setdefault(i.dim, []).append(i)
385391

386392
intervals = []
387393
for v in mapper.values():

tests/test_ir.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
Vector, AFFINE, REGULAR, IRREGULAR, mocksym0,
1414
mocksym1)
1515
from devito.ir.support.space import (NullInterval, Interval, Forward, Backward,
16-
IterationSpace)
16+
IntervalGroup, IterationSpace)
1717
from devito.ir.support.guards import GuardOverflow
1818
from devito.symbolics import DefFunction, FieldFromPointer
1919
from devito.tools import prod
@@ -508,6 +508,24 @@ def test_intervals_switch(self, x, y):
508508
assert iy.switch(x) == ix
509509
assert ix.switch(y).switch(x) == ix
510510

511+
def test_space_intersection(self, x, y):
512+
ig0 = IntervalGroup([Interval(x, 1, -1)])
513+
ig1 = IntervalGroup([Interval(x, 2, -2), Interval(y, 3, -3)])
514+
515+
ig = IntervalGroup.generate('intersection', ig0, ig1)
516+
517+
assert len(ig) == 1
518+
assert ig[0] == Interval(x, 2, -2)
519+
520+
# Now the same but with IterationSpaces
521+
ispace0 = IterationSpace(ig0)
522+
ispace1 = IterationSpace(ig1)
523+
524+
ispace = IterationSpace.intersection(ispace0, ispace1)
525+
526+
assert len(ispace) == 1
527+
assert ispace.intervals == ig
528+
511529

512530
class TestDependenceAnalysis:
513531

0 commit comments

Comments
 (0)