2121 ctypes_to_cstr )
2222from devito .types .basic import (AbstractFunction , AbstractSymbol , Basic , Indexed ,
2323 Symbol )
24- from devito .types .object import AbstractObject , LocalObject
24+ from devito .types .object import AbstractObject , LocalObject , LocalCompositeObject
2525
2626__all__ = ['Node' , 'MultiTraversable' , 'Block' , 'Expression' , 'Callable' ,
2727 'Call' , 'ExprStmt' , 'Conditional' , 'Iteration' , 'List' , 'Section' ,
3030 'Increment' , 'Return' , 'While' , 'ListMajor' , 'ParallelIteration' ,
3131 'ParallelBlock' , 'Dereference' , 'Lambda' , 'SyncSpot' , 'Pragma' ,
3232 'DummyExpr' , 'BlankLine' , 'ParallelTree' , 'BusyWait' , 'UsingNamespace' ,
33- 'Using' , 'CallableBody' , 'Transfer' ]
33+ 'Using' , 'CallableBody' , 'Transfer' , 'Callback' , 'FixedArgsCallable' ]
3434
3535# First-class IET nodes
3636
@@ -759,6 +759,15 @@ def defines(self):
759759 return self .all_parameters
760760
761761
762+ class FixedArgsCallable (Callable ):
763+
764+ """
765+ A Callable class that enforces a fixed function signature.
766+ """
767+
768+ pass
769+
770+
762771class CallableBody (MultiTraversable ):
763772
764773 """
@@ -1037,8 +1046,8 @@ class Dereference(ExprStmt, Node):
10371046 The following cases are supported:
10381047
10391048 * `pointer` is a PointerArray or TempFunction, and `pointee` is an Array.
1040- * `pointer` is an ArrayObject representing a pointer to a C struct, and
1041- `pointee` is a field in `pointer`.
1049+ * `pointer` is an ArrayObject or CCompositeObject representing a pointer
1050+ to a C struct, and `pointee` is a field in `pointer`.
10421051 * `pointer` is a Symbol with its _C_ctype deriving from ct._Pointer, and
10431052 `pointee` is a Symbol representing the dereferenced value.
10441053 """
@@ -1061,7 +1070,8 @@ def functions(self):
10611070 def expr_symbols (self ):
10621071 ret = []
10631072 if self .pointer .is_Symbol :
1064- assert issubclass (self .pointer ._C_ctype , ctypes ._Pointer ), \
1073+ assert (isinstance (self .pointer , LocalCompositeObject ) or
1074+ issubclass (self .pointer ._C_ctype , ctypes ._Pointer )), \
10651075 "Scalar dereference must have a pointer ctype"
10661076 ret .extend ([self .pointer ._C_symbol , self .pointee ._C_symbol ])
10671077 elif self .pointer .is_PointerArray or self .pointer .is_TempFunction :
@@ -1136,6 +1146,45 @@ def defines(self):
11361146 return tuple (self .parameters )
11371147
11381148
1149+ class Callback (Call ):
1150+ """
1151+ Base class for special callback types.
1152+
1153+ Parameters
1154+ ----------
1155+ name : str
1156+ The name of the callback.
1157+ retval : str
1158+ The return type of the callback.
1159+ param_types : str or list of str
1160+ The return type for each argument of the callback.
1161+
1162+ Notes
1163+ -----
1164+ - The reason Callback is an IET type rather than a SymPy type is
1165+ due to the fact that, when represented at the SymPy level, the IET
1166+ engine fails to bind the callback to a specific Call. Consequently,
1167+ errors occur during the creation of the call graph.
1168+ """
1169+ # TODO: Create a common base class for Call and Callback to avoid
1170+ # having arguments=None here
1171+ def __init__ (self , name , retval = None , param_types = None , arguments = None ):
1172+ super ().__init__ (name = name )
1173+ self .retval = retval
1174+ self .param_types = as_tuple (param_types )
1175+
1176+ @property
1177+ def callback_form (self ):
1178+ """
1179+ A string representation of the callback form.
1180+
1181+ Notes
1182+ -----
1183+ To be overridden by subclasses.
1184+ """
1185+ return
1186+
1187+
11391188class Section (List ):
11401189
11411190 """
0 commit comments