|
4 | 4 | import numpy as np |
5 | 5 | from sympy import Expr, cacheit |
6 | 6 |
|
| 7 | +from devito.data import FULL |
7 | 8 | from devito.tools import (Pickable, as_tuple, c_restrict_void_p, |
8 | 9 | dtype_to_ctype, dtypes_vector_mapper, is_integer) |
9 | 10 | from devito.types.basic import AbstractFunction, LocalType |
@@ -60,6 +61,13 @@ def shape_allocated(self): |
60 | 61 | def is_const(self): |
61 | 62 | return self._is_const |
62 | 63 |
|
| 64 | + @property |
| 65 | + def c0(self): |
| 66 | + # ArrayBasic can be used as a base class for tensorial objects (that is, |
| 67 | + # arrays whose components are AbstractFunctions). This property enables |
| 68 | + # treating the two cases uniformly in some lowering passes |
| 69 | + return self |
| 70 | + |
63 | 71 |
|
64 | 72 | class Array(ArrayBasic): |
65 | 73 |
|
@@ -425,7 +433,6 @@ def __halo_setup__(self, components=(), **kwargs): |
425 | 433 |
|
426 | 434 | @property |
427 | 435 | def c0(self): |
428 | | - # Shortcut for self.components[0] |
429 | 436 | return self.components[0] |
430 | 437 |
|
431 | 438 | # Class attributes overrides |
@@ -464,18 +471,26 @@ def ncomp(self): |
464 | 471 | def initvalue(self): |
465 | 472 | return None |
466 | 473 |
|
467 | | - # Overrides defaulting to self.c0's behaviour |
468 | | - |
| 474 | + # Defaulting to self.c0's behaviour |
469 | 475 | for i in ('_mem_internal_eager', '_mem_internal_lazy', '_mem_local', |
470 | 476 | '_mem_mapped', '_mem_host', '_mem_stack', '_mem_constant', |
471 | 477 | '_mem_shared', '_mem_shared_remote', '__padding_dtype__', |
472 | 478 | '_size_domain', '_size_halo', '_size_owned', '_size_padding', |
473 | | - '_size_nopad', '_size_nodomain', '_offset_domain', |
474 | | - '_offset_halo', '_offset_owned', '_dist_dimensions', |
475 | | - '_C_get_field', 'grid', 'symbolic_shape', |
| 479 | + '_size_nopad', '_size_nodomain', '_offset_domain', '_offset_halo', |
| 480 | + '_offset_owned', '_dist_dimensions', '_C_get_field', 'grid', |
476 | 481 | *AbstractFunction.__properties__): |
477 | 482 | locals()[i] = property(lambda self, v=i: getattr(self.c0, v)) |
478 | 483 |
|
| 484 | + # Other overrides |
| 485 | + |
| 486 | + @cached_property |
| 487 | + def symbolic_shape(self): |
| 488 | + from devito.symbolics import FieldFromPointer, IndexedPointer # noqa |
| 489 | + ffp = FieldFromPointer(self._C_field_shape, self._C_symbol) |
| 490 | + ret = [s if is_integer(s) else IndexedPointer(ffp, i) |
| 491 | + for i, s in enumerate(self.shape)] |
| 492 | + return tuple(ret) |
| 493 | + |
479 | 494 | @property |
480 | 495 | def _mem_heap(self): |
481 | 496 | return not any([self._mem_stack, self._mem_shared, self._mem_shared_remote]) |
|
0 commit comments