Skip to content

Commit 1e99432

Browse files
authored
Merge branch 'devitocodes:main' into master
2 parents 42e0e70 + 17ca3b9 commit 1e99432

10 files changed

Lines changed: 115 additions & 38 deletions

File tree

devito/operator/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def add_glb_vanilla(self, key, time):
453453
if not self.input:
454454
return
455455

456-
ops = sum(v.ops for v in self.input.values())
456+
ops = sum(v.ops for v in self.input.values() if not np.isnan(v.ops))
457457
traffic = sum(v.traffic for v in self.input.values())
458458

459459
gflops = float(ops)/10**9
Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import sympy
2-
31
from devito.ir import cluster_pass
4-
from devito.symbolics import reuse_if_untouched, q_leaf
5-
from devito.symbolics.unevaluation import Add, Mul, Pow
2+
from devito.symbolics import unevaluate as _unevaluate
63

74
__all__ = ['unevaluate']
85

@@ -12,22 +9,3 @@ def unevaluate(cluster):
129
exprs = [_unevaluate(e) for e in cluster.exprs]
1310

1411
return cluster.rebuild(exprs=exprs)
15-
16-
17-
mapper = {
18-
sympy.Add: Add,
19-
sympy.Mul: Mul,
20-
sympy.Pow: Pow
21-
}
22-
23-
24-
def _unevaluate(expr):
25-
if q_leaf(expr):
26-
return expr
27-
28-
args = [_unevaluate(a) for a in expr.args]
29-
30-
try:
31-
return mapper[expr.func](*args)
32-
except KeyError:
33-
return reuse_if_untouched(expr, args)

devito/passes/iet/definitions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from devito.passes.iet.engine import iet_pass
1717
from devito.passes.iet.langbase import LangBB
1818
from devito.symbolics import (Byref, DefFunction, FieldFromPointer, IndexedPointer,
19-
SizeOf, VOID, pow_to_mul)
19+
SizeOf, VOID, pow_to_mul, unevaluate)
2020
from devito.tools import as_mapper, as_list, as_tuple, filter_sorted, flatten
2121
from devito.types import (Array, ComponentAccess, CustomDimension, DeviceMap,
2222
DeviceRM, Eq, Symbol)
@@ -119,7 +119,7 @@ def _alloc_array_on_global_mem(self, site, obj, storage):
119119

120120
# Create input array
121121
name = '%s_init' % obj.name
122-
initvalue = np.array([pow_to_mul(i) for i in obj.initvalue])
122+
initvalue = np.array([unevaluate(pow_to_mul(i)) for i in obj.initvalue])
123123
src = Array(name=name, dtype=obj.dtype, dimensions=obj.dimensions,
124124
space='host', scope='stack', initvalue=initvalue)
125125

