Skip to content
4 changes: 2 additions & 2 deletions devito/arch/archinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def supports(self, query, language=None):
warning(f"Couldn't establish if `query={query}` is supported on this "
"system. Assuming it is not.")
return False
elif query == 'async-loads' and cc >= 80:
elif query == 'async-pipe' and cc >= 80:
# Asynchronous pipeline loads -- introduced in Ampere
return True
elif query in ('tma', 'thread-block-cluster') and cc >= 90:
Expand All @@ -1055,7 +1055,7 @@ class Volta(NvidiaDevice):
class Ampere(Volta):

def supports(self, query, language=None):
if query == 'async-loads':
if query == 'async-pipe':
return True
else:
return super().supports(query, language)
Expand Down
10 changes: 10 additions & 0 deletions devito/ir/cgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,16 @@ def _print_Abs(self, expr):
return f"fabs({self._print(arg)})"
return self._print_fmath_func('abs', expr)

def _print_BitwiseNot(self, expr):
# Unary function, single argument
arg = expr.args[0]
return f'~{self._print(arg)}'

def _print_BitwiseXor(self, expr):
# Binary function
arg0, arg1 = expr.args
return f'{self._print(arg0)} ^ {self._print(arg1)}'

def _print_Add(self, expr, order=None):
""""
Print an addition.
Expand Down
6 changes: 4 additions & 2 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ def __repr__(self):
if not self.is_Reduction:
return super().__repr__()
elif self.operation is OpInc:
return '%s += %s' % (self.lhs, self.rhs)
return f'Inc({self.lhs}, {self.rhs})'
else:
return '%s = %s(%s)' % (self.lhs, self.operation, self.rhs)
return f'Eq({self.lhs}, {self.operation}({self.rhs}))'

__str__ = __repr__

# Pickling support
__reduce_ex__ = Pickable.__reduce_ex__
Expand Down
3 changes: 3 additions & 0 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def _gen_value(self, obj, mode=1, masked=()):
qualifiers = [v for k, v in self._qualifiers_mapper.items()
if getattr(obj.function, k, False) and v not in masked]

if obj.is_LocalObject and mode == 2:
qualifiers.extend(as_tuple(obj._C_tag))

if (obj._mem_stack or obj._mem_constant) and mode == 1:
strtype = self.ccode(obj._C_typedata)
strshape = ''.join(f'[{self.ccode(i)}]' for i in obj.symbolic_shape)
Expand Down
44 changes: 29 additions & 15 deletions devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sympy import S
import numpy as np

from devito.finite_differences import IndexDerivative
from devito.finite_differences import IndexDerivative, Weights
from devito.ir import Backward, Forward, Interval, IterationSpace, Queue
from devito.passes.clusters.misc import fuse
from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace
Expand Down Expand Up @@ -94,17 +94,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs):


@_core.register(Symbol)
@_core.register(Indexed)
@_core.register(BasicWrapperMixin)
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
return expr, []


@_core.register(Indexed)
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
if not isinstance(expr.function, Weights):
return expr, []

# Lower or reuse a previously lowered Weights array
sregistry = kwargs['sregistry']
subs_user = kwargs['subs']

w0 = expr.function
k = tuple(w0.weights)
Comment thread
EdCaunt marked this conversation as resolved.
try:
w = weights[k]
except KeyError:
name = sregistry.make_name(prefix='w')
dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least np.float32

Is this guaranteed somehow?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, w0.dtype basically

initvalue = tuple(i.subs(subs_user) for i in k)
w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue)

rebuilt = expr._subs(w0.indexed, w.indexed)

return rebuilt, []


@_core.register(IndexDerivative)
def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
sregistry = kwargs['sregistry']
options = kwargs['options']
subs_user = kwargs['subs']

try:
cbk0 = deriv_schedule_registry[options['deriv-schedule']]
Expand All @@ -117,18 +139,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):

# Create the concrete Weights array, or reuse an already existing one
# if possible
name = sregistry.make_name(prefix='w')
w0 = ideriv.weights.function
dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32
k = tuple(w0.weights)
try:
w = weights[k]
except KeyError:
initvalue = tuple(i.subs(subs_user) for i in k)
w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue)
w, _ = _core(ideriv.weights, c, ispace, weights, reusables, mapper, **kwargs)

# Replace the abstract Weights array with the concrete one
subs = {w0.indexed: w.indexed}
subs = {ideriv.weights.base: w.base}
init = uxreplace(init, subs)
ideriv = uxreplace(ideriv, subs)

Expand Down Expand Up @@ -158,10 +172,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs):
# NOTE: created before recurring so that we ultimately get a sound ordering
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"recurring" -> "recursing"?

try:
s = reusables.pop()
assert np.can_cast(s.dtype, dtype)
assert np.can_cast(s.dtype, w.dtype)
except KeyError:
name = sregistry.make_name(prefix='r')
s = Symbol(name=name, dtype=dtype)
s = Symbol(name=name, dtype=w.dtype)

# Go inside `expr` and recursively lower any nested IndexDerivatives
expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, **kwargs)
Expand Down
18 changes: 13 additions & 5 deletions devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,25 @@ def _alloc_object_on_low_lat_mem(self, site, obj, storage):
"""
decl = Definition(obj)

