Skip to content

Commit 1df126b

Browse files
committed
compiler: Add ThreadArrive and TensorMove
1 parent 3f7f6dc commit 1df126b

4 files changed

Lines changed: 69 additions & 42 deletions

File tree

devito/ir/clusters/cluster.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from devito.symbolics import estimate_cost
1717
from devito.tools import as_tuple, filter_ordered, flatten, infer_dtype
1818
from devito.types import (
19-
CriticalRegion, Fence, ThreadCommit, ThreadPoolSync, ThreadWait, WeakFence
19+
CriticalRegion, Fence, Indexed, TensorMove, ThreadArrive, ThreadCommit,
20+
ThreadPoolSync, ThreadWait, WeakFence
2021
)
2122

2223
__all__ = ["Cluster", "ClusterGroup"]
@@ -310,14 +311,43 @@ def is_critical_region(self):
310311
def is_thread_pool_sync(self):
311312
return self._is_type(ThreadPoolSync)
312313

314+
@cached_property
315+
def is_shm_write(self):
316+
return all(w._mem_shared for w in self.scope.writes)
317+
313318
@cached_property
314319
def is_thread_commit(self):
315320
return self._is_type(ThreadCommit)
316321

322+
@cached_property
323+
def is_thread_arrive(self):
324+
return self._is_type(ThreadArrive)
325+
317326
@cached_property
318327
def is_thread_wait(self):
319328
return self._is_type(ThreadWait)
320329

330+
@cached_property
331+
def is_word_move(self):
332+
return (self._is_type(Indexed) and
333+
all(e.rhs.function._mem_heap for e in self.exprs))
334+
335+
@cached_property
336+
def is_tensor_move(self):
337+
return self._is_type(TensorMove)
338+
339+
@cached_property
340+
def is_word_move_to_mem_shared(self):
341+
return self.is_word_move and self.is_shm_write
342+
343+
@cached_property
344+
def is_tensor_move_to_mem_shared(self):
345+
return self.is_tensor_move and self.is_shm_write
346+
347+
@cached_property
348+
def is_glb_load_to_mem_shared(self):
349+
return self.is_word_move_to_mem_shared or self.is_tensor_move_to_mem_shared
350+
321351
@cached_property
322352
def is_async(self):
323353
"""
@@ -557,6 +587,10 @@ def dspace(self):
557587
def is_halo_touch(self):
558588
return all(i.is_halo_touch for i in self)
559589

590+
@cached_property
591+
def is_glb_load_to_mem_shared(self):
592+
return all(i.is_glb_load_to_mem_shared for i in self)
593+
560594
@cached_property
561595
def dtype(self):
562596
"""

devito/ir/support/properties.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,6 @@ def __init__(self, name, val=None):
9797
A Dimension along which prefetching is feasible and beneficial.
9898
"""
9999

100-
PREFETCHABLE_SHM = Property('prefetchable-shm')
101-
"""
102-
A Dimension along which shared-memory prefetching is feasible and beneficial.
103-
"""
104-
105100
INIT_CORE_SHM = Property('init-core-shm')
106101
"""
107102
A Dimension along which the shared-memory CORE data region is initialized.
@@ -190,32 +185,6 @@ def update_properties(properties, exprs):
190185
if not exprs:
191186
return properties
192187

193-
# Auto-detect prefetchable Dimensions
194-
dims = set()
195-
flag = False
196-
for e in as_tuple(exprs):
197-
w, r = e.args
198-
199-
# Ensure it's in the form `Indexed = Indexed`
200-
try:
201-
wf, rf = w.function, r.function
202-
except AttributeError:
203-
break
204-
205-
if not rf or not wf._mem_shared:
206-
break
207-
dims.update({d.parent for d in wf.dimensions if d.parent in properties})
208-
209-
if not rf._mem_heap:
210-
break
211-
else:
212-
flag = True
213-
214-
if flag:
215-
properties = properties.prefetchable_shm(dims)
216-
else:
217-
properties = properties.drop(properties=PREFETCHABLE_SHM)
218-
219188
# Remove properties that are trivially incompatible with `exprs`
220189
if not all(e.lhs.function._mem_shared for e in as_tuple(exprs)):
221190
drop = {INIT_CORE_SHM, INIT_HALO_LEFT_SHM, INIT_HALO_RIGHT_SHM}
@@ -284,9 +253,6 @@ def prefetchable(self, dims, v=PREFETCHABLE):
284253
m[d] = self.get(d, set()) | {v}
285254
return Properties(m)
286255

