Skip to content

Commit a1ebd69

Browse files
authored
Merge pull request #2675 from devitocodes/revert-2669-cached-scopes
Revert "compiler: Cache `Scope` + `Dependence` instances"
2 parents 8f9bc6f + 5f2614b commit a1ebd69

8 files changed

Lines changed: 96 additions & 212 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from devito.ir.equations import OpMin, OpMax, identity_mapper
1313
from devito.ir.clusters.analysis import analyze
1414
from devito.ir.clusters.cluster import Cluster, ClusterGroup
15-
from devito.ir.clusters.visitors import Queue, cluster_pass
16-
from devito.ir.support import Scope
15+
from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass
1716
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
1817
from devito.mpi.reduction_scheme import DistReduce
1918
from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace,
@@ -78,7 +77,7 @@ def impose_total_ordering(clusters):
7877
return processed
7978

8079

81-
class Schedule(Queue):
80+
class Schedule(QueueStateful):
8281

8382
"""
8483
This special Queue produces a new sequence of "scheduled" Clusters, which
@@ -136,7 +135,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
136135
# `clusters` are supposed to share it
137136
candidates = prefix[-1].dim._defines
138137

139-
scope = Scope(flatten(c.exprs for c in clusters))
138+
scope = self._fetch_scope(clusters)
140139

141140
# Handle the nastiest case -- ambiguity due to the presence of both a
142141
# flow- and an anti-dependence.

devito/ir/clusters/analysis.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,31 @@
1-
from devito.ir.clusters.cluster import Cluster
2-
from devito.ir.clusters.visitors import Queue
1+
from devito.ir.clusters.visitors import QueueStateful
32
from devito.ir.support import (AFFINE, PARALLEL, PARALLEL_INDEP, PARALLEL_IF_ATOMIC,
4-
SEQUENTIAL, Property, Scope)
5-
from devito.ir.support.space import IterationSpace
3+
SEQUENTIAL)
64
from devito.tools import as_tuple, flatten, timed_pass
7-
from devito.types.dimension import Dimension
85

96
__all__ = ['analyze']
107

118

12-
# Describes properties fetched by a `Detector`
13-
Properties = dict[Cluster, dict[Dimension, set[Property]]]
14-
15-
169
@timed_pass()
1710
def analyze(clusters):
18-
properties: Properties = {}
11+
state = QueueStateful.State()
1912

2013
# Collect properties
21-
clusters = Parallelism().process(clusters, properties=properties)
22-
clusters = Affiness().process(clusters, properties=properties)
14+
clusters = Parallelism(state).process(clusters)
15+
clusters = Affiness(state).process(clusters)
2316

2417
# Reconstruct Clusters attaching the discovered properties
25-
processed = [c.rebuild(properties=properties.get(c)) for c in clusters]
18+
processed = [c.rebuild(properties=state.properties.get(c)) for c in clusters]
2619

2720
return processed
2821

2922

30-
class Detector(Queue):
23+
class Detector(QueueStateful):
3124

32-
def process(self, clusters: list[Cluster], properties: Properties) -> list[Cluster]:
33-
return self._process_fatd(clusters, 1, properties=properties)
25+
def process(self, elements):
26+
return self._process_fatd(elements, 1)
3427

35-
def callback(self, clusters: list[Cluster], prefix: IterationSpace | None,
36-
properties: Properties) -> list[Cluster]:
28+
def callback(self, clusters, prefix):
3729
if not prefix:
3830
return clusters
3931

@@ -49,19 +41,11 @@ def callback(self, clusters: list[Cluster], prefix: IterationSpace | None,
4941
# Update `self.state`
5042
if retval:
5143
for c in clusters:
52-
c_properties = properties.setdefault(c, {})
53-
c_properties.setdefault(d, set()).update(retval)
44+
properties = self.state.properties.setdefault(c, {})
45+
properties.setdefault(d, set()).update(retval)
5446

5547
return clusters
5648

57-
def _callback(self, clusters: list[Cluster], dim: Dimension,
58-
prefix: IterationSpace | None) -> set[Property]:
59-
"""
60-
Callback to be implemented by subclasses. It should return a set of
61-
properties for the given dimension.
62-
"""
63-
raise NotImplementedError()
64-
6549

6650
class Parallelism(Detector):
6751

@@ -88,27 +72,27 @@ class Parallelism(Detector):
8872
the 'write' is known to be an associative and commutative increment
8973
"""
9074