devito/symbolics/extended_sympy.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Extended SymPy hierarchy.
33
"""
4+
import re
45

56
import numpy as np
67
import sympy
@@ -426,13 +427,29 @@ def __new__(cls, base, dtype=None, stars=None, reinterpret=False, **kwargs):
426427
# E.g. void
427428
pass
428429

430+
dtype, stars = cls._process_dtype(dtype, stars)
431+
429432
obj = super().__new__(cls, base)
430433
obj._stars = stars or ''
431434
obj._dtype = dtype
432435
obj._reinterpret = reinterpret
433436

434437
return obj
435438

439+
@classmethod
440+
def _process_dtype(cls, dtype, stars):
441+
if not isinstance(dtype, str) or stars is not None:
442+
return dtype, stars
443+
444+
# String dtype, e.g. "float", "int*", "foo**"
445+
match = re.fullmatch(r'(\w+)\s*(\*+)?', dtype)
446+
if match:
447+
dtype = match.group(1)
448+
stars = match.group(2) or ''
449+
return dtype, stars
450+
else:
451+
return dtype, stars
452+
436453
def _hashable_content(self):
437454
return super()._hashable_content() + (self._stars,)
438455

@@ -461,7 +478,10 @@ def _C_ctype(self):
461478

462479
@property
463480
def _op(self):
464-
return f'({ctypes_to_cstr(self._C_ctype)})'
481+
cstr = ctypes_to_cstr(self._C_ctype)
482+
if self.stars:
483+
cstr = f"{cstr}{self.stars}"
484+
return f'({cstr})'
465485

466486
def __str__(self):
467487
return f"{self._op}{self.base}"

devito/symbolics/manipulation.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from devito.symbolics.extended_sympy import DefFunction, rfunc
1414
from devito.symbolics.queries import q_leaf
1515
from devito.symbolics.search import retrieve_indexed, retrieve_functions
16-
from devito.symbolics.unevaluation import Mul as UMul
16+
from devito.symbolics.unevaluation import (
17+
Add as UnevalAdd, Mul as UnevalMul, Pow as UnevalPow
18+
)
1719
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
1820
from devito.types.basic import Basic, Indexed
1921
from devito.types.array import ComponentAccess
@@ -22,7 +24,7 @@
2224

2325
__all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args',
2426
'normalize_args', 'uxreplace', 'Uxmapper', 'subs_if_composite',
25-
'reuse_if_untouched', 'evalrel', 'flatten_args']
27+
'reuse_if_untouched', 'evalrel', 'flatten_args', 'unevaluate']
2628

2729

2830
def uxreplace(expr, rule):
@@ -338,7 +340,7 @@ def pow_to_mul(expr):
338340
# but at least we traverse the base looking for other Pows
339341
return expr.func(pow_to_mul(base), exp, evaluate=False)
340342
elif exp > 0:
341-
return UMul(*[pow_to_mul(base)]*int(exp), evaluate=False)
343+
return UnevalMul(*[pow_to_mul(base)]*int(exp), evaluate=False)
342344
elif exp < 0:
343345
# Reciprocal powers become inverse of the negative power
344346
# for example Pow(expr, -2) becomes Pow(expr * expr, -1)
@@ -502,3 +504,18 @@ def evalrel(func=min, input=None, assumptions=None):
502504
except TypeError:
503505
pass
504506
return rfunc(func, *input)
507+
508+
509+
uneval_mapper = {Add: UnevalAdd, Mul: UnevalMul, Pow: UnevalPow}
510+
511+
512+
def unevaluate(expr):
513+
if q_leaf(expr):
514+
return expr
515+
516+
args = [unevaluate(a) for a in expr.args]
517+
518+
try:
519+
return uneval_mapper[expr.func](*args)
520+
except KeyError:
521+
return reuse_if_untouched(expr, args)

devito/tools/dtypes_lowering.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,10 @@ class c_restrict_void_p(ctypes.c_void_p):
256256

257257
def ctypes_to_cstr(ctype, toarray=None):
258258
"""Translate ctypes types into C strings."""
259-
if ctype in ctypes_vector_mapper.values():
259+
if isinstance(ctype, str):
260+
# Already a C string
261+
return ctype
262+
elif ctype in ctypes_vector_mapper.values():
260263
retval = ctype.__name__
261264
elif isinstance(ctype, CustomDtype):
262265
retval = str(ctype)

devito/types/array.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from sympy import Expr, cacheit
66

7-
from devito.tools import (Reconstructable, as_tuple, c_restrict_void_p,
7+
from devito.tools import (Pickable, as_tuple, c_restrict_void_p,
88
dtype_to_ctype, dtypes_vector_mapper, is_integer)
99
from devito.types.basic import AbstractFunction, LocalType
1010
from devito.types.utils import CtypesFactory, DimensionTuple
@@ -533,10 +533,11 @@ def handles(self):
533533
return self.components
534534

535535

536-
class ComponentAccess(Expr, Reconstructable):
536+
class ComponentAccess(Expr, Pickable):
537537

538538
_component_names = ('x', 'y', 'z', 'w')
539539

540+
__rargs__ = ('arg',)
540541
__rkwargs__ = ('index',)
541542

542543
def __new__(cls, arg, index=0, **kwargs):
@@ -558,7 +559,7 @@ def __str__(self):
558559

559560
__repr__ = __str__
560561

561-
func = Reconstructable._rebuild
562+
func = Pickable._rebuild
562563

563564
def _sympystr(self, printer):
564565
return str(self)
@@ -567,6 +568,10 @@ def _sympystr(self, printer):
567568
def base(self):
568569
return self.args[0]
569570

571+
@property
572+
def arg(self):
573+
return self.base
574+
570575
@property
571576
def index(self):
572577
return self._index

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ py-cpuinfo<10
66
cgen>=2020.1
77
codepy>=2019.1
88
click<9.0
9-
multidict
9+
multidict<6.3
1010
anytree>=2.4.3,<=2.12.1
1111
cloudpickle
1212
packaging

tests/test_pickle.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ctypes
22
import pickle as pickle0
3-
import cloudpickle as pickle1
43

4+
import cloudpickle as pickle1
55
import pytest
66
import numpy as np
77
from sympy import Symbol
@@ -19,7 +19,7 @@
1919
from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar,
2020
PointerArray, Lock, PThreadArray, SharedData, Timer,
2121
DeviceID, NPThreads, ThreadID, TempFunction, Indirection,
22-
FIndexed)
22+
FIndexed, ComponentAccess)
2323
from devito.types.basic import BoundSymbol, AbstractSymbol
2424
from devito.tools import EnrichedTuple
2525
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
@@ -416,6 +416,20 @@ def test_findexed(self, pickle):
416416
assert new_fi.indices == (x+1, y, z-2)
417417
assert new_fi.strides_map == fi.strides_map
418418

419+
def test_component_access(self, pickle):
420+
grid = Grid(shape=(3, 3, 3))
421+
x, y, z = grid.dimensions
422+
423+
f = Function(name='f', grid=grid)
424+
425+
ca = ComponentAccess(f.indexify(), 1)
426+
427+
pkl_ca = pickle.dumps(ca)
428+
new_ca = pickle.loads(pkl_ca)
429+
430+
assert new_ca.index == 1
431+
assert new_ca.function.name == f.name
432+
419433
def test_symbolics(self, pickle):
420434
a = Symbol('a')
421435

tests/test_symbolics.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def test_rvalue():
432432
assert str(Rvalue(ctype, ns, init)) == 'my::namespace::dummytype{}'
433433

434434

435-
def test_cast():
435+
def test_basecast():
436436
s = Symbol(name='s', dtype=np.float32)
437437

438438
class BarCast(BaseCast):
@@ -448,6 +448,28 @@ class BarCast(BaseCast):
448448
assert v != v1
449449

450450

451+
def test_str_cast():
452+
s = Symbol(name='s', dtype=np.float32)
453+
454+
v = Cast(s, 'foo')
455+
assert not v.stars
456+
assert v.dtype == 'foo'
457+
assert v._op == '(foo)'
458+
assert ccode(v) == '(foo)s'
459+
460+
v = Cast(s, 'foo*')
461+
assert v.stars == '*'
462+
assert v.dtype == 'foo'
463+
assert v._op == '(foo*)'
464+
assert ccode(v) == '(foo*)s'
465+
466+
v = Cast(s, 'foo **')
467+
assert v.stars == '**'
468+
assert v.dtype == 'foo'
469+
assert v._op == '(foo**)'
470+
assert ccode(v) == '(foo**)s'
471+
472+
451473
def test_findexed():
452474
grid = Grid(shape=(3, 3, 3))
453475
x, y, z = grid.dimensions
@@ -474,6 +496,24 @@ def test_findexed():
474496
assert new_fi.strides_map == strides_map
475497

476498

499+
def test_component_access():
500+
grid = Grid(shape=(3, 3, 3))
501+
x, y, z = grid.dimensions
502+
503+
f = Function(name='f', grid=grid)
504+
505+
cf0 = ComponentAccess(f.indexify(), 0)
506+
cf1 = ComponentAccess(f.indexify(), 1)
507+
508+
assert ccode(cf0) == 'f[x][y][z].x'
509+
assert ccode(cf1) == 'f[x][y][z].y'
510+
511+
# Reconstruction
512+
cf2 = cf1.func(*cf1.args)
513+
assert cf2.index == cf1.index
514+
assert cf2 == cf1
515+
516+
477517
def test_canonical_ordering_of_weights():
478518
grid = Grid(shape=(3, 3, 3))
479519
x, y, z = grid.dimensions

0 commit comments

Comments
 (0)