Skip to content

Commit 370d143

Browse files
authored
Merge pull request #2840 from devitocodes/the-TMA
compiler: Enhance IR to support more advanced parlang (CUDA/HIP/SYCL) features
2 parents 5dc569c + ed3d9c5 commit 370d143

25 files changed

Lines changed: 337 additions & 124 deletions

devito/arch/archinfo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ def supports(self, query, language=None):
11391139
warning(f"Couldn't establish if `query={query}` is supported on this "
11401140
"system. Assuming it is not.")
11411141
return False
1142-
elif query == 'async-loads' and cc >= 80:
1142+
elif query == 'async-pipe' and cc >= 80:
11431143
# Asynchronous pipeline loads -- introduced in Ampere
11441144
return True
11451145
elif query in ('tma', 'thread-block-cluster') and cc >= 90: # noqa: SIM103
@@ -1156,7 +1156,7 @@ class Volta(NvidiaDevice):
11561156
class Ampere(Volta):
11571157

11581158
def supports(self, query, language=None):
1159-
if query == 'async-loads':
1159+
if query == 'async-pipe':
11601160
return True
11611161
else:
11621162
return super().supports(query, language)

devito/ir/cgen/printer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,13 @@ def _print_BitwiseNot(self, expr):
281281

282282
def _print_BitwiseBinaryOp(self, expr):
283283
arg0, arg1 = expr.args
284+
285+
prec = precedence(expr)
286+
if not arg0.is_Atom:
287+
arg0 = self.parenthesize(arg0, prec)
288+
if not arg1.is_Atom:
289+
arg1 = self.parenthesize(arg1, prec)
290+
284291
return f'{self._print(arg0)} {expr.op} {self._print(arg1)}'
285292

286293
def _print_Add(self, expr, order=None):

devito/ir/clusters/cluster.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,35 @@ def free_symbols(self):
171171
def dimensions(self):
172172
return set().union(*[i._defines for i in self.ispace.dimensions])
173173

174+
@cached_property
175+
def exprs_dimensions(self):
176+
"""
177+
The Dimensions that appear explicitly in the Cluster expressions.
178+
"""
179+
dims_explicit = {i for i in self.free_symbols if i.is_Dimension}
180+
dims_implicit = {d for e in self.exprs for d in e.implicit_dims}
181+
return dims_explicit | dims_implicit
182+
183+
@cached_property
184+
def guards_dimensions(self):
185+
"""
186+
The Dimensions that appear explicitly in the Cluster guards.
187+
"""
188+
syms_guards = {d for e in self.guards.values() for d in e.free_symbols}
189+
dims_guards = {i for i in syms_guards if i.is_Dimension}
190+
return dims_guards
191+
174192
@cached_property
175193
def used_dimensions(self):
176194
"""
177-
The Dimensions that *actually* appear among the expressions in ``self``.
178-
These do not necessarily coincide the IterationSpace Dimensions; for
179-
example, reduction or redundant (i.e., invariant) Dimensions won't
180-
appear in an expression.
195+
All the Dimensions that appear explicitly either within the expressions
196+
or the guards.
197+
198+
Note that, in some cases, some of the IterationSpace Dimensions might
199+
not appear here among the used Dimensions -- for example, reduction or
200+
redundant (i.e., invariant) Dimensions.
181201
"""
182-
idims = set.union(*[set(e.implicit_dims) for e in self.exprs])
183-
return {i for i in self.free_symbols if i.is_Dimension} | idims
202+
return self.exprs_dimensions | self.guards_dimensions
184203

185204
@cached_property
186205
def dist_dimensions(self):

devito/ir/equations/equation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,11 @@ def __repr__(self):
9292
if not self.is_Reduction:
9393
return super().__repr__()
9494
elif self.operation is OpInc:
95-
return f'{self.lhs} += {self.rhs}'
95+
return f'Inc({self.lhs}, {self.rhs})'
9696
else:
97-
return f'{self.lhs} = {self.operation}({self.rhs})'
97+
return f'Eq({self.lhs}, {self.operation}({self.rhs}))'
98+
99+
__str__ = __repr__
98100

99101
# Pickling support
100102
__reduce_ex__ = Pickable.__reduce_ex__