91-
def _callback(self, clusters, dim, prefix):
75+
def _callback(self, clusters, d, prefix):
9276
# Rule out if non-unitary increment Dimension (e.g., `t0=(time+1)%2`)
93-
if any(c.sub_iterators[dim] for c in clusters):
94-
return {SEQUENTIAL}
77+
if any(c.sub_iterators[d] for c in clusters):
78+
return SEQUENTIAL
9579

9680
# All Dimensions up to and including `i-1`
9781
prev = flatten(i.dim._defines for i in prefix[:-1])
9882

9983
is_parallel_indep = True
10084
is_parallel_atomic = False
10185

102-
scope = Scope(flatten(c.exprs for c in clusters))
86+
scope = self._fetch_scope(clusters)
10387
for dep in scope.d_all_gen():
104-
test00 = dep.is_indep(dim) and not dep.is_storage_related(dim)
88+
test00 = dep.is_indep(d) and not dep.is_storage_related(d)
10589
test01 = all(dep.is_reduce_atmost(i) for i in prev)
10690
if test00 and test01:
10791
continue
10892

10993
test1 = len(prev) > 0 and any(dep.is_carried(i) for i in prev)
11094
if test1:
111-
is_parallel_indep &= (dep.distance_mapper.get(dim.root) == 0)
95+
is_parallel_indep &= (dep.distance_mapper.get(d.root) == 0)
11296
continue
11397

11498
if dep.function in scope.initialized:
@@ -119,14 +103,14 @@ def _callback(self, clusters, dim, prefix):
119103
is_parallel_atomic = True
120104
continue
121105

122-
return {SEQUENTIAL}
106+
return SEQUENTIAL
123107

124108
if is_parallel_atomic:
125-
return {PARALLEL_IF_ATOMIC}
109+
return PARALLEL_IF_ATOMIC
126110
elif is_parallel_indep:
127111
return {PARALLEL, PARALLEL_INDEP}
128112
else:
129-
return {PARALLEL}
113+
return PARALLEL
130114

131115

132116
class Affiness(Detector):
@@ -135,11 +119,8 @@ class Affiness(Detector):
135119
Detect the AFFINE Dimensions.
136120
"""
137121

138-
def _callback(self, clusters, dim, prefix):
139-
scope = Scope(flatten(c.exprs for c in clusters))
122+
def _callback(self, clusters, d, prefix):
123+
scope = self._fetch_scope(clusters)
140124
accesses = [a for a in scope.accesses if not a.is_scalar]
141-
142-
if all(a.is_regular and a.affine_if_present(dim._defines) for a in accesses):
143-
return {AFFINE}
144-
145-
return set()
125+
if all(a.is_regular and a.affine_if_present(d._defines) for a in accesses):
126+
return AFFINE

devito/ir/clusters/visitors.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
from collections import defaultdict
12
from collections.abc import Iterable
23

34
from itertools import groupby
45

5-
from devito.ir.support import IterationSpace, null_ispace
6-
from devito.tools import flatten, timed_pass
6+
from devito.ir.support import IterationSpace, Scope, null_ispace
7+
from devito.tools import as_tuple, flatten, timed_pass
78

8-
__all__ = ['Queue', 'cluster_pass']
9+
__all__ = ['Queue', 'QueueStateful', 'cluster_pass']
910

1011

1112
class Queue:
@@ -112,6 +113,48 @@ def _process_fatd(self, clusters, level, prefix=None, **kwargs):
112113
return processed
113114

114115

116+
class QueueStateful(Queue):
117+
118+
"""
119+
A Queue carrying along some state. This is useful when one wants to avoid
120+
expensive re-computations of information.
121+
"""
122+
123+
class State:
124+
125+
def __init__(self):
126+
self.properties = {}
127+
self.scopes = {}
128+
129+
def __init__(self, state=None):
130+
super().__init__()
131+
self.state = state or QueueStateful.State()
132+
133+
def _fetch_scope(self, clusters):
134+
exprs = flatten(c.exprs for c in as_tuple(clusters))
135+
key = tuple(exprs)
136+
if key not in self.state.scopes:
137+
self.state.scopes[key] = Scope(exprs)
138+
return self.state.scopes[key]
139+
140+
def _fetch_properties(self, clusters, prefix):
141+
# If the situation is:
142+
#
143+
# t
144+
# x0
145+
# <some clusters>
146+
# x1
147+
# <some other clusters>
148+
#
149+
# then retain only the "common" properties, that is those along `t`
150+
properties = defaultdict(set)
151+
for c in clusters:
152+
v = self.state.properties.get(c, {})
153+
for i in prefix:
154+
properties[i.dim].update(v.get(i.dim, set()))
155+
return properties
156+
157+
115158
class Prefix(IterationSpace):
116159

117160
def __init__(self, ispace, guards, properties, syncs):

devito/ir/support/basic.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from collections.abc import Iterable
21
from itertools import chain, product
32
from functools import cached_property
4-
from typing import Callable
53

6-
from sympy import S, Expr
4+
from sympy import S
75
import sympy
86

97
from devito.ir.support.space import Backward, null_ispace
@@ -14,7 +12,7 @@
1412
uxreplace)
1513
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
1614
flatten, memoized_meth, memoized_generator, smart_gt,
17-
smart_lt, CacheInstances)
15+
smart_lt)
1816
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
1917
CriticalRegion, Function, Symbol, Temp, TempArray,
2018
TBArray)
@@ -248,7 +246,7 @@ def __eq__(self, other):
248246
self.ispace == other.ispace)
249247

250248
def __hash__(self):
251-
return hash((self.access, self.mode, self.timestamp, self.ispace))
249+
return super().__hash__()
252250

253251
@property
254252
def function(self):
@@ -626,7 +624,7 @@ def is_imaginary(self):
626624
return S.ImaginaryUnit in self.distance
627625

628626

629-
class Dependence(Relation, CacheInstances):
627+
class Dependence(Relation):
630628

631629
"""
632630
A data dependence between two TimedAccess objects.
@@ -825,26 +823,17 @@ def project(self, function):
825823
return DependenceGroup(i for i in self if i.function is function)
826824

