Skip to content

Commit 04373ef

Browse files
authored
Merge pull request #2858 from devitocodes/alias-multi-cond
compiler: Improve scalar aliases detection and scheduling
2 parents 28ded5a + eb3b4ea commit 04373ef

12 files changed

Lines changed: 272 additions & 55 deletions

File tree

conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,9 @@ def check_array(array, exp_halo, exp_shape, rotate=False):
479479

480480
assert tuple(array.halo) == exp_halo
481481
assert tuple(shape) == tuple(exp_shape)
482+
483+
484+
# Main body in Operator IET, depending on ISA
485+
def body0(op):
486+
bidx = 0 if 'sse' not in configuration['platform'].known_isas else 1
487+
return op.body.body[bidx]

devito/passes/clusters/aliases.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
126126
for mapper in self._generate(cgroup, exclude):
127127
# Clusters -> AliasList
128128
found = collect(mapper.extracted, meta.ispace, self.opt_minstorage)
129+
if not found:
130+
continue
129131
exprs, aliases = self._choose(found, cgroup, mapper)
130132

131133
# AliasList -> Schedule
@@ -271,14 +273,15 @@ def _do_generate(self, exprs, exclude, cbk_search, cbk_compose=None):
271273
free_symbols = i.free_symbols
272274
if {a.function for a in free_symbols} & exclude:
273275
continue
274-
275276
mapper.add(i, make, terms)
276277

277278
return mapper
278279

279280

280281
class CireInvariants(CireTransformerLegacy, Queue):
281282

283+
_q_guards_in_key = True
284+
282285
def __init__(self, sregistry, options, platform):
283286
super().__init__(sregistry, options, platform)
284287

@@ -928,7 +931,7 @@ def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
928931
assert writeto.size == 0
929932

930933
dtype = sympy_dtype(pivot, base=meta.dtype, smin=opt_min_dtype)
931-
obj = Temp(name=name, dtype=dtype)
934+
obj = Temp(name=name, dtype=dtype, is_const=True)
932935
expression = Eq(obj, uxreplace(pivot, subs))
933936

934937
callback = lambda idx: obj # noqa: B023