devito/ir/iet/visitors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ def _gen_value(self, obj, mode=1, masked=()):
321321
qualifiers = [v for k, v in self._qualifiers_mapper.items()
322322
if getattr(obj.function, k, False) and v not in masked]
323323

324+
if obj.is_LocalObject and mode == 2:
325+
qualifiers.extend(as_tuple(obj._C_tag))
326+
324327
if (obj._mem_stack or obj._mem_constant) and mode == 1:
325328
strtype = self.ccode(obj._C_typedata)
326329
strshape = ''.join(f'[{self.ccode(i)}]' for i in obj.symbolic_shape)

devito/ir/support/guards.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,9 @@ def pairwise_or(*guards):
500500

501501
# Analysis
502502
for guard in guards:
503-
if guard is true or guard is None:
503+
if guard is true:
504+
return true
505+
elif guard is None:
504506
continue
505507
elif isinstance(guard, And):
506508
components = guard.args

devito/ir/support/space.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -789,17 +789,19 @@ def __init__(self, intervals, sub_iterators=None, directions=None):
789789
super().__init__(intervals)
790790

791791
# Normalize sub-iterators
792-
sub_iterators = dict([(k, tuple(filter_ordered(as_tuple(v))))
793-
for k, v in (sub_iterators or {}).items()])
792+
sub_iterators = sub_iterators or {}
793+
sub_iterators = {d: tuple(filter_ordered(as_tuple(v)))
794+
for d, v in sub_iterators.items() if d in self.intervals}
794795
sub_iterators.update({i.dim: () for i in self.intervals
795796
if i.dim not in sub_iterators})
796797
self._sub_iterators = frozendict(sub_iterators)
797798

798799
# Normalize directions
799-
if directions is None:
800-
self._directions = frozendict([(i.dim, Any) for i in self.intervals])
801-
else:
802-
self._directions = frozendict(directions)
800+
directions = directions or {}
801+
directions = {d: v for d, v in directions.items() if d in self.intervals}
802+
directions.update({i.dim: Any for i in self.intervals
803+
if i.dim not in directions})
804+
self._directions = frozendict(directions)
803805