827825

828-
class Scope(CacheInstances):
829-
830-
# Describes a rule for dependencies
831-
Rule = Callable[[TimedAccess, TimedAccess], bool]
826+
class Scope:
832827

833-
@classmethod
834-
def _preprocess_args(cls, exprs: Expr | Iterable[Expr],
835-
**kwargs) -> tuple[tuple, dict]:
836-
return (as_tuple(exprs),), kwargs
837-
838-
def __init__(self, exprs: tuple[Expr],
839-
rules: Rule | tuple[Rule] | None = None) -> None:
828+
def __init__(self, exprs, rules=None):
840829
"""
841830
A Scope enables data dependence analysis on a totally ordered sequence
842831
of expressions.
843832
"""
844-
self.exprs = exprs
833+
self.exprs = as_tuple(exprs)
845834

846835
# A set of rules to drive the collection of dependencies
847-
self.rules: tuple[Scope.Rule] = as_tuple(rules) # type: ignore[assignment]
836+
self.rules = as_tuple(rules)
848837
assert all(callable(i) for i in self.rules)
849838

850839
@memoized_generator
@@ -1183,10 +1172,12 @@ def d_from_access_gen(self, accesses):
11831172
Generate all flow, anti, and output dependences involving any of
11841173
the given TimedAccess objects.
11851174
"""
1186-
accesses = set(as_tuple(accesses))
1175+
accesses = as_tuple(accesses)
11871176
for d in self.d_all_gen():
1188-
if accesses & {d.source, d.sink}:
1189-
yield d
1177+
for i in accesses:
1178+
if d.source == i or d.sink == i:
1179+
yield d
1180+
break
11901181

11911182
@memoized_meth
11921183
def d_from_access(self, accesses):

devito/operator/operator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
from devito.symbolics import estimate_cost, subs_op_args
3232
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3333
flatten, filter_sorted, frozendict, is_integer,
34-
split, timed_pass, timed_region, contains_val,
35-
CacheInstances)
34+
split, timed_pass, timed_region, contains_val)
3635
from devito.types import (Buffer, Evaluable, host_layer, device_layer,
3736
disk_layer)
3837
from devito.types.dimension import Thickness
@@ -246,9 +245,6 @@ def _build(cls, expressions, **kwargs):
246245
op._dtype, op._dspace = irs.clusters.meta
247246
op._profiler = profiler
248247

249-
# Clear build-scoped instance caches
250-
CacheInstances.clear_caches()
251-
252248
return op
253249

254250
def __init__(self, *args, **kwargs):

0 commit comments

Comments
 (0)