2222
2323# NOTE: the following is inspired by pyopencl.cltypes
2424
25- mapper = {
26- "half" : np .float16 ,
25+ dtype_mapper = {
2726 "int" : np .int32 ,
2827 "float" : np .float32 ,
2928 "double" : np .float64
3029}
3130
3231
33- def build_dtypes_vector (field_names , counts ):
32+ def build_dtypes_vector (field_names , counts , mapper = None ):
3433 ret = {}
34+ mapper = mapper or dtype_mapper
3535 for base_name , base_dtype in mapper .items ():
3636 for count in counts :
3737 name = "%s%d" % (base_name , count )
@@ -95,7 +95,7 @@ def get_base_dtype(self, v, default=None):
9595# Standard vector dtypes
9696dtypes_vector_mapper .update (build_dtypes_vector (field_names , counts ))
9797# Fallbacks
98- dtypes_vector_mapper .update ({(v , 1 ): v for v in mapper .values ()})
98+ dtypes_vector_mapper .update ({(v , 1 ): v for v in dtype_mapper .values ()})
9999
100100
101101# *** Custom types escaping both the numpy and ctypes namespaces
@@ -181,21 +181,25 @@ def infer_datasize(dtype, shape):
181181 return np .ctypeslib .as_ctypes_type (dtype ), datasize
182182
183183
184+ mpi_mapper = {
185+ np .ubyte : 'MPI_BYTE' ,
186+ np .ushort : 'MPI_UNSIGNED_SHORT' ,
187+ np .int32 : 'MPI_INT' ,
188+ np .float32 : 'MPI_FLOAT' ,
189+ np .int64 : 'MPI_LONG' ,
190+ np .float64 : 'MPI_DOUBLE' ,
191+ np .complex64 : 'MPI_C_COMPLEX' ,
192+ np .complex128 : 'MPI_C_DOUBLE_COMPLEX'
193+ }
194+
195+
184196def dtype_to_mpitype (dtype ):
185197 """Map numpy types to MPI datatypes."""
186198
187199 # Resolve vector dtype if necessary
188200 dtype = dtypes_vector_mapper .get_base_dtype (dtype )
189201
190- return {
191- np .ubyte : 'MPI_BYTE' ,
192- np .ushort : 'MPI_UNSIGNED_SHORT' ,
193- np .int32 : 'MPI_INT' ,
194- np .float32 : 'MPI_FLOAT' ,
195- np .int64 : 'MPI_LONG' ,
196- np .float64 : 'MPI_DOUBLE' ,
197- np .float16 : 'MPI_UNSIGNED_SHORT'
198- }[dtype ]
202+ return mpi_mapper [dtype ]
199203
200204
201205def dtype_to_mpidtype (dtype ):
@@ -226,9 +230,7 @@ class c_restrict_void_p(ctypes.c_void_p):
226230
227231
228232ctypes_vector_mapper = {}
229- for base_name , base_dtype in mapper .items ():
230- if base_dtype is np .float16 :
231- continue
233+ for base_name , base_dtype in dtype_mapper .items ():
232234 base_ctype = dtype_to_ctype (base_dtype )
233235
234236 for count in counts :
@@ -304,11 +306,6 @@ def ctypes_to_cstr(ctype, toarray=None):
304306 return retval
305307
306308
307- known_ctypes = {
308- 'vector_types.h' : list (ctypes_vector_mapper .values ()),
309- }
310-
311-
312309def is_external_ctype (ctype , includes ):
313310 """
314311 True if `ctype` is known to be declared in one of the given `includes`
@@ -321,9 +318,8 @@ def is_external_ctype(ctype, includes):
321318 if issubclass (ctype , ctypes ._SimpleCData ):
322319 return False
323320
324- for k , v in known_ctypes .items ():
325- if ctype in v :
326- return True
321+ if ctype in ctypes_vector_mapper .values ():
322+ return True
327323
328324 return False
329325
0 commit comments