804806
def __repr__(self):
805807
ret = ', '.join([f"{repr(i)}{repr(self.directions[i.dim])}"
@@ -821,8 +823,7 @@ def __lt__(self, other):
821823
return len(self.itintervals) < len(other.itintervals)
822824

823825
def __hash__(self):
824-
return hash((super().__hash__(), self.sub_iterators,
825-
self.directions))
826+
return hash((super().__hash__(), self.sub_iterators, self.directions))
826827

827828
def __contains__(self, d):
828829
try:

devito/passes/clusters/cse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313

1414
from devito.finite_differences.differentiable import IndexDerivative
1515
from devito.ir import Cluster, Scope, cluster_pass
16-
from devito.symbolics import estimate_cost, q_leaf, q_terminal
16+
from devito.symbolics import Reserved, estimate_cost, q_leaf, q_terminal, search
1717
from devito.symbolics.manipulation import _uxreplace
18-
from devito.symbolics.search import search
1918
from devito.tools import DAG, as_list, as_tuple, extract_dtype, frozendict
2019
from devito.types import Eq, Symbol, Temp
2120

@@ -411,6 +410,7 @@ def _(expr):
411410

412411
@_catch.register(Indexed)
413412
@_catch.register(Symbol)
413+
@_catch.register(Reserved)
414414
def _(expr):
415415
"""
416416
Handler for objects preventing CSE to propagate through their arguments.

devito/passes/clusters/derivatives.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from sympy import S
55

6-
from devito.finite_differences import IndexDerivative
6+
from devito.finite_differences import IndexDerivative, Weights
77
from devito.ir import Backward, Forward, Interval, IterationSpace, Queue
88
from devito.passes.clusters.misc import fuse
99
from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace
@@ -91,17 +91,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs):
9191

9292

9393
@_core.register(Symbol)
94-
@_core.register(Indexed)
9594
@_core.register(BasicWrapperMixin)
9695
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
9796
return expr, []
9897

9998

99+
@_core.register(Indexed)
100+
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
101+
if not isinstance(expr.function, Weights):
102+
return expr, []
103+
104+
# Lower or reuse a previously lowered Weights array
105+
sregistry = kwargs['sregistry']
106+
subs_user = kwargs['subs']
107+
108+
w0 = expr.function
109+
k = tuple(w0.weights)
110+
try:
111+
w = weights[k]
112+
except KeyError:
113+
name = sregistry.make_name(prefix='w')
114+
dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32
115+
initvalue = tuple(i.subs(subs_user) for i in k)
116+
w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue)
117+
118+
rebuilt = expr._subs(w0.indexed, w.indexed)
119+
120+
return rebuilt, []
121+
122+
100123
@_core.register(IndexDerivative)
101124
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
102125
sregistry = kwargs['sregistry']
103126
options = kwargs['options']
104-
subs_user = kwargs['subs']
105127

106128
try:
107129
cbk0 = deriv_schedule_registry[options['deriv-schedule']]
@@ -114,18 +136,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
114136

115137
# Create the concrete Weights array, or reuse an already existing one
116138
# if possible
117-
name = sregistry.make_name(prefix='w')
118-
w0 = ideriv.weights.function
119-
dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32
120-
k = tuple(w0.weights)
121-
try:
122-
w = weights[k]
123-
except KeyError:
124-
initvalue = tuple(i.subs(subs_user) for i in k)
125-
w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue)
139+
w, _ = _core(ideriv.weights, c, ispace, weights, reusables, mapper, **kwargs)
126140

127141
# Replace the abstract Weights array with the concrete one
128-
subs = {w0.indexed: w.indexed}
142+
subs = {ideriv.weights.base: w.base}
129143
init = uxreplace(init, subs)
130144
ideriv = uxreplace(ideriv, subs)
131145

@@ -152,13 +166,13 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
152166
ispace1 = IterationSpace.union(ispace, ispace0, relations=extra)
153167

154168
# The Symbol that will hold the result of the IndexDerivative computation
155-
# NOTE: created before recurring so that we ultimately get a sound ordering
169+
# NOTE: created before recursing so that we ultimately get a sound ordering
156170
try:
157171
s = reusables.pop()
158-
assert np.can_cast(s.dtype, dtype)
172+
assert np.can_cast(s.dtype, w.dtype)
159173
except KeyError:
160174
name = sregistry.make_name(prefix='r')
161-
s = Symbol(name=name, dtype=dtype)
175+
s = Symbol(name=name, dtype=w.dtype)
162176

163177
# Go inside `expr` and recursively lower any nested IndexDerivatives
164178
expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, **kwargs)

devito/passes/clusters/misc.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def callback(self, clusters, prefix):
5555
continue
5656

5757
# Is `c` a real candidate -- is there at least one invariant Dimension?
58-
if any(d._defines & hope_invariant for d in c.used_dimensions):
58+
if any(d._defines & hope_invariant for d in c.exprs_dimensions):
5959
processed.append(c)
6060
continue
6161

@@ -69,16 +69,16 @@ def callback(self, clusters, prefix):
6969
# All of the inner Dimensions must appear in the write-to region
7070
# otherwise we would violate data dependencies. Consider
7171
#
72-
# 1) 2) 3)
73-
# for i for i for i
74-
# for x for x for x
75-
# r = f(a[x]) for y for y
76-
# r[x] = f(a[x, y]) r[x, y] = f(a[x, y])
72+
# 1) 2) 3)
73+
# for i for i for i
74+
# for x for x for x
75+
# r = f(a[x]) for y for y
76+
# r[x] = f(a[x, y]) r[x, y] = f(a[x, y])
7777
#
7878
# In 1) and 2) lifting is infeasible; in 3) the statement can
7979
# be lifted outside the `i` loop as `r`'s write-to region contains
8080
# both `x` and `y`
81-
xed = {d._defines for d in c.used_dimensions if d not in outer}
81+
xed = {d._defines for d in c.exprs_dimensions if d not in outer}
8282
if not all(i & set(w.dimensions) for i, w in product(xed, c.scope.writes)):
8383
processed.append(c)
8484
continue

0 commit comments

Comments
 (0)