3232from patsy .compat import OrderedDict
3333from patsy .util import (repr_pretty_delegate , repr_pretty_impl ,
3434 safe_issubdtype ,
35- no_pickling , assert_no_pickling )
35+ no_pickling , assert_no_pickling , check_pickle_version )
3636from patsy .constraint import linear_constraint
3737from patsy .contrasts import ContrastMatrix
3838from patsy .desc import ModelDesc , Term
39+ from patsy import __version__
3940
4041class FactorInfo (object ):
4142 """A FactorInfo object is a simple class that provides some metadata about
@@ -120,7 +121,17 @@ def __repr__(self):
120121 kwlist .append (("categories" , self .categories ))
121122 repr_pretty_impl (p , self , [], kwlist )
122123
123- __getstate__ = no_pickling
124+ def __eq__ (self , other ):
125+ return self .__dict__ == other .__dict__
126+
127+ def __hash__ (self ):
128+ if not self .categories :
129+ categories = 'NoCategories'
130+ else :
131+ categories = frozenset (self .categories )
132+ return hash ((FactorInfo , str (self .factor ), str (self .type ),
133+ str (self .state ), str (self .num_columns ), categories ))
134+
124135
125136def test_FactorInfo ():
126137 fi1 = FactorInfo ("asdf" , "numerical" , {"a" : 1 }, num_columns = 10 )
@@ -234,7 +245,10 @@ def _repr_pretty_(self, p, cycle):
234245 ("contrast_matrices" , self .contrast_matrices ),
235246 ("num_columns" , self .num_columns )])
236247
237- __getstate__ = no_pickling
248+ def __eq__ (self , other ):
249+ return self .__dict__ == other .__dict__
250+
251+ # __getstate__ = no_pickling
238252
239253def test_SubtermInfo ():
240254 cm = ContrastMatrix (np .ones ((2 , 2 )), ["[1]" , "[2]" ])
@@ -691,16 +705,40 @@ def from_array(cls, array_like, default_column_prefix="column"):
691705 for i in columns ]
692706 return DesignInfo (column_names )
693707
694- __getstate__ = no_pickling
708+ def __getstate__ (self ):
709+ return (0 , self .column_name_indexes , self .factor_infos ,
710+ self .term_codings , self .term_slices , self .term_name_slices )
711+
712+ def __setstate__ (self , pickle ):
713+ (version , column_name_indexes , factor_infos , term_codings ,
714+ term_slices , term_name_slices ) = pickle
715+ check_pickle_version (version , 0 , self .__class__ .__name__ )
716+ self .column_name_indexes = column_name_indexes
717+ self .factor_infos = factor_infos
718+ self .term_codings = term_codings
719+ self .term_slices = term_slices
720+ self .term_name_slices = term_name_slices
721+
722+ def __eq__ (self , other ):
723+ return self .__dict__ == other .__dict__
724+
725+
726+ class _MockFactor (object ):
727+ def __init__ (self , name ):
728+ self ._name = name
729+
730+ def name (self ):
731+ return self ._name
732+
733+ def __eq__ (self , other ):
734+ return self .__dict__ == other .__dict__
735+
736+ def __hash__ (self ):
737+ return hash ((_MockFactor , str (self ._name )))
738+
695739
696740def test_DesignInfo ():
697741 from nose .tools import assert_raises
698- class _MockFactor (object ):
699- def __init__ (self , name ):
700- self ._name = name
701-
702- def name (self ):
703- return self ._name
704742 f_x = _MockFactor ("x" )
705743 f_y = _MockFactor ("y" )
706744 t_x = Term ([f_x ])
@@ -734,8 +772,9 @@ def name(self):
734772
735773 # smoke test
736774 repr (di )
775+ from six .moves import cPickle as pickle
737776
738- assert_no_pickling ( di )
777+ assert di == pickle . loads ( pickle . dumps ( di , pickle . HIGHEST_PROTOCOL ) )
739778
740779 # One without term objects
741780 di = DesignInfo (["a1" , "a2" , "a3" , "b" ])
@@ -756,6 +795,8 @@ def name(self):
756795 assert di .slice ("a3" ) == slice (2 , 3 )
757796 assert di .slice ("b" ) == slice (3 , 4 )
758797
798+ assert di == pickle .loads (pickle .dumps (di , pickle .HIGHEST_PROTOCOL ))
799+
759800 # Check intercept handling in describe()
760801 assert DesignInfo (["Intercept" , "a" , "b" ]).describe () == "1 + a + b"
761802
0 commit comments