Skip to content

Commit abc1879

Browse files
mlouboutFabioLuporini
authored andcommitted
compiler: prevent undefined Temp through global init
1 parent 28ded5a commit abc1879

6 files changed

Lines changed: 152 additions & 31 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
CustomDimension, Eq, Hyperplane, IncrDimension, Indexed, ModuloDimension, Size,
2626
StencilDimension, Symbol, Temp, TempArray, TempFunction
2727
)
28+
from devito.types.dimension import ConditionalDimension, SubsamplingFactor
2829
from devito.types.grid import MultiSubDimension
2930

3031
__all__ = ['cire']
@@ -923,15 +924,25 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
923924
callback = lambda idx: obj[ # noqa: B023
924925
[i + s for i, s in zip(idx, shift, strict=True)] # noqa: B023
925926
]
927+
guards = meta.guards
926928
else:
927929
# Degenerate case: scalar expression
928930
assert writeto.size == 0
929931

930-
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
931-
obj = Temp(name=name, dtype=dtype)
932-
expression = Eq(obj, uxreplace(pivot, subs))
932+
guards = None
933+
is_cond = any(isinstance(d, (SubsamplingFactor, ConditionalDimension))
934+
for d in pivot.free_symbols)
935+
if meta.guards and is_cond:
936+
# Scalar alias that depends on a guard, unsafe to lift out of the guard
937+
# Do not alias
938+
expression = None
939+
callback = lambda idx: uxreplace(pivot, subs) # noqa: B023
940+
else:
941+
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
942+
obj = Temp(name=name, dtype=dtype)
943+
expression = Eq(obj, uxreplace(pivot, subs))
933944

934-
callback = lambda idx: obj # noqa: B023
945+
callback = lambda idx: obj # noqa: B023
935946