devito/passes/clusters/cse.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def cse_dtype(exprdtype, cdtype):
4040
Return the dtype of a CSE temporary given the dtype of the expression to be
4141
captured and the cluster's dtype.
4242
"""
43+
if np.issubdtype(cdtype, np.floating) and np.issubdtype(exprdtype, np.integer):
44+
# Integer expression and floating-point cluster: promote to the floating point
45+
# np.promote_types upcast integers (e.g int32 -> Float64) so we
46+
# need to ensure that the promoted type is not larger than the cluster's dtype
47+
return cdtype
48+
4349
if np.issubdtype(cdtype, np.complexfloating):
4450
return np.promote_types(exprdtype, cdtype(0).real.__class__).type
4551
else:
@@ -97,8 +103,9 @@ def cse(cluster, sregistry=None, options=None, **kwargs):
97103
if cluster.is_fence:
98104
return cluster
99105

100-
make_dtype = lambda e: cse_dtype(e.dtype, dtype)
101-
make = lambda e: CTemp(name=sregistry.make_name(), dtype=make_dtype(e))
106+
def make(e):
107+
edtype = cse_dtype(e.dtype, dtype)
108+
return CTemp(name=sregistry.make_name(), dtype=edtype)
102109

103110
exprs = _cse(cluster, make, min_cost=min_cost, mode=mode)
104111

devito/passes/clusters/misc.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ def callback(self, clusters, prefix):
9797

9898
properties = c.properties.filter(key)
9999

100-
# Lifted scalar clusters cannot be guarded
101-
# as they would not be in the scope of the guarded clusters
102-
# unless the guard is for an outer dimension
103-
guards = {} if c.is_scalar and not (prefix[:-1] and c.guards) else c.guards
104-
105-
lifted.append(c.rebuild(ispace=ispace, properties=properties, guards=guards))
100+
# If `c` is made of scalar expressions within guards, then we must keep
101+
# it close to the adjacent Clusters for correctness
102+
if c.is_scalar and c.guards and ispace:
103+
processed.append(c.rebuild(ispace=ispace, properties=properties))
104+
else:
105+
lifted.append(c.rebuild(ispace=ispace, properties=properties))
106106

107107
return lifted + processed
108108

devito/symbolics/search.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,18 @@ def visit_preorder_first_hit(self, expr: Expression) -> Iterator[Expression]:
9696

9797

9898
def search(exprs: Expression | Iterable[Expression],
99-
query: type | Callable[[Any], bool],
99+
query: type | tuple[type, ...] | Callable[[Any], bool],
100100
mode: Mode = 'unique',
101101
visit: Literal['dfs', 'bfs', 'bfs_first_hit'] = 'dfs',
102102
deep: bool = False) -> List | set[Expression]:
103103
"""Interface to Search."""
104104

105105
assert mode in ('all', 'unique'), "Unknown mode"
106106

107-
Q = (lambda obj: isinstance(obj, query)) if isinstance(query, type) else query
107+
if isinstance(query, (type, tuple)):
108+
Q = lambda obj: isinstance(obj, query)
109+
else:
110+
Q = query
108111

109112
# Search doesn't actually use a BFS (rather, a preorder DFS), but the terminology
110113
# is retained in this function's parameters for backwards compatibility

devito/types/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def _rebuild(self, *args, **kwargs):
15891589
comps = [f.func(*args, name=f.name.replace(self.name, newname), **kwargs)
15901590
for f in self.flat()]
15911591
# Rebuild the matrix with the new components
1592-
return self._new(comps)
1592+
return self._new(*self.shape, comps)
15931593

15941594
func = _rebuild
15951595

examples/mpi/overview.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@
460460
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
461461
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
462462
"\n",
463-
" float r0 = 1.0F/h_x;\n",
463+
" const float r0 = 1.0F/h_x;\n",
464464
"\n",
465465
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
466466
" {\n",

examples/performance/00_overview.ipynb

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@
508508
"}\n",
509509
"STOP(section0,timers)\n",
510510
"\n",
511-
"float r1 = 1.0F/h_y;\n",
511+
"const float r1 = 1.0F/h_y;\n",
512512
"\n",
513513
"for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
514514
"{\n",
@@ -572,7 +572,13 @@
572572
"+ u[t1][x + 4][y + 4][z + 4] = (f[x + 1][y + 1][z + 1]*f[x + 1][y + 1][z + 1])*((-6.66666667e-1F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y + 1][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 2][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 4][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 5][z + 4]) + (-8.33333333e-2F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y + 4][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 5][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 7][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 8][z + 4]) + (8.33333333e-2F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 1][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 3][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 4][z + 4]) + (6.66666667e-1F*r0)*(8.33333333e-2F*r0*u[t0][x + 4][y + 3][z + 4] - 6.66666667e-1F*r0*u[t0][x + 4][y + 4][z + 4] + 6.66666667e-1F*r0*u[t0][x + 4][y + 6][z + 4] - 8.33333333e-2F*r0*u[t0][x + 4][y + 7][z + 4]))*sinf(f[x + 1][y + 1][z + 1]);\n",
573573
" }\n",
574574
" }\n",
575-
" }\n",
575+
" }\n"
576+
]
577+
},
578+
{
579+
"name": "stdout",
580+
"output_type": "stream",
581+
"text": [
576582
"\n"
577583
]
578584
}
@@ -652,7 +658,7 @@
652658
},
653659
{
654660
"cell_type": "code",
655-
"execution_count": 13,
661+
"execution_count": 12,
656662
"metadata": {},
657663
"outputs": [
658664
{
@@ -712,7 +718,7 @@
712718
},
713719
{
714720
"cell_type": "code",
715-
"execution_count": 14,
721+
"execution_count": 13,
716722
"metadata": {},
717723
"outputs": [
718724
{
@@ -753,7 +759,14 @@
753759
" }\n",
754760
" }\n",
755761
" STOP(section0,timers)\n",
756-
"}\n"
762+
"}"
763+
]
764+
},
765+
{
766+
"name": "stdout",
767+
"output_type": "stream",
768+
"text": [
769+
"\n"
757770
]
758771
}
759772
],
@@ -772,7 +785,7 @@
772785
},
773786
{
774787
"cell_type": "code",
775-
"execution_count": 15,
788+
"execution_count": 14,
776789
"metadata": {},
777790
"outputs": [
778791
{
@@ -863,7 +876,7 @@
863876
},
864877
{
865878
"cell_type": "code",
866-
"execution_count": 16,
879+
"execution_count": 15,
867880
"metadata": {},
868881
"outputs": [
869882
{
@@ -919,7 +932,7 @@
919932
},
920933
{
921934
"cell_type": "code",
922-
"execution_count": 17,
935+
"execution_count": 16,
923936
"metadata": {},
924937
"outputs": [
925938
{
@@ -976,7 +989,7 @@
976989
},
977990
{
978991
"cell_type": "code",
979-
"execution_count": 18,
992+
"execution_count": 17,
980993
"metadata": {},
981994
"outputs": [],
982995
"source": [
@@ -994,7 +1007,7 @@
9941007
},
9951008
{
9961009
"cell_type": "code",
997-
"execution_count": 19,
1010+
"execution_count": 18,
9981011
"metadata": {},
9991012
"outputs": [
10001013
{
@@ -1044,7 +1057,7 @@
10441057
},
10451058
{
10461059
"cell_type": "code",
1047-
"execution_count": 20,
1060+
"execution_count": 19,
10481061
"metadata": {},
10491062
"outputs": [
10501063
{
@@ -1112,7 +1125,7 @@
11121125
},
11131126
{
11141127
"cell_type": "code",
1115-
"execution_count": 21,
1128+
"execution_count": 20,
11161129
"metadata": {},
11171130
"outputs": [
11181131
{
@@ -1192,7 +1205,7 @@
11921205
" }\n",
11931206
" STOP(section0,timers)\n",
11941207
"\n",
1195-
" float r1 = 1.0F/h_y;\n",
1208+
" const float r1 = 1.0F/h_y;\n",
11961209
"\n",
11971210
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
11981211
" {\n",
@@ -1279,7 +1292,7 @@
12791292
},
12801293
{
12811294
"cell_type": "code",
1282-
"execution_count": 22,
1295+
"execution_count": 21,
12831296
"metadata": {},
12841297
"outputs": [
12851298
{
@@ -1304,7 +1317,7 @@
13041317
"}\n",
13051318
"STOP(section0,timers)\n",
13061319
"\n",
1307-
"float r1 = 1.0F/h_y;\n",
1320+
"const float r1 = 1.0F/h_y;\n",
13081321
"\n",
13091322
"for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
13101323
"{\n",
@@ -1369,7 +1382,7 @@
13691382
},
13701383
{
13711384
"cell_type": "code",
1372-
"execution_count": 23,
1385+
"execution_count": 22,
13731386
"metadata": {},
13741387
"outputs": [
13751388
{
@@ -1404,7 +1417,7 @@
14041417
},
14051418
{
14061419
"cell_type": "code",
1407-
"execution_count": 24,
1420+
"execution_count": 23,
14081421
"metadata": {},
14091422
"outputs": [
14101423
{
@@ -1483,7 +1496,7 @@
14831496
" }\n",
14841497
" STOP(section0,timers)\n",
14851498
"\n",
1486-
" float r1 = 1.0F/h_y;\n",
1499+
" const float r1 = 1.0F/h_y;\n",
14871500
"\n",
14881501
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
14891502
" {\n",
@@ -1546,7 +1559,7 @@
15461559
},
15471560
{
15481561
"cell_type": "code",
1549-
"execution_count": 25,
1562+
"execution_count": 24,
15501563
"metadata": {},
15511564
"outputs": [
15521565
{
@@ -1624,8 +1637,8 @@
16241637
" }\n",
16251638
" STOP(section0,timers)\n",
16261639
"\n",
1627-
" float r1 = 1.0F/h_x;\n",
1628-
" float r2 = 1.0F/h_y;\n",
1640+
" const float r1 = 1.0F/h_x;\n",
1641+
" const float r2 = 1.0F/h_y;\n",
16291642
"\n",
16301643
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
16311644
" {\n",
@@ -1718,7 +1731,7 @@
17181731
"name": "python",
17191732
"nbconvert_exporter": "python",
17201733
"pygments_lexer": "ipython3",
1721-
"version": "3.10.12"
1734+
"version": "3.13.11"
17221735
}
17231736
},
17241737
"nbformat": 4,

examples/performance/01_gpu.ipynb

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,13 @@
142142
"name": "stderr",
143143
"output_type": "stream",
144144
"text": [
145-
"NUMA domain count autodetection failed, assuming 1\n",
146-
"Operator `Kernel` ran in 0.01 s\n",
145+
"NUMA domain count autodetection failed, assuming 1\n"
146+
]
147+
},
148+
{
149+
"name": "stderr",
150+
"output_type": "stream",
151+
"text": [
147152
"Operator `Kernel` ran in 0.01 s\n"
148153
]
149154
}
@@ -292,9 +297,9 @@
292297
" const int x_stride0 = x_fsz0*y_fsz0;\n",
293298
" const int y_stride0 = y_fsz0;\n",
294299
"\n",
295-
" float r0 = 1.0F/dt;\n",
296-
" float r1 = 1.0F/(h_x*h_x);\n",
297-
" float r2 = 1.0F/(h_y*h_y);\n",
300+
" const float r0 = 1.0F/dt;\n",
301+
" const float r1 = 1.0F/(h_x*h_x);\n",
302+
" const float r2 = 1.0F/(h_y*h_y);\n",
298303
"\n",
299304
" for (int time = time_m; time <= time_M; time += 1)\n",
300305
" {\n",
@@ -340,7 +345,7 @@
340345
"name": "python",
341346
"nbconvert_exporter": "python",
342347
"pygments_lexer": "ipython3",
343-
"version": "3.13.5"
348+
"version": "3.13.11"
344349
}
345350
},
346351
"nbformat": 4,

0 commit comments

Comments
 (0)