Skip to content

Commit aea9128

Browse files
committed
Adding __getstate__ and __setstate__ on relevant factors. Adding __eq__ as appropriate to make assert tests work.
1 parent 83f1b80 commit aea9128

14 files changed

Lines changed: 342 additions & 60 deletions

patsy/categorical.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,37 @@
4646
pandas_Categorical_categories,
4747
pandas_Categorical_codes,
4848
safe_issubdtype,
49-
no_pickling, assert_no_pickling)
49+
no_pickling, assert_no_pickling, check_pickle_version)
50+
from patsy.state import StatefulTransform
5051

5152
if have_pandas:
5253
import pandas
5354

5455
# Objects of this type will always be treated as categorical, with the
5556
# specified levels and contrast (if given).
57+
5658
class _CategoricalBox(object):
5759
def __init__(self, data, contrast, levels):
5860
self.data = data
5961
self.contrast = contrast
6062
self.levels = levels
6163

62-
__getstate__ = no_pickling
64+
def __getstate__(self):
65+
data = getattr(self, 'data')
66+
contrast = getattr(self, 'contrast')
67+
levels = getattr(self, 'levels')
68+
return (0, data, contrast, levels)
69+
70+
def __setstate__(self, pickle):
71+
version, data, contrast, levels = pickle
72+
check_pickle_version(version, 0, name=self.__class__.__name__)
73+
self.data = data
74+
self.contrast = contrast
75+
self.levels = levels
76+
77+
def __eq__(self, other):
78+
return self.__dict__ == other.__dict__
79+
6380

6481
def C(data, contrast=None, levels=None):
6582
"""
@@ -120,7 +137,20 @@ def test_C():
120137
assert c4.contrast == "NEW CONTRAST"
121138
assert c4.levels == "LEVELS"
122139

123-
assert_no_pickling(c4)
140+
# assert_no_pickling(c4)
141+
142+
143+
def test_C_pickle():
144+
from six.moves import cPickle as pickle
145+
c1 = C("asdf")
146+
assert c1 == pickle.loads(pickle.dumps(c1))
147+
c2 = C("DATA", "CONTRAST", "LEVELS")
148+
assert c2 == pickle.loads(pickle.dumps(c2))
149+
c3 = C(c2, levels="NEW LEVELS")
150+
assert c3 == pickle.loads(pickle.dumps(c3))
151+
c4 = C(c2, "NEW CONTRAST")
152+
assert c4 == pickle.loads(pickle.dumps(c4))
153+
124154

125155
def guess_categorical(data):
126156
if safe_is_pandas_categorical(data):
@@ -217,7 +247,7 @@ def sniff(self, data):
217247
# would be too. Otherwise we need to keep looking.
218248
return self._level_set == set([True, False])
219249

220-
__getstate__ = no_pickling
250+
# __getstate__ = no_pickling
221251

222252
def test_CategoricalSniffer():
223253
from patsy.missing import NAAction

patsy/constraint.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ def _repr_pretty_(self, p, cycle):
6565
return repr_pretty_impl(p, self,
6666
[self.variable_names, self.coefs, self.constants])
6767

68-
__getstate__ = no_pickling
68+
def __eq__(self, other):
69+
return self.__dict__ == other.__dict__
70+
71+
# __getstate__ = no_pickling
6972

7073
@classmethod
7174
def combine(cls, constraints):
@@ -118,7 +121,7 @@ def test_LinearConstraint():
118121
assert_raises(ValueError, LinearConstraint, ["a", "b"],
119122
np.zeros((0, 2)))
120123

121-
assert_no_pickling(lc)
124+
# assert_no_pickling(lc)
122125

123126
def test_LinearConstraint_combine():
124127
comb = LinearConstraint.combine([LinearConstraint(["a", "b"], [1, 0]),

patsy/contrasts.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from patsy import PatsyError
1717
from patsy.util import (repr_pretty_delegate, repr_pretty_impl,
1818
safe_issubdtype,
19-
no_pickling, assert_no_pickling)
19+
no_pickling, assert_no_pickling, check_pickle_version)
2020

2121
class ContrastMatrix(object):
2222
"""A simple container for a matrix used for coding categorical factors.
@@ -47,7 +47,23 @@ def __init__(self, matrix, column_suffixes):
4747
def _repr_pretty_(self, p, cycle):
4848
repr_pretty_impl(p, self, [self.matrix, self.column_suffixes])
4949

