Skip to content

Commit 1e68909

Browse files
committed
compiler: Cache Scopes + Dependences, remove QueueStateful
1 parent 74de496 commit 1e68909

7 files changed

Lines changed: 203 additions & 86 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from devito.ir.equations import OpMin, OpMax, identity_mapper
1313
from devito.ir.clusters.analysis import analyze
1414
from devito.ir.clusters.cluster import Cluster, ClusterGroup
15-
from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass
15+
from devito.ir.clusters.visitors import Queue, cluster_pass
16+
from devito.ir.support import Scope
1617
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
1718
from devito.mpi.reduction_scheme import DistReduce
1819
from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace,
@@ -77,7 +78,7 @@ def impose_total_ordering(clusters):
7778
return processed
7879

7980

80-
class Schedule(QueueStateful):
81+
class Schedule(Queue):
8182

8283
"""
8384
This special Queue produces a new sequence of "scheduled" Clusters, which
@@ -135,7 +136,7 @@ def callback(self, clusters, prefix, backlog=None, known_break=None):
135136
# `clusters` are supposed to share it
136137
candidates = prefix[-1].dim._defines
137138

138-
scope = self._fetch_scope(clusters)
139+
scope = Scope(flatten(c.exprs for c in clusters))
139140

140141
# Handle the nastiest case -- ambiguity due to the presence of both a
141142
# flow- and an anti-dependence.

devito/ir/clusters/analysis.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,39 @@
1-
from devito.ir.clusters.visitors import QueueStateful
1+
from devito.ir.clusters.cluster import Cluster
2+
from devito.ir.clusters.visitors import Queue
23
from devito.ir.support import (AFFINE, PARALLEL, PARALLEL_INDEP, PARALLEL_IF_ATOMIC,
3-
SEQUENTIAL)
4+
SEQUENTIAL, Property, Scope)
5+
from devito.ir.support.space import IterationSpace
46
from devito.tools import as_tuple, flatten, timed_pass
7+
from devito.types.dimension import Dimension
58

69
__all__ = ['analyze']
710

811

12+
# Describes properties fetched by a `Detector`
13+
Properties = dict[Cluster, dict[Dimension, set[Property]]]
14+
15+
916
@timed_pass()
1017
def analyze(clusters):
11-
state = QueueStateful.State()
18+
properties: Properties = {}
1219

1320
# Collect properties
14-
clusters = Parallelism(state).process(clusters)
15-
clusters = Affiness(state).process(clusters)
21+
clusters = Parallelism().process(clusters, properties=properties)
22+
clusters = Affiness().process(clusters, properties=properties)
1623

1724
# Reconstruct Clusters attaching the discovered properties
18-
processed = [c.rebuild(properties=state.properties.get(c)) for c in clusters]
25+
processed = [c.rebuild(properties=properties.get(c)) for c in clusters]
1926

2027
return processed
2128

2229

23-
class Detector(QueueStateful):
30+
class Detector(Queue):
2431

25-
def process(self, elements):
26-
return self._process_fatd(elements, 1)
32+
def process(self, clusters: list[Cluster], properties: Properties) -> list[Cluster]:
33+
return self._process_fatd(clusters, 1, properties=properties)
2734

28-
def callback(self, clusters, prefix):
35+
def callback(self, clusters: list[Cluster], prefix: IterationSpace | None,
36+
properties: Properties) -> list[Cluster]:
2937
if not prefix:
3038
return clusters
3139

@@ -41,11 +49,19 @@ def callback(self, clusters, prefix):
4149
# Update `self.state`
4250
if retval:
4351
for c in clusters:
44-
properties = self.state.properties.setdefault(c, {})
45-
properties.setdefault(d, set()).update(retval)
52+
c_properties = properties.setdefault(c, {})
53+
c_properties.setdefault(d, set()).update(retval)
4654

4755
return clusters
4856

57+
def _callback(self, clusters: list[Cluster], dim: Dimension,
58+
prefix: IterationSpace | None) -> set[Property]:
59+
"""
60+
Callback to be implemented by subclasses. It should return a set of
61+
properties for the given dimension.
62+
"""
63+
raise NotImplementedError()
64+
4965

5066
class Parallelism(Detector):
5167

@@ -72,27 +88,27 @@ class Parallelism(Detector):
7288
the 'write' is known to be an associative and commutative increment
7389
"""
7490

75-
def _callback(self, clusters, d, prefix):
91+
def _callback(self, clusters, dim, prefix):
7692
# Rule out if non-unitary increment Dimension (e.g., `t0=(time+1)%2`)
77-
if any(c.sub_iterators[d] for c in clusters):
78-
return SEQUENTIAL
93+
if any(c.sub_iterators[dim] for c in clusters):
94+
return {SEQUENTIAL}
7995

8096
# All Dimensions up to and including `i-1`
8197
prev = flatten(i.dim._defines for i in prefix[:-1])
8298

8399
is_parallel_indep = True
84100
is_parallel_atomic = False
85101

86-
scope = self._fetch_scope(clusters)
102+
scope = Scope(flatten(c.exprs for c in clusters))
87103
for dep in scope.d_all_gen():
88-
test00 = dep.is_indep(d) and not dep.is_storage_related(d)
104+
test00 = dep.is_indep(dim) and not dep.is_storage_related(dim)
89105
test01 = all(dep.is_reduce_atmost(i) for i in prev)
90106
if test00 and test01:
91107
continue
92108

93109
test1 = len(prev) > 0 and any(dep.is_carried(i) for i in prev)
94110
if test1:
95-
is_parallel_indep &= (dep.distance_mapper.get(d.root) == 0)
111+
is_parallel_indep &= (dep.distance_mapper.get(dim.root) == 0)
96112
continue
97113

98114
if dep.function in scope.initialized:
@@ -103,14 +119,14 @@ def _callback(self, clusters, d, prefix):
103119
is_parallel_atomic = True
104120
continue
105121

106-
return SEQUENTIAL
122+
return {SEQUENTIAL}
107123

108124
if is_parallel_atomic:
109-
return PARALLEL_IF_ATOMIC
125+
return {PARALLEL_IF_ATOMIC}
110126
elif is_parallel_indep:
111127
return {PARALLEL, PARALLEL_INDEP}
112128
else:
113-
return PARALLEL
129+
return {PARALLEL}
114130

115131

116132
class Affiness(Detector):
@@ -119,8 +135,11 @@ class Affiness(Detector):
119135
Detect the AFFINE Dimensions.
120136
"""
121137

122-
def _callback(self, clusters, d, prefix):
123-
scope = self._fetch_scope(clusters)
138+
def _callback(self, clusters, dim, prefix):
139+
scope = Scope(flatten(c.exprs for c in clusters))
124140
accesses = [a for a in scope.accesses if not a.is_scalar]
125-
if all(a.is_regular and a.affine_if_present(d._defines) for a in accesses):
126-
return AFFINE
141+
142+
if all(a.is_regular and a.affine_if_present(dim._defines) for a in accesses):
143+
return {AFFINE}
144+
145+
return set()

devito/ir/clusters/visitors.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from collections import defaultdict
21
from collections.abc import Iterable
32

43
from itertools import groupby
54

6-
from devito.ir.support import IterationSpace, Scope, null_ispace
7-
from devito.tools import as_tuple, flatten, timed_pass
5+
from devito.ir.support import IterationSpace, null_ispace
6+
from devito.tools import flatten, timed_pass
87

9-
__all__ = ['Queue', 'QueueStateful', 'cluster_pass']
8+
__all__ = ['Queue', 'cluster_pass']
109

1110

1211
class Queue:
@@ -113,48 +112,6 @@ def _process_fatd(self, clusters, level, prefix=None, **kwargs):
113112
return processed
114113

115114

116-
class QueueStateful(Queue):
117-
118-
"""
119-
A Queue carrying along some state. This is useful when one wants to avoid
120-
expensive re-computations of information.
121-
"""
122-
123-
class State:
124-
125-
def __init__(self):
126-
self.properties = {}
127-
self.scopes = {}
128-
129-
def __init__(self, state=None):
130-
super().__init__()
131-
self.state = state or QueueStateful.State()
132-
133-
def _fetch_scope(self, clusters):
134-
exprs = flatten(c.exprs for c in as_tuple(clusters))
135-
key = tuple(exprs)
136-
if key not in self.state.scopes:
137-
self.state.scopes[key] = Scope(exprs)
138-
return self.state.scopes[key]
139-
140-
def _fetch_properties(self, clusters, prefix):
141-
# If the situation is:
142-
#
143-
# t
144-
# x0
145-
# <some clusters>
146-
# x1
147-
# <some other clusters>
148-
#
149-
# then retain only the "common" properties, that is those along `t`
150-
properties = defaultdict(set)
151-
for c in clusters:
152-
v = self.state.properties.get(c, {})
153-
for i in prefix:
154-
properties[i.dim].update(v.get(i.dim, set()))
155-
return properties
156-
157-
158115
class Prefix(IterationSpace):
159116

160117
def __init__(self, ispace, guards, properties, syncs):

devito/ir/support/basic.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from collections.abc import Iterable
12
from itertools import chain, product
23
from functools import cached_property
4+
from typing import Callable
35

4-
from sympy import S
6+
from sympy import S, Expr
57
import sympy
68

79
from devito.ir.support.space import Backward, null_ispace
@@ -12,7 +14,7 @@
1214
uxreplace)
1315
from devito.tools import (Tag, as_mapper, as_tuple, is_integer, filter_sorted,
1416
flatten, memoized_meth, memoized_generator, smart_gt,
15-
smart_lt)
17+
smart_lt, CacheInstances)
1618
from devito.types import (ComponentAccess, Dimension, DimensionTuple, Fence,
1719
CriticalRegion, Function, Symbol, Temp, TempArray,
1820
TBArray)
@@ -624,7 +626,7 @@ def is_imaginary(self):
624626
return S.ImaginaryUnit in self.distance
625627

626628

627-
class Dependence(Relation):
629+
class Dependence(Relation, CacheInstances):
628630

629631
"""
630632
A data dependence between two TimedAccess objects.
@@ -823,17 +825,26 @@ def project(self, function):
823825
return DependenceGroup(i for i in self if i.function is function)
824826

825827

826-
class Scope:
828+
class Scope(CacheInstances):
827829

828-
def __init__(self, exprs, rules=None):
830+
# Describes a rule for dependencies
831+
Rule = Callable[[TimedAccess, TimedAccess], bool]
832+
833+
@classmethod
834+
def _preprocess_args(cls, exprs: Expr | Iterable[Expr],
835+
**kwargs) -> tuple[tuple, dict]:
836+
return (as_tuple(exprs),), kwargs
837+
838+
def __init__(self, exprs: tuple[Expr],
839+
rules: Rule | tuple[Rule] | None = None) -> None:
829840
"""
830841
A Scope enables data dependence analysis on a totally ordered sequence
831842
of expressions.
832843
"""
833-
self.exprs = as_tuple(exprs)
844+
self.exprs = exprs
834845

835846
# A set of rules to drive the collection of dependencies
836-
self.rules = as_tuple(rules)
847+
self.rules: tuple[Scope.Rule] = as_tuple(rules) # type: ignore[assignment]
837848
assert all(callable(i) for i in self.rules)
838849

839850
@memoized_generator

devito/operator/operator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
from devito.symbolics import estimate_cost, subs_op_args
3232
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3333
flatten, filter_sorted, frozendict, is_integer,
34-
split, timed_pass, timed_region, contains_val)
34+
split, timed_pass, timed_region, contains_val,
35+
CacheInstances)
3536
from devito.types import (Buffer, Evaluable, host_layer, device_layer,
3637
disk_layer)
3738
from devito.types.dimension import Thickness
@@ -245,6 +246,9 @@ def _build(cls, expressions, **kwargs):
245246
op._dtype, op._dspace = irs.clusters.meta
246247
op._profiler = profiler
247248

249+
# Clear build-scoped instance caches
250+
CacheInstances.clear_caches()
251+
248252
return op
249253

250254
def __init__(self, *args, **kwargs):

devito/tools/memoization.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from collections.abc import Hashable
2-
from functools import partial
1+
from collections.abc import Callable, Hashable
2+
from functools import lru_cache, partial
33
from itertools import tee
4+
from typing import TypeVar
45

5-
__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator']
6+
__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator', 'CacheInstances']
67

78

89
class memoized_func:
@@ -125,3 +126,63 @@ def __call__(self, *args, **kwargs):
125126
it = cache[key] if key in cache else self.func(*args, **kwargs)
126127
cache[key], result = tee(it)
127128
return result
129+
130+
131+
# Describes the type of a subclass of CacheInstances
132+
InstanceType = TypeVar('InstanceType', bound='CacheInstances', covariant=True)
133+
134+
135+
class CacheInstancesMeta(type):
136+
"""
137+
Metaclass to wrap construction in an LRU cache.
138+
"""
139+
140+
_cached_types: set[type['CacheInstances']] = set()
141+
142+
def __init__(cls: type[InstanceType], *args) -> None: # type: ignore
143+
super().__init__(*args)
144+
145+
# Register the cached type
146+
CacheInstancesMeta._cached_types.add(cls)
147+
148+
def __call__(cls: type[InstanceType], # type: ignore
149+
*args, **kwargs) -> InstanceType:
150+
if cls._instance_cache is None:
151+
maxsize = cls._instance_cache_size
152+
cls._instance_cache = lru_cache(maxsize=maxsize)(super().__call__)
153+
154+
args, kwargs = cls._preprocess_args(*args, **kwargs)
155+
return cls._instance_cache(*args, **kwargs)
156+
157+
@classmethod
158+
def clear_caches(cls: type['CacheInstancesMeta']) -> None:
159+
"""
160+
Clear all caches for classes using this metaclass.
161+
"""
162+
for cached_type in cls._cached_types:
163+
if cached_type._instance_cache is not None:
164+
cached_type._instance_cache.cache_clear()
165+
166+
167+
class CacheInstances(metaclass=CacheInstancesMeta):
168+
"""
169+
Parent class that wraps construction in an LRU cache.
170+
"""
171+
172+
_instance_cache: Callable | None = None
173+
_instance_cache_size: int = 128
174+
175+
@classmethod
176+
def _preprocess_args(cls, *args, **kwargs):
177+
"""
178+
Preprocess the arguments before caching. This can be overridden in subclasses
179+
to customize argument handling (e.g. to convert to hashable types).
180+
"""
181+
return args, kwargs
182+
183+
@staticmethod
184+
def clear_caches() -> None:
185+
"""
186+
Clears all IR instance caches.
187+
"""
188+
CacheInstancesMeta.clear_caches()

0 commit comments

Comments
 (0)