if obj._C_init:
definition = (decl, obj._C_init)
init = obj._C_init
if not init:
definition = decl
efuncs = ()
elif init.is_Callable:
definition = Call(init.name, init.parameters,
retobj=obj if init.retval else None)
efuncs = (init,)
else:
definition = (decl)
definition = (decl, init)
efuncs = ()

frees = obj._C_free

if obj.free_symbols - {obj}:
storage.update(obj, site, objs=definition, frees=frees)
storage.update(obj, site, objs=definition, efuncs=efuncs, frees=frees)
else:
storage.update(obj, site, standalones=definition, frees=frees)
storage.update(obj, site, standalones=definition, efuncs=efuncs,
frees=frees)

def _alloc_array_on_low_lat_mem(self, site, obj, storage):
"""
Expand Down
4 changes: 2 additions & 2 deletions devito/symbolics/extended_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from devito.tools.dtypes_lowering import dtype_mapper

__all__ = ['cast', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'BaseCast', # noqa
'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex',
'LONG']
'DOUBLE', 'VOID', 'LONG', 'ULONG', 'NoDeclStruct', 'c_complex',
'c_double_complex']


limits_mapper = {
Expand Down
26 changes: 19 additions & 7 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sympy
from sympy import Expr, Function, Number, Tuple, cacheit, sympify
from sympy.core.decorators import call_highest_priority
from sympy.logic.boolalg import BooleanFunction

from devito.finite_differences.elementary import Min, Max
from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa
Expand All @@ -16,13 +17,13 @@
from devito.types import Symbol
from devito.types.basic import Basic

__all__ = ['CondEq', 'CondNe', 'IntDiv', 'CallFromPointer', # noqa
'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite',
'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction',
'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String',
'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace',
'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit',
'VectorAccess']
__all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'IntDiv', # noqa
'CallFromPointer', 'CallFromComposite', 'FieldFromPointer',
'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer',
'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'ReservedWord',
'Keyword', 'String', 'Macro', 'Class', 'MacroArgument', 'Deref',
'Namespace', 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin',
'ValueLimit', 'VectorAccess']


class CondEq(sympy.Eq):
Expand Down Expand Up @@ -63,6 +64,17 @@ def negated(self):
return CondEq(*self.args, evaluate=False)


class BitwiseNot(BooleanFunction):
pass


class BitwiseXor(BooleanFunction):

# Enforce two args
def __new__(cls, arg0, arg1, **kwargs):
return super().__new__(cls, arg0, arg1, **kwargs)


class IntDiv(sympy.Expr):

"""
Expand Down
13 changes: 11 additions & 2 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,8 +1875,17 @@ def _mem_internal_lazy(self):
return self._liveness == 'lazy'

"""
A modifier added to the subclass C declaration when it appears
in a function signature. For example, a subclass might define `_C_modifier = '&'`
A modifier added to the declaration of the LocalType when it appears in a
function signature. For example, a subclass might define `_C_modifier = '&'`
to impose pass-by-reference semantics.
"""
_C_modifier = None

"""
One or more optional keywords added to the declaration of the LocalType
in between the type and the variable name when it appears in a function
signature. For example, some languages support these to modify the way
the compiler generates code for passing the parameter and how the
runtime accesses it.
"""
_C_tag = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is something pro specific?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used by PRO yes

24 changes: 24 additions & 0 deletions devito/types/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,30 @@ def closing(self):
"""


class FunctionMap(LocalObject):

"""
Wrap a Function in a LocalObject.
"""

__rargs__ = ('name', 'tensor')

def __init__(self, name, tensor, **kwargs):
super().__init__(name, **kwargs)
self.tensor = tensor

def _hashable_content(self):
return super()._hashable_content() + (self.tensor,)

@property
def free_symbols(self):
"""
The free symbols of a FunctionMap are the free symbols of the
underlying Function.
"""
return super().free_symbols | {self.tensor}


# *** C/CXX support types

size_t = CustomDtype('size_t')
Expand Down
19 changes: 12 additions & 7 deletions devito/types/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,27 +176,28 @@ class LocalObject(AbstractObject, LocalType):
"""

__rargs__ = ('name',)
__rkwargs__ = ('cargs', 'initvalue', 'liveness', 'is_global')
__rkwargs__ = ('cargs', 'initvalue', 'liveness', 'scope')

def __init__(self, name, cargs=None, initvalue=None, liveness='lazy',
is_global=False, **kwargs):
scope='stack', **kwargs):
self.name = name
self.cargs = as_tuple(cargs)
self.initvalue = initvalue or self.default_initvalue

assert liveness in ['eager', 'lazy']
self._liveness = liveness

self._is_global = is_global
assert scope in ['stack', 'shared', 'global']
self._scope = scope

def _hashable_content(self):
return (super()._hashable_content() +
self.cargs +
(self.initvalue, self.liveness, self.is_global))
(self.initvalue, self.liveness, self.scope))

@property
def is_global(self):
return self._is_global
def scope(self):
return self._scope

@property
def free_symbols(self):
Expand Down Expand Up @@ -232,6 +233,10 @@ def _C_free(self):
"""
return None

@property
def _mem_shared(self):
return self._scope == 'shared'

@property
def _mem_global(self):
return self._is_global
return self._scope == 'global'
Loading
Loading