50-
__getstate__ = no_pickling
50+
51+
def __getstate__(self):
52+
return (0, self.matrix, self.column_suffixes)
53+
54+
def __setstate__(self, pickle):
55+
version, matrix, column_suffixes = pickle
56+
check_pickle_version(version, 0, name=self.__class__.__name__)
57+
self.matrix = matrix
58+
self.column_suffixes = column_suffixes
59+
60+
def __eq__(self, other):
61+
if self.column_suffixes != other.column_suffixes:
62+
return False
63+
if not np.array_equal(self.matrix, other.matrix):
64+
return False
65+
return True
66+
5167

5268
def test_ContrastMatrix():
5369
cm = ContrastMatrix([[1, 0], [0, 1]], ["a", "b"])
@@ -59,7 +75,7 @@ def test_ContrastMatrix():
5975
from nose.tools import assert_raises
6076
assert_raises(PatsyError, ContrastMatrix, [[1], [0]], ["a", "b"])
6177

62-
assert_no_pickling(cm)
78+
# assert_no_pickling(cm)
6379

6480
# This always produces an object of the type that Python calls 'str' (whether
6581
# that be a Python 2 string-of-bytes or a Python 3 string-of-unicode). It does

patsy/desc.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from patsy.eval import EvalEnvironment, EvalFactor
1515
from patsy.util import uniqueify_list
1616
from patsy.util import repr_pretty_delegate, repr_pretty_impl
17-
from patsy.util import no_pickling, assert_no_pickling
17+
from patsy.util import no_pickling, assert_no_pickling, check_pickle_version
1818

1919
# These are made available in the patsy.* namespace
2020
__all__ = ["Term", "ModelDesc", "INTERCEPT"]
@@ -65,17 +65,33 @@ def name(self):
6565
else:
6666
return "Intercept"
6767

68-
__getstate__ = no_pickling
68+
def __getstate__(self):
69+
return (0, self.factors)
70+
71+
def __setstate__(self, pickle):
72+
version, factors = pickle
73+
check_pickle_version(version, 0, name=self.__class__.__name__)
74+
self.factors = factors
75+
76+
# __getstate__ = no_pickling
6977

7078
INTERCEPT = Term([])
7179

80+
7281
class _MockFactor(object):
7382
def __init__(self, name):
7483
self._name = name
7584

7685
def name(self):
7786
return self._name
7887

88+
def __eq__(self, other):
89+
return self.__dict__ == other.__dict__
90+
91+
def __hash__(self):
92+
return hash((_MockFactor, str(self._name)))
93+
94+
7995
def test_Term():
8096
assert Term([1, 2, 1]).factors == (1, 2)
8197
assert Term([1, 2]) == Term([2, 1])
@@ -86,7 +102,11 @@ def test_Term():
86102
assert Term([f2, f1]).name() == "b:a"
87103
assert Term([]).name() == "Intercept"
88104

89-
assert_no_pickling(Term([]))
105+
# assert_no_pickling(Term([]))
106+
107+
from six.moves import cPickle as pickle
108+
t = Term([f1, f2])
109+
assert t == pickle.loads(pickle.dumps(t, pickle.HIGHEST_PROTOCOL))
90110

91111
class ModelDesc(object):
92112
"""A simple container representing the termlists parsed from a formula.
@@ -166,7 +186,10 @@ def from_formula(cls, tree_or_string):
166186
assert isinstance(value, cls)
167187
return value
168188

169-
__getstate__ = no_pickling
189+
def __eq__(self, other):
190+
return self.__dict__ == other.__dict__
191+
192+
# __getstate__ = no_pickling
170193

171194
def test_ModelDesc():
172195
f1 = _MockFactor("a")
@@ -177,7 +200,9 @@ def test_ModelDesc():
177200
print(m.describe())
178201
assert m.describe() == "1 + a ~ 0 + a + a:b"
179202

180-
assert_no_pickling(m)
203+
# assert_no_pickling(m)
204+
from six.moves import cPickle as pickle
205+
assert m == pickle.loads(pickle.dumps(m, pickle.HIGHEST_PROTOCOL))
181206

182207
assert ModelDesc([], []).describe() == "~ 0"
183208
assert ModelDesc([INTERCEPT], []).describe() == "1 ~ 0"
@@ -209,7 +234,7 @@ def _pretty_repr_(self, p, cycle): # pragma: no cover
209234
[self.intercept, self.intercept_origin,
210235
self.intercept_removed, self.terms])
211236

212-
__getstate__ = no_pickling
237+
# __getstate__ = no_pickling
213238

214239
def _maybe_add_intercept(doit, terms):
215240
if doit:

patsy/design_info.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232
from patsy.compat import OrderedDict
3333
from patsy.util import (repr_pretty_delegate, repr_pretty_impl,
3434
safe_issubdtype,
35-
no_pickling, assert_no_pickling)
35+
no_pickling, assert_no_pickling, check_pickle_version)
3636
from patsy.constraint import linear_constraint
3737
from patsy.contrasts import ContrastMatrix
3838
from patsy.desc import ModelDesc, Term
39+
from patsy import __version__
3940

4041
class FactorInfo(object):
4142
"""A FactorInfo object is a simple class that provides some metadata about
@@ -120,7 +121,17 @@ def __repr__(self):
120121
kwlist.append(("categories", self.categories))
121122
repr_pretty_impl(p, self, [], kwlist)
122123

