Skip to content

Commit e2c73cf

Browse files
committed
compiler: ensure inf is treated as non-strict direction
1 parent 3c93647 commit e2c73cf

4 files changed

Lines changed: 38 additions & 10 deletions

File tree

devito/ir/support/basic.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def distance(self, other):
351351
# E.g., `self=R<f,[x + 2]>` and `other=W<f,[i + 1]>`
352352
# E.g., `self=R<f,[x]>`, `other=W<f,[x + 1]>`,
353353
# `self.itintervals=(x<0>,)`, `other.itintervals=(x<1>,)`
354-
return vinf(ret)
354+
return v_undef(ret)
355355
except AttributeError:
356356
# E.g., `self=R<f,[cy]>` and `self.itintervals=(y,)` => `sai=None`
357357
pass
@@ -413,10 +413,12 @@ def distance(self, other):
413413
# E.g., `self=R<f,[x, y]>`, `sai=time`,
414414
# `self.itintervals=(time, x, y)`, `n=0`
415415
continue
416+
elif not sai or not oai:
417+
return v_undef(ret, val=S.NaN)
416418
else:
417419
# E.g., `self=R<u,[t+1, ii_src_0+1, ii_src_1+2]>`, `fi=p_src`,
418420
# and `n=1`
419-
return vinf(ret)
421+
return v_undef(ret)
420422

421423
n = len(ret)
422424

@@ -640,7 +642,11 @@ def cause(self):
640642
"""Return the findex causing the dependence."""
641643
for i, j in zip(self.findices, self.distance):
642644
try:
643-
if j > 0:
645+
# If j is S.Infinity, then the direction was not defined
646+
# so we treat it as non cause.
647+
if j is S.NaN:
648+
continue
649+
elif j > 0:
644650
return i._defines
645651
except TypeError:
646652
# Conservatively assume this is an offending dimension
@@ -1353,8 +1359,8 @@ def is_regular(self):
13531359

13541360
# *** Utils
13551361

1356-
def vinf(entries):
1357-
return Vector(*(entries + [S.Infinity]))
1362+
def v_undef(entries, val=S.Infinity):
1363+
return Vector(*(entries + [val]))
13581364

13591365

13601366
def retrieve_accesses(exprs, **kwargs):

devito/passes/iet/mpi.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,11 @@ def _semantical_eq_loc_indices(hsf0, hsf1):
530530

531531
# Special case: they might be syntactically different, but semantically
532532
# equivalent, e.g., `t0` and `t1` with same modulus
533-
if v0.modulo == v1.modulo == 1:
534-
continue
533+
try:
534+
if v0.modulo == v1.modulo == 1:
535+
continue
536+
except AttributeError:
537+
return False
535538

536539
return False
537540

devito/types/dense.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from devito.logger import debug, warning
1717
from devito.mpi import MPI
1818
from devito.parameters import configuration
19-
from devito.symbolics import FieldFromPointer, normalize_args
19+
from devito.symbolics import FieldFromPointer, normalize_args, IndexedPointer
2020
from devito.finite_differences import Differentiable, generate_fd_shortcuts
2121
from devito.finite_differences.tools import fd_weights_registry
2222
from devito.tools import (ReducerMap, as_tuple, c_restrict_void_p, flatten,
@@ -719,7 +719,7 @@ def _C_make_index(self, dim, side=None):
719719
@memoized_meth
720720
def _C_get_field(self, region, dim, side=None):
721721
"""Symbolic representation of a given data region."""
722-
ffp = lambda f, i: FieldFromPointer("%s[%d]" % (f, i), self._C_symbol)
722+
ffp = lambda f, i: IndexedPointer(FieldFromPointer("%s" % f, self._C_symbol), i)
723723
if region is DOMAIN:
724724
offset = ffp(self._C_field_owned_ofs, self._C_make_index(dim, LEFT))
725725
size = ffp(self._C_field_domain_size, self._C_make_index(dim))

tests/test_operator.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
from devito.passes.iet.languages.C import CDataManager
3636
from devito.symbolics import ListInitializer, indexify, retrieve_indexed
3737
from devito.tools import flatten, powerset, timed_region
38-
from devito.types import Array, Barrier, CustomDimension, Indirection, Scalar, Symbol
38+
from devito.types import (
39+
Array, Barrier, CustomDimension, Indirection, Scalar, Symbol, ConditionalDimension
40+
)
3941

4042

4143
def dimify(dimensions):
@@ -2034,6 +2036,23 @@ def test_2194_v2(self, eqns, expected, exp_trees, exp_iters):
20342036
op.apply()
20352037
assert(np.all(u.data[:] == expected[:]))
20362038

2039+
def test_pseudo_time_dep(self):
2040+
"""
2041+
Test taht a data dependency through a field is correctly
2042+
ignore when nor direction dependent
2043+
"""
2044+
grid = Grid((11, 11))
2045+
ct = ConditionalDimension(name='ct', parent=grid.time_dim, factor=2)
2046+
f = TimeFunction(name='f', grid=grid, time_order=1)
2047+
g = Function(name='g', grid=grid)
2048+
2049+
eq = [Eq(f.backward, div(f) + 1),
2050+
Eq(g, g + f.symbolic_shape[1], implicit_dims=ct),
2051+
Eq(g, g + 1, implicit_dims=ct)]
2052+
op = Operator(eq)
2053+
2054+
assert_structure(op, ['t,x,y', 't', 't,x,y'], 't,x,y,x,y')
2055+
20372056

20382057
class TestInternals:
20392058

0 commit comments

Comments
 (0)