936947
# Create the substitution rules for the aliasing expressions
937948
subs.update({
@@ -957,7 +968,8 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
957968
properties[Hyperplane(writeto.itdims)] = {SEPARABLE}
958969

959970
# Finally, build the alias Cluster
960-
clusters.append(Cluster(expression, ispace, meta.guards, properties))
971+
if expression is not None:
972+
clusters.append(Cluster(expression, ispace, guards, properties))
961973

962974
return clusters, subs
963975

devito/passes/clusters/cse.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def cse_dtype(exprdtype, cdtype):
4040
Return the dtype of a CSE temporary given the dtype of the expression to be
4141
captured and the cluster's dtype.
4242
"""
43+
if np.issubdtype(cdtype, np.floating) and np.issubdtype(exprdtype, np.integer):
44+
# Integer expression and floating-point cluster: promote to the floating point
45+
# np.promote_types upcast integers (e.g int32 -> Float64) so we
46+
# need to ensure that the promoted type is not larger than the cluster's dtype
47+
return cdtype
48+
4349
if np.issubdtype(cdtype, np.complexfloating):
4450
return np.promote_types(exprdtype, cdtype(0).real.__class__).type
4551
else:
@@ -97,8 +103,9 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
97103
if cluster.is_fence:
98104
return cluster
99105

100-
make_dtype = lambda e: cse_dtype(e.dtype, dtype)
101-
make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e))
106+
def make(e):
107+
edtype = cse_dtype(e.dtype, dtype)
108+
return CTemp(name=sregistry.make_name(), dtype=edtype)
102109

103110
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)
104111

devito/types/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def _rebuild(self, *args, **kwargs):
15891589
comps = [f.func(*args, name=f.name.replace(self.name, newname), **kwargs)
15901590
for f in self.flat()]
15911591
# Rebuild the matrix with the new components
1592-
return self._new(comps)
1592+
return self._new(*self.shape, comps)
15931593

15941594
func = _rebuild
15951595

examples/mpi/overview.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@
460460
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
461461
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
462462
"\n",
463-
" float r0 = 1.0F/h_x;\n",
463+
" const float r0 = 1.0F/h_x;\n",
464464
"\n",
465465
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
466466
" {\n",

examples/performance/00_overview.ipynb

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@
508508
"}\n",
509509
"STOP(section0,timers)\n",
510510
"\n",
511-
"float r1 = 1.0F/h_y;\n",
511+
"const float r1 = 1.0F/h_y;\n",
512512
"\n",
513513
"for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
514514
"{\n",
@@ -572,7 +572,13 @@
572572
"+ u[t1][x + 4][y + 4][z + 4] = (f[x + 1][y + 1][z + 1]*f[x + 1][y + 1][z + 1])*((-6.66666667e-1F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y + 1][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 2][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 4][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 5][z + 4]) + (-8.33333333e-2F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y + 4][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 5][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 7][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 8][z + 4]) + (8.33333333e-2F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 1][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 3][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 4][z + 4]) + (6.66666667e-1F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y + 3][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 4][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 6][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 7][z + 4]))*sinf(f[x + 1][y + 1][z + 1]);\n",
573573
" }\n",
574574
" }\n",
575-
" }\n",
575+
" }\n"
576+
]
577+
},
578+
{
579+
"name": "stdout",
580+
"output_type": "stream",
581+
"text": [
576582
"\n"
577583
]
578584
}
@@ -652,7 +658,7 @@
652658
},
653659
{
654660
"cell_type": "code",
655-
"execution_count": 13,
661+
"execution_count": 12,
656662
"metadata": {},
657663
"outputs": [
658664
{
@@ -712,7 +718,7 @@
712718
},
713719
{
714720
"cell_type": "code",
715-
"execution_count": 14,
721+
"execution_count": 13,
716722
"metadata": {},
717723
"outputs": [
718724
{
@@ -753,7 +759,14 @@
753759
" }\n",
754760
" }\n",
755761
" STOP(section0,timers)\n",
756-
"}\n"
762+
"}"
763+
]
764+
},
765+
{
766+
"name": "stdout",
767+
"output_type": "stream",
768+
"text": [
769+
"\n"
757770
]
758771
}
759772
],
@@ -772,7 +785,7 @@
772785
},
773786
{
774787
"cell_type": "code",
775-
"execution_count": 15,
788+
"execution_count": 14,
776789
"metadata": {},
777790
"outputs": [
778791
{
@@ -863,7 +876,7 @@
863876
},
864877
{
865878
"cell_type": "code",
866-
"execution_count": 16,
879+
"execution_count": 15,
867880
"metadata": {},
868881
"outputs": [
869882
{
@@ -919,7 +932,7 @@
919932
},
920933
{
921934
"cell_type": "code",
922-
"execution_count": 17,
935+
"execution_count": 16,
923936
"metadata": {},
924937
"outputs": [
925938
{
@@ -976,7 +989,7 @@
976989
},
977990
{
978991
"cell_type": "code",
979-
"execution_count": 18,
992+
"execution_count": 17,
980993
"metadata": {},
981994
"outputs": [],
982995
"source": [
@@ -994,7 +1007,7 @@
9941007
},
9951008
{
9961009
"cell_type": "code",
997-
"execution_count": 19,
1010+
"execution_count": 18,
9981011
"metadata": {},
9991012
"outputs": [
10001013
{
@@ -1044,7 +1057,7 @@
10441057
},
10451058
{
10461059
"cell_type": "code",
1047-
"execution_count": 20,
1060+
"execution_count": 19,
10481061
"metadata": {},
10491062
"outputs": [
10501063
{
@@ -1112,7 +1125,7 @@
11121125
},
11131126
{
11141127
"cell_type": "code",
1115-
"execution_count": 21,
1128+
"execution_count": 20,
11161129
"metadata": {},
11171130
"outputs": [
11181131
{
@@ -1192,7 +1205,7 @@
11921205
" }\n",
11931206
" STOP(section0,timers)\n",
11941207
"\n",
1195-
" float r1 = 1.0F/h_y;\n",
1208+
" const float r1 = 1.0F/h_y;\n",
11961209
"\n",
11971210
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
11981211
" {\n",
@@ -1279,7 +1292,7 @@
12791292
},
12801293
{
12811294
"cell_type": "code",
1282-
"execution_count": 22,
1295+
"execution_count": 21,
12831296
"metadata": {},
12841297
"outputs": [
12851298
{
@@ -1304,7 +1317,7 @@
13041317
"}\n",
13051318
"STOP(section0,timers)\n",
13061319
"\n",
1307-
"float r1 = 1.0F/h_y;\n",
1320+
"const float r1 = 1.0F/h_y;\n",
13081321
"\n",
13091322
"for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
13101323
"{\n",
@@ -1369,7 +1382,7 @@
13691382
},
13701383
{
13711384
"cell_type": "code",
1372-
"execution_count": 23,
1385+
"execution_count": 22,
13731386
"metadata": {},
13741387
"outputs": [
13751388
{
@@ -1404,7 +1417,7 @@
14041417
},
14051418
{
14061419
"cell_type": "code",
1407-
"execution_count": 24,
1420+
"execution_count": 23,
14081421
"metadata": {},
14091422
"outputs": [
14101423
{
@@ -1483,7 +1496,7 @@
14831496
" }\n",
14841497
" STOP(section0,timers)\n",
14851498
"\n",
1486-
" float r1 = 1.0F/h_y;\n",
1499+
" const float r1 = 1.0F/h_y;\n",
14871500
"\n",
14881501
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
14891502
" {\n",
@@ -1546,7 +1559,7 @@
15461559
},
15471560
{
15481561
"cell_type": "code",
1549-
"execution_count": 25,
1562+
"execution_count": 24,
15501563
"metadata": {},
15511564
"outputs": [
15521565
{
@@ -1624,8 +1637,8 @@
16241637
" }\n",
16251638
" STOP(section0,timers)\n",
16261639
"\n",
1627-
" float r1 = 1.0F/h_x;\n",
1628-
" float r2 = 1.0F/h_y;\n",
1640+
" const float r1 = 1.0F/h_x;\n",
1641+
" const float r2 = 1.0F/h_y;\n",
16291642
"\n",
16301643
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
16311644
" {\n",
@@ -1718,7 +1731,7 @@
17181731
"name": "python",
17191732
"nbconvert_exporter": "python",
17201733
"pygments_lexer": "ipython3",
1721-
"version": "3.10.12"
1734+
"version": "3.13.11"
17221735
}
17231736
},
17241737
"nbformat": 4,

tests/test_dse.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from devito.tools import as_tuple
3030
from devito.types import PrecomputedSparseTimeFunction, Scalar, Symbol
31+
from devito.types.misc import Temp
3132
from examples.seismic import AcquisitionGeometry, demo_model
3233
from examples.seismic.acoustic import AcousticWaveSolver
3334
from examples.seismic.tti import AnisotropicWaveSolver
@@ -2685,6 +2686,94 @@ def test_space_and_time_invariant_together(self):
26852686
'tx0_blk0y0_blk0xyzyz'
26862687
)
26872688

2689+
def test_split_cond(self):
2690+
grid = Grid((11, 11))
2691+
time = grid.time_dim
2692+
2693+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2694+
2695+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2696+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2697+
2698+
eq0 = Eq(u.forward, u + cos(time), implicit_dims=ct)
2699+
eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2)
2700+
eq2 = Eq(u.forward, u.forward + cos(time), implicit_dims=ct)
2701+
2702+
op = Operator([eq0, eq1, eq2])
2703+
cond = FindNodes(Conditional).visit(op)
2704+
assert len(cond) == 3
2705+
# The alias should have been lifted out of the condition
2706+
assert 'float r0 = cos(time);' in str(op.body.body[0])
2707+
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2708+
assert len(scalars) == 1
2709+
2710+
def test_split_cond_multi_alias(self):
2711+
grid = Grid((11, 11))
2712+
time = grid.time_dim
2713+
2714+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2715+
2716+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2717+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2718+
2719+
eq0 = Eq(u.forward, u + cos(time) + sin(time), implicit_dims=ct)
2720+
eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2)
2721+
eq2 = Eq(u.forward, u.forward + cos(time) - sin(time), implicit_dims=ct)
2722+
2723+
op = Operator([eq0, eq1, eq2])
2724+
cond = FindNodes(Conditional).visit(op)
2725+
assert len(cond) == 3
2726+
# The alias should have been lifted out of the condition
2727+
assert 'float r3 = cos(time);' in str(op.body.body[0])
2728+
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2729+
assert len(scalars) == 5
2730+
2731+
def test_multi_cond_no_split(self):
2732+
grid = Grid((11, 11))
2733+
time = grid.time_dim
2734+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2735+
2736+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2737+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2738+
2739+
eq0 = Eq(u.forward, u + cos(time), implicit_dims=ct)
2740+
# Not hoisting the inite would have this equation split the time
2741+
# loop to initialize the alias for sin(time)
2742+
eq1 = Eq(u.forward, u.forward + sin(time), implicit_dims=ct2)
2743+
eq2 = Eq(u.forward, u.forward - sin(time), implicit_dims=ct)
2744+
2745+
op = Operator([eq0, eq1, eq2])
2746+
2747+
assert_structure(
2748+
op,
2749+
['t', 't,x,y', 't,x,y', 't,x,y'],
2750+
'txyxyxy'
2751+
)
2752+
2753+
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2754+
assert len(scalars) == 4
2755+
2756+
def test_alias_with_conditional(self):
2757+
grid = Grid((11, 11))
2758+
time = grid.time_dim
2759+
2760+
u = TimeFunction(name='u', grid=grid, time_order=2, space_order=2)
2761+
2762+
ct = ConditionalDimension(name='ct', parent=time, factor=2)
2763+
ct2 = ConditionalDimension(name='ct2', parent=time, factor=4)
2764+
2765+
eq0 = Eq(u.forward, u + cos(ct), implicit_dims=ct)
2766+
eq1 = Eq(u.forward, u.forward + 1, implicit_dims=ct2)
2767+
eq2 = Eq(u.forward, u.forward + cos(ct), implicit_dims=ct)
2768+
2769+
op = Operator([eq0, eq1, eq2])
2770+
cond = FindNodes(Conditional).visit(op)
2771+
assert len(cond) == 3
2772+
2773+
# The alias should have been lifted out of the condition
2774+
scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)]
2775+
assert not scalars
2776+
26882777

26892778
class TestIsoAcoustic:
26902779

0 commit comments

Comments
 (0)