Skip to content

Commit 58f1c70

Browse files
committed
compiler: Fix ComponentAccess pickling
1 parent 6b904d9 commit 58f1c70

3 files changed

Lines changed: 41 additions & 4 deletions

File tree

devito/types/array.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from sympy import Expr, cacheit
66

7-
from devito.tools import (Reconstructable, as_tuple, c_restrict_void_p,
7+
from devito.tools import (Pickable, as_tuple, c_restrict_void_p,
88
dtype_to_ctype, dtypes_vector_mapper, is_integer)
99
from devito.types.basic import AbstractFunction, LocalType
1010
from devito.types.utils import CtypesFactory, DimensionTuple
@@ -518,10 +518,11 @@ def handles(self):
518518
return self.components
519519

520520

521-
class ComponentAccess(Expr, Reconstructable):
521+
class ComponentAccess(Expr, Pickable):
522522

523523
_component_names = ('x', 'y', 'z', 'w')
524524

525+
__rargs__ = ('arg',)
525526
__rkwargs__ = ('index',)
526527

527528
def __new__(cls, arg, index=0, **kwargs):
@@ -543,7 +544,7 @@ def __str__(self):
543544

544545
__repr__ = __str__
545546

546-
func = Reconstructable._rebuild
547+
func = Pickable._rebuild
547548

548549
def _sympystr(self, printer):
549550
return str(self)
@@ -552,6 +553,10 @@ def _sympystr(self, printer):
552553
def base(self):
553554
return self.args[0]
554555

556+
@property
557+
def arg(self):
558+
return self.base
559+
555560
@property
556561
def index(self):
557562
return self._index

tests/test_pickle.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar,
2121
PointerArray, Lock, PThreadArray, SharedData, Timer,
2222
DeviceID, NPThreads, ThreadID, TempFunction, Indirection,
23-
FIndexed, StencilDimension)
23+
FIndexed, ComponentAccess, StencilDimension)
2424
from devito.types.basic import BoundSymbol, AbstractSymbol
2525
from devito.tools import EnrichedTuple
2626
from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer,
@@ -417,6 +417,20 @@ def test_findexed(self, pickle):
417417
assert new_fi.indices == (x+1, y, z-2)
418418
assert new_fi.strides_map == fi.strides_map
419419

420+
def test_component_access(self, pickle):
421+
grid = Grid(shape=(3, 3, 3))
422+
x, y, z = grid.dimensions
423+
424+
f = Function(name='f', grid=grid)
425+
426+
ca = ComponentAccess(f.indexify(), 1)
427+
428+
pkl_ca = pickle.dumps(ca)
429+
new_ca = pickle.loads(pkl_ca)
430+
431+
assert new_ca.index == 1
432+
assert new_ca.function.name == f.name
433+
420434
def test_weights_to_array(self, pickle):
421435
grid = Grid(shape=(3, 3, 3))
422436
x, y, z = grid.dimensions

tests/test_symbolics.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,24 @@ def test_findexed():
483483
assert new_fi.strides_map == strides_map
484484

485485

486+
def test_component_access():
487+
grid = Grid(shape=(3, 3, 3))
488+
x, y, z = grid.dimensions
489+
490+
f = Function(name='f', grid=grid)
491+
492+
cf0 = ComponentAccess(f.indexify(), 0)
493+
cf1 = ComponentAccess(f.indexify(), 1)
494+
495+
assert ccode(cf0) == 'f[x][y][z].x'
496+
assert ccode(cf1) == 'f[x][y][z].y'
497+
498+
# Reconstruction
499+
cf2 = cf1.func(*cf1.args)
500+
assert cf2.index == cf1.index
501+
assert cf2 == cf1
502+
503+
486504
def test_canonical_ordering_of_weights():
487505
grid = Grid(shape=(3, 3, 3))
488506
x, y, z = grid.dimensions

0 commit comments

Comments
 (0)