123-
__getstate__ = no_pickling
124+
def __eq__(self, other):
125+
return self.__dict__ == other.__dict__
126+
127+
def __hash__(self):
128+
if not self.categories:
129+
categories = 'NoCategories'
130+
else:
131+
categories = frozenset(self.categories)
132+
return hash((FactorInfo, str(self.factor), str(self.type),
133+
str(self.state), str(self.num_columns), categories))
134+
124135

125136
def test_FactorInfo():
126137
fi1 = FactorInfo("asdf", "numerical", {"a": 1}, num_columns=10)
@@ -234,7 +245,10 @@ def _repr_pretty_(self, p, cycle):
234245
("contrast_matrices", self.contrast_matrices),
235246
("num_columns", self.num_columns)])
236247

237-
__getstate__ = no_pickling
248+
def __eq__(self, other):
249+
return self.__dict__ == other.__dict__
250+
251+
# __getstate__ = no_pickling
238252

239253
def test_SubtermInfo():
240254
cm = ContrastMatrix(np.ones((2, 2)), ["[1]", "[2]"])
@@ -691,16 +705,40 @@ def from_array(cls, array_like, default_column_prefix="column"):
691705
for i in columns]
692706
return DesignInfo(column_names)
693707

694-
__getstate__ = no_pickling
708+
def __getstate__(self):
709+
return (0, self.column_name_indexes, self.factor_infos,
710+
self.term_codings, self.term_slices, self.term_name_slices)
711+
712+
def __setstate__(self, pickle):
713+
(version, column_name_indexes, factor_infos, term_codings,
714+
term_slices, term_name_slices) = pickle
715+
check_pickle_version(version, 0, self.__class__.__name__)
716+
self.column_name_indexes = column_name_indexes
717+
self.factor_infos = factor_infos
718+
self.term_codings = term_codings
719+
self.term_slices = term_slices
720+
self.term_name_slices = term_name_slices
721+
722+
def __eq__(self, other):
723+
return self.__dict__ == other.__dict__
724+
725+
726+
class _MockFactor(object):
727+
def __init__(self, name):
728+
self._name = name
729+
730+
def name(self):
731+
return self._name
732+
733+
def __eq__(self, other):
734+
return self.__dict__ == other.__dict__
735+
736+
def __hash__(self):
737+
return hash((_MockFactor, str(self._name)))
738+
695739

696740
def test_DesignInfo():
697741
from nose.tools import assert_raises
698-
class _MockFactor(object):
699-
def __init__(self, name):
700-
self._name = name
701-
702-
def name(self):
703-
return self._name
704742
f_x = _MockFactor("x")
705743
f_y = _MockFactor("y")
706744
t_x = Term([f_x])
@@ -734,8 +772,9 @@ def name(self):
734772

735773
# smoke test
736774
repr(di)
775+
from six.moves import cPickle as pickle
737776

738-
assert_no_pickling(di)
777+
assert di == pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL))
739778

740779
# One without term objects
741780
di = DesignInfo(["a1", "a2", "a3", "b"])
@@ -756,6 +795,8 @@ def name(self):
756795
assert di.slice("a3") == slice(2, 3)
757796
assert di.slice("b") == slice(3, 4)
758797

798+
assert di == pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL))
799+
759800
# Check intercept handling in describe()
760801
assert DesignInfo(["Intercept", "a", "b"]).describe() == "1 + a + b"
761802

0 commit comments

Comments
 (0)