Skip to content

Commit 83737b0

Browse files
committed
compiler: make dtype lowering more flexible
1 parent e9fa6ec commit 83737b0

3 files changed

Lines changed: 23 additions & 25 deletions

File tree

.github/workflows/docker-bases.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,4 +255,6 @@ jobs:
255255
file: './docker/Dockerfile.amd'
256256
push: true
257257
target: 'hip'
258+
build-args: |
259+
ROCM_VERSION=6.3.4
258260
tags: devitocodes/bases:amd-hip

devito/symbolics/extended_dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from devito.symbolics.extended_sympy import ReservedWord, Cast, ValueLimit
55
from devito.tools import (Bunch, float2, float3, float4, double2, double3, double4, # noqa
66
int2, int3, int4, ctypes_vector_mapper)
7-
from devito.tools.dtypes_lowering import mapper as dtype_mapper
7+
from devito.tools.dtypes_lowering import dtype_mapper
88

99
__all__ = ['cast_mapper', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'BaseCast', # noqa
1010
'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex']

devito/tools/dtypes_lowering.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@
2222

2323
# NOTE: the following is inspired by pyopencl.cltypes
2424

25-
mapper = {
26-
"half": np.float16,
25+
dtype_mapper = {
2726
"int": np.int32,
2827
"float": np.float32,
2928
"double": np.float64
3029
}
3130

3231

33-
def build_dtypes_vector(field_names, counts):
32+
def build_dtypes_vector(field_names, counts, mapper=None):
3433
ret = {}
34+
mapper = mapper or dtype_mapper
3535
for base_name, base_dtype in mapper.items():
3636
for count in counts:
3737
name = "%s%d" % (base_name, count)
@@ -95,7 +95,7 @@ def get_base_dtype(self, v, default=None):
9595
# Standard vector dtypes
9696
dtypes_vector_mapper.update(build_dtypes_vector(field_names, counts))
9797
# Fallbacks
98-
dtypes_vector_mapper.update({(v, 1): v for v in mapper.values()})
98+
dtypes_vector_mapper.update({(v, 1): v for v in dtype_mapper.values()})
9999

100100

101101
# *** Custom types escaping both the numpy and ctypes namespaces
@@ -181,21 +181,25 @@ def infer_datasize(dtype, shape):
181181
return np.ctypeslib.as_ctypes_type(dtype), datasize
182182

183183

184+
mpi_mapper = {
185+
np.ubyte: 'MPI_BYTE',
186+
np.ushort: 'MPI_UNSIGNED_SHORT',
187+
np.int32: 'MPI_INT',
188+
np.float32: 'MPI_FLOAT',
189+
np.int64: 'MPI_LONG',
190+
np.float64: 'MPI_DOUBLE',
191+
np.complex64: 'MPI_C_COMPLEX',
192+
np.complex128: 'MPI_C_DOUBLE_COMPLEX'
193+
}
194+
195+
184196
def dtype_to_mpitype(dtype):
185197
"""Map numpy types to MPI datatypes."""
186198

187199
# Resolve vector dtype if necessary
188200
dtype = dtypes_vector_mapper.get_base_dtype(dtype)
189201

190-
return {
191-
np.ubyte: 'MPI_BYTE',
192-
np.ushort: 'MPI_UNSIGNED_SHORT',
193-
np.int32: 'MPI_INT',
194-
np.float32: 'MPI_FLOAT',
195-
np.int64: 'MPI_LONG',
196-
np.float64: 'MPI_DOUBLE',
197-
np.float16: 'MPI_UNSIGNED_SHORT'
198-
}[dtype]
202+
return mpi_mapper[dtype]
199203

200204

201205
def dtype_to_mpidtype(dtype):
@@ -226,9 +230,7 @@ class c_restrict_void_p(ctypes.c_void_p):
226230

227231

228232
ctypes_vector_mapper = {}
229-
for base_name, base_dtype in mapper.items():
230-
if base_dtype is np.float16:
231-
continue
233+
for base_name, base_dtype in dtype_mapper.items():
232234
base_ctype = dtype_to_ctype(base_dtype)
233235

234236
for count in counts:
@@ -304,11 +306,6 @@ def ctypes_to_cstr(ctype, toarray=None):
304306
return retval
305307

306308

307-
known_ctypes = {
308-
'vector_types.h': list(ctypes_vector_mapper.values()),
309-
}
310-
311-
312309
def is_external_ctype(ctype, includes):
313310
"""
314311
True if `ctype` is known to be declared in one of the given `includes`
@@ -321,9 +318,8 @@ def is_external_ctype(ctype, includes):
321318
if issubclass(ctype, ctypes._SimpleCData):
322319
return False
323320

324-
for k, v in known_ctypes.items():
325-
if ctype in v:
326-
return True
321+
if ctype in ctypes_vector_mapper.values():
322+
return True
327323

328324
return False
329325

0 commit comments

Comments
 (0)