diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index 0272304ee3..9dee3dd793 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -1038,7 +1038,7 @@ def supports(self, query, language=None): warning(f"Couldn't establish if `query={query}` is supported on this " "system. Assuming it is not.") return False - elif query == 'async-loads' and cc >= 80: + elif query == 'async-pipe' and cc >= 80: # Asynchronous pipeline loads -- introduced in Ampere return True elif query in ('tma', 'thread-block-cluster') and cc >= 90: @@ -1055,7 +1055,7 @@ class Volta(NvidiaDevice): class Ampere(Volta): def supports(self, query, language=None): - if query == 'async-loads': + if query == 'async-pipe': return True else: return super().supports(query, language) diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index 32aa8d6c21..be97df0ca4 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -272,6 +272,16 @@ def _print_Abs(self, expr): return f"fabs({self._print(arg)})" return self._print_fmath_func('abs', expr) + def _print_BitwiseNot(self, expr): + # Unary function, single argument + arg = expr.args[0] + return f'~{self._print(arg)}' + + def _print_BitwiseXor(self, expr): + # Binary function + arg0, arg1 = expr.args + return f'{self._print(arg0)} ^ {self._print(arg1)}' + def _print_Add(self, expr, order=None): """" Print an addition. diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index f83dc39c94..85162b8a5f 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -83,9 +83,11 @@ def __repr__(self): if not self.is_Reduction: return super().__repr__() elif self.operation is OpInc: - return '%s += %s' % (self.lhs, self.rhs) + return f'Inc({self.lhs}, {self.rhs})' else: - return '%s = %s(%s)' % (self.lhs, self.operation, self.rhs) + return f'Eq({self.lhs}, {self.operation}({self.rhs}))' + + __str__ = __repr__ # Pickling support __reduce_ex__ = Pickable.__reduce_ex__ diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 93e7809250..343b97da0a 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -305,6 +305,9 @@ def _gen_value(self, obj, mode=1, masked=()): qualifiers = [v for k, v in self._qualifiers_mapper.items() if getattr(obj.function, k, False) and v not in masked] + if obj.is_LocalObject and mode == 2: + qualifiers.extend(as_tuple(obj._C_tag)) + if (obj._mem_stack or obj._mem_constant) and mode == 1: strtype = self.ccode(obj._C_typedata) strshape = ''.join(f'[{self.ccode(i)}]' for i in obj.symbolic_shape) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 5af92a3208..d58e536500 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -3,7 +3,7 @@ from sympy import S import numpy as np -from devito.finite_differences import IndexDerivative +from devito.finite_differences import IndexDerivative, Weights from devito.ir import Backward, Forward, Interval, IterationSpace, Queue from devito.passes.clusters.misc import fuse from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace @@ -94,17 +94,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs): @_core.register(Symbol) -@_core.register(Indexed) @_core.register(BasicWrapperMixin) def _(expr, c, ispace, weights, reusables, mapper, **kwargs): return expr, [] +@_core.register(Indexed) +def _(expr, c, ispace, weights, reusables, mapper, **kwargs): + if not isinstance(expr.function, Weights): + return expr, [] + + # Lower or reuse a previously lowered Weights array + sregistry = kwargs['sregistry'] + subs_user = kwargs['subs'] + + w0 = expr.function + k = tuple(w0.weights) + try: + w = weights[k] + except KeyError: + name = sregistry.make_name(prefix='w') + dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32 + initvalue = tuple(i.subs(subs_user) for i in k) + w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue) + + rebuilt = expr._subs(w0.indexed, w.indexed) + + return rebuilt, [] + + @_core.register(IndexDerivative) def _(expr, c, ispace, weights, reusables, mapper, **kwargs): sregistry = kwargs['sregistry'] options = kwargs['options'] - subs_user = kwargs['subs'] try: cbk0 = deriv_schedule_registry[options['deriv-schedule']] @@ -117,18 +139,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): # Create the concrete Weights array, or reuse an already existing one # if possible - name = sregistry.make_name(prefix='w') - w0 = ideriv.weights.function - dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32 - k = tuple(w0.weights) - try: - w = weights[k] - except KeyError: - initvalue = tuple(i.subs(subs_user) for i in k) - w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue) + w, _ = _core(ideriv.weights, c, ispace, weights, reusables, mapper, **kwargs) # Replace the abstract Weights array with the concrete one - subs = {w0.indexed: w.indexed} + subs = {ideriv.weights.base: w.base} init = uxreplace(init, subs) ideriv = uxreplace(ideriv, subs) @@ -158,10 +172,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): # NOTE: created before recurring so that we ultimately get a sound ordering try: s = reusables.pop() - assert np.can_cast(s.dtype, dtype) + assert np.can_cast(s.dtype, w.dtype) except KeyError: name = sregistry.make_name(prefix='r') - s = Symbol(name=name, dtype=dtype) + s = Symbol(name=name, dtype=w.dtype) # Go inside `expr` and recursively lower any nested IndexDerivatives expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, **kwargs) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index deb3933f68..1f81dc4daa 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -97,17 +97,25 @@ def _alloc_object_on_low_lat_mem(self, site, obj, storage): """ decl = Definition(obj) - if obj._C_init: - definition = (decl, obj._C_init) + init = obj._C_init + if not init: + definition = decl + efuncs = () + elif init.is_Callable: + definition = Call(init.name, init.parameters, + retobj=obj if init.retval else None) + efuncs = (init,) else: - definition = (decl) + definition = (decl, init) + efuncs = () frees = obj._C_free if obj.free_symbols - {obj}: - storage.update(obj, site, objs=definition, frees=frees) + storage.update(obj, site, objs=definition, efuncs=efuncs, frees=frees) else: - storage.update(obj, site, standalones=definition, frees=frees) + storage.update(obj, site, standalones=definition, efuncs=efuncs, + frees=frees) def _alloc_array_on_low_lat_mem(self, site, obj, storage): """ diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index eda71a0b74..23f5c33bf0 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -7,8 +7,8 @@ from devito.tools.dtypes_lowering import dtype_mapper __all__ = ['cast', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'BaseCast', # noqa - 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', - 'LONG'] + 'DOUBLE', 'VOID', 'LONG', 'ULONG', 'NoDeclStruct', 'c_complex', + 'c_double_complex'] limits_mapper = { diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 4a8d2df206..70a904180c 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -7,6 +7,7 @@ import sympy from sympy import Expr, Function, Number, Tuple, cacheit, sympify from sympy.core.decorators import call_highest_priority +from sympy.logic.boolalg import BooleanFunction from devito.finite_differences.elementary import Min, Max from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa @@ -16,13 +17,13 @@ from devito.types import Symbol from devito.types.basic import Basic -__all__ = ['CondEq', 'CondNe', 'IntDiv', 'CallFromPointer', # noqa - 'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite', - 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', - 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', - 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace', - 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit', - 'VectorAccess'] +__all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'IntDiv', # noqa + 'CallFromPointer', 'CallFromComposite', 'FieldFromPointer', + 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', + 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'ReservedWord', + 'Keyword', 'String', 'Macro', 'Class', 'MacroArgument', 'Deref', + 'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', + 'ValueLimit', 'VectorAccess'] class CondEq(sympy.Eq): @@ -63,6 +64,17 @@ def negated(self): return CondEq(*self.args, evaluate=False) +class BitwiseNot(BooleanFunction): + pass + + +class BitwiseXor(BooleanFunction): + + # Enforce two args + def __new__(cls, arg0, arg1, **kwargs): + return super().__new__(cls, arg0, arg1, **kwargs) + + class IntDiv(sympy.Expr): """ diff --git a/devito/types/basic.py b/devito/types/basic.py index 8c7e960fb2..6fefb7db26 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1875,8 +1875,17 @@ def _mem_internal_lazy(self): return self._liveness == 'lazy' """ - A modifier added to the subclass C declaration when it appears - in a function signature. For example, a subclass might define `_C_modifier = '&'` + A modifier added to the declaration of the LocalType when it appears in a + function signature. For example, a subclass might define `_C_modifier = '&'` to impose pass-by-reference semantics. """ _C_modifier = None + + """ + One or more optional keywords added to the declaration of the LocalType + in between the type and the variable name when it appears in a function + signature. For example, some languages support these to modify the way + the compiler generates code for passing the parameter and how the + runtime accesses it. + """ + _C_tag = None diff --git a/devito/types/misc.py b/devito/types/misc.py index 47fa6d601a..e5d178fc57 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -345,6 +345,30 @@ def closing(self): """ +class FunctionMap(LocalObject): + + """ + Wrap a Function in a LocalObject. + """ + + __rargs__ = ('name', 'tensor') + + def __init__(self, name, tensor, **kwargs): + super().__init__(name, **kwargs) + self.tensor = tensor + + def _hashable_content(self): + return super()._hashable_content() + (self.tensor,) + + @property + def free_symbols(self): + """ + The free symbols of a FunctionMap are the free symbols of the + underlying Function. + """ + return super().free_symbols | {self.tensor} + + # *** C/CXX support types size_t = CustomDtype('size_t') diff --git a/devito/types/object.py b/devito/types/object.py index 032bca303a..1877a754ed 100644 --- a/devito/types/object.py +++ b/devito/types/object.py @@ -176,10 +176,10 @@ class LocalObject(AbstractObject, LocalType): """ __rargs__ = ('name',) - __rkwargs__ = ('cargs', 'initvalue', 'liveness', 'is_global') + __rkwargs__ = ('cargs', 'initvalue', 'liveness', 'scope') def __init__(self, name, cargs=None, initvalue=None, liveness='lazy', - is_global=False, **kwargs): + scope='stack', **kwargs): self.name = name self.cargs = as_tuple(cargs) self.initvalue = initvalue or self.default_initvalue @@ -187,16 +187,17 @@ def __init__(self, name, cargs=None, initvalue=None, liveness='lazy', assert liveness in ['eager', 'lazy'] self._liveness = liveness - self._is_global = is_global + assert scope in ['stack', 'shared', 'global'] + self._scope = scope def _hashable_content(self): return (super()._hashable_content() + self.cargs + - (self.initvalue, self.liveness, self.is_global)) + (self.initvalue, self.liveness, self.scope)) @property - def is_global(self): - return self._is_global + def scope(self): + return self._scope @property def free_symbols(self): @@ -232,6 +233,10 @@ def _C_free(self): """ return None + @property + def _mem_shared(self): + return self._scope == 'shared' + @property def _mem_global(self): - return self._is_global + return self._scope == 'global' diff --git a/tests/test_iet.py b/tests/test_iet.py index e21e8f58bf..48ddd3e723 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -15,9 +15,10 @@ from devito.passes.iet.engine import Graph from devito.passes.iet.languages.C import CDataManager from devito.symbolics import (Byref, FieldFromComposite, InlineIf, Macro, Class, - FLOAT) + FLOAT, ListInitializer, SizeOf) from devito.tools import CustomDtype, as_tuple, dtype_to_ctype from devito.types import Array, LocalObject, Symbol +from devito.types.misc import FunctionMap @pytest.fixture @@ -296,6 +297,52 @@ def _C_free(self): }""" +def test_make_cuda_tensor_map(): + + class CUTensorMap(FunctionMap): + + dtype = CustomDtype('CUtensorMap') + + @property + def _C_init(self): + symsizes = list(reversed(self.tensor.symbolic_shape)) + sizeof_dtype = SizeOf(self.tensor.dmap._C_typedata) + + sizes = ListInitializer(symsizes) + strides = ListInitializer([ + np.prod(symsizes[:i])*sizeof_dtype for i in range(1, len(symsizes)) + ]) + + arguments = [ + Byref(self), + Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'), + 4, self.tensor.dmap, sizes, strides, + ] + call = Call('cuTensorMapEncodeTiled', arguments) + + return call + + grid = Grid(shape=(10, 10, 10)) + + u = TimeFunction(name='u', grid=grid) + + tmap = CUTensorMap('tmap', u) + + iet = Call('foo', tmap) + iet = ElementalFunction('foo', iet, parameters=()) + dm = CDataManager(sregistry=None) + iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0] + + assert str(iet) == """\ +static void foo() +{ + CUtensorMap tmap; + cuTensorMapEncodeTiled(&tmap,CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]}); + + foo(tmap); +}""" # noqa + + def test_cpp_local_object(): """ Test C++ support for LocalObjects. @@ -308,7 +355,7 @@ class MyObject(LocalObject): lo0 = MyObject('obj0') # Globally-scoped objects must not be declared in the function body - lo1 = MyObject('obj1', is_global=True) + lo1 = MyObject('obj1', scope='global') # A LocalObject using both a template and a modifier class SpecialObject(LocalObject):