Skip to content

Commit ad99102

Browse files
authored
Merge pull request #2566 from devitocodes/cast-str
compiler: Allow str cast type
2 parents 34cdc53 + 2e4f4b0 commit ad99102

6 files changed

Lines changed: 53 additions & 7 deletions

File tree

devito/operator/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def add_glb_vanilla(self, key, time):
453453
if not self.input:
454454
return
455455

456-
ops = sum(v.ops for v in self.input.values())
456+
ops = sum(v.ops for v in self.input.values() if not np.isnan(v.ops))
457457
traffic = sum(v.traffic for v in self.input.values())
458458

459459
gflops = float(ops)/10**9

devito/symbolics/extended_sympy.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Extended SymPy hierarchy.
33
"""
4+
import re
45

56
import numpy as np
67
import sympy
@@ -394,13 +395,29 @@ def __new__(cls, base, dtype=None, stars=None, reinterpret=False, **kwargs):
394395
# E.g. void
395396
pass
396397

398+
dtype, stars = cls._process_dtype(dtype, stars)
399+
397400
obj = super().__new__(cls, base)
398401
obj._stars = stars or ''
399402
obj._dtype = dtype
400403
obj._reinterpret = reinterpret
401404

402405
return obj
403406

407+
@classmethod
408+
def _process_dtype(cls, dtype, stars):
409+
if not isinstance(dtype, str) or stars is not None:
410+
return dtype, stars
411+
412+
# String dtype, e.g. "float", "int*", "foo**"
413+
match = re.fullmatch(r'(\w+)\s*(\*+)?', dtype)
414+
if match:
415+
dtype = match.group(1)
416+
stars = match.group(2) or ''
417+
return dtype, stars
418+
else:
419+
return dtype, stars
420+
404421
def _hashable_content(self):
405422
return super()._hashable_content() + (self._stars,)
406423

@@ -429,7 +446,10 @@ def _C_ctype(self):
429446

430447
@property
431448
def _op(self):
432-
return f'({ctypes_to_cstr(self._C_ctype)})'
449+
cstr = ctypes_to_cstr(self._C_ctype)
450+
if self.stars:
451+
cstr = f"{cstr}{self.stars}"
452+
return f'({cstr})'
433453

434454
def __str__(self):
435455
return f"{self._op}{self.base}"

devito/tools/dtypes_lowering.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,10 @@ class c_restrict_void_p(ctypes.c_void_p):
250250

251251
def ctypes_to_cstr(ctype, toarray=None):
252252
"""Translate ctypes types into C strings."""
253-
if ctype in ctypes_vector_mapper.values():
253+
if isinstance(ctype, str):
254+
# Already a C string
255+
return ctype
256+
elif ctype in ctypes_vector_mapper.values():
254257
retval = ctype.__name__
255258
elif isinstance(ctype, CustomDtype):
256259
retval = str(ctype)

examples/seismic/tutorials/13_LSRTM_acoustic.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@
304304
},
305305
{
306306
"cell_type": "code",
307-
"execution_count": 6,
307+
"execution_count": null,
308308
"metadata": {},
309309
"outputs": [],
310310
"source": [
@@ -329,12 +329,13 @@
329329
" dm_true = (solver.model.vp.data**(-2) - model0.vp.data**(-2))\n",
330330
" \n",
331331
" objective = 0.\n",
332+
" u0 = None\n",
332333
" for i in range(nshots):\n",
333334
" \n",
334335
" #Observed Data using Born's operator\n",
335336
" geometry.src_positions[0, :] = source_locations[i, :]\n",
336337
"\n",
337-
" _, u0, _ = solver.forward(vp=model0.vp, save=True)\n",
338+
" _, u0, _ = solver.forward(vp=model0.vp, save=True, u=u0)\n",
338339
" \n",
339340
" _, _, _,_ = solver.jacobian(dm_true, vp=model0.vp, rec = d_obs)\n",
340341
" \n",

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ py-cpuinfo<10
66
cgen>=2020.1
77
codepy>=2019.1
88
click<9.0
9-
multidict
9+
multidict<6.3
1010
anytree>=2.4.3,<=2.12.1
1111
cloudpickle
1212
packaging

tests/test_symbolics.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def test_rvalue():
419419
assert str(Rvalue(ctype, ns, init)) == 'my::namespace::dummytype{}'
420420

421421

422-
def test_cast():
422+
def test_basecast():
423423
s = Symbol(name='s', dtype=np.float32)
424424

425425
class BarCast(BaseCast):
@@ -435,6 +435,28 @@ class BarCast(BaseCast):
435435
assert v != v1
436436

437437

438+
def test_str_cast():
439+
s = Symbol(name='s', dtype=np.float32)
440+
441+
v = Cast(s, 'foo')
442+
assert not v.stars
443+
assert v.dtype == 'foo'
444+
assert v._op == '(foo)'
445+
assert ccode(v) == '(foo)s'
446+
447+
v = Cast(s, 'foo*')
448+
assert v.stars == '*'
449+
assert v.dtype == 'foo'
450+
assert v._op == '(foo*)'
451+
assert ccode(v) == '(foo*)s'
452+
453+
v = Cast(s, 'foo **')
454+
assert v.stars == '**'
455+
assert v.dtype == 'foo'
456+
assert v._op == '(foo**)'
457+
assert ccode(v) == '(foo**)s'
458+
459+
438460
def test_findexed():
439461
grid = Grid(shape=(3, 3, 3))
440462
x, y, z = grid.dimensions

0 commit comments

Comments
 (0)