287-
def prefetchable_shm(self, dims):
288-
return self.prefetchable(dims, PREFETCHABLE_SHM)
289-
290256
def block(self, dims, kind='default'):
291257
if kind == 'default':
292258
p = TILABLE
@@ -357,9 +323,6 @@ def _is_property_any(self, dims, v):
357323
def is_prefetchable(self, dims=None, v=PREFETCHABLE):
358324
return self._is_property_any(dims, PREFETCHABLE)
359325

360-
def is_prefetchable_shm(self, dims=None):
361-
return self._is_property_any(dims, PREFETCHABLE_SHM)
362-
363326
def is_core_init(self, dims=None):
364327
return self._is_property_any(dims, INIT_CORE_SHM)
365328

devito/passes/clusters/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _key(self, c):
232232
weak.append(c.properties.is_core_init())
233233

234234
# Prefetchable Clusters should get merged, if possible
235-
weak.append(c.properties.is_prefetchable_shm())
235+
weak.append(c.is_glb_load_to_mem_shared)
236236

237237
# Promoting adjacency of IndexDerivatives will maximize their reuse
238238
weak.append(any(search(c.exprs, IndexDerivative)))

devito/types/parallel.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from functools import cached_property
1212

1313
import numpy as np
14+
from sympy import Expr
1415

1516
from devito.exceptions import InvalidArgument
1617
from devito.parameters import configuration
17-
from devito.symbolics import search
18+
from devito.symbolics import Reserved, Terminal, search
1819
from devito.tools import as_list, as_tuple, is_integer
1920
from devito.types.array import Array, ArrayObject
2021
from devito.types.basic import Scalar, Symbol
@@ -35,7 +36,9 @@
3536
'QueueID',
3637
'SharedData',
3738
'TBArray',
39+
'TensorMove',
3840
'ThreadArray',
41+
'ThreadArrive',
3942
'ThreadCommit',
4043
'ThreadID',
4144
'ThreadPoolSync',
@@ -365,12 +368,24 @@ class ThreadCommit(Fence):
365368
pass
366369

367370

371+
class ThreadArrive(Fence):
372+
373+
"""
374+
A generic arrive operation for a single thread, typically used to signal
375+
the arrival at a certain point through a suitable synchronization object.
376+
"""
377+
378+
pass
379+
380+
368381
class ThreadWait(Fence):
369382

370383
"""
371384
A generic wait operation for a single thread, typically used to synchronize
372-
after a memory operation issued at a specific program point with a
373-
ThreadCommit operation.
385+
with other threads over:
386+
387+
* a memory operation issued by a prior ThreadCommit operation.
388+
* the consumption of a shared resource via a ThreadArrive operation.
374389
"""
375390

376391
pass
@@ -386,3 +401,18 @@ def __init_finalize__(self, *args, **kwargs):
386401
kwargs['liveness'] = 'eager'
387402

388403
super().__init_finalize__(*args, **kwargs)
404+
405+
406+
class TensorMove(Expr, Reserved, Terminal):
407+
408+
"""
409+
Represent the LOAD/STORE of a multi-dimensional block of data from/to a higher
410+
level of the memory hierarchy
411+
"""
412+
413+
func = Reserved._rebuild
414+
415+
def _ccode(self, printer):
416+
return str(self)
417+
418+
_sympystr = _ccode

0 commit comments

Comments
 (0)