@@ -121,16 +121,18 @@ def __repr__(self):
121121 kwlist .append (("categories" , self .categories ))
122122 repr_pretty_impl (p , self , [], kwlist )
123123
124- def __eq__ (self , other ):
125- return self .__dict__ == other .__dict__
124+ def __getstate__ (self ):
125+ return {'version' : 0 , 'factor' : self .factor , 'type' : self .type ,
126+ 'state' : self .state , 'num_columns' : self .num_columns ,
127+ 'categories' : self .categories }
126128
127- def __hash__ (self ):
128- if not self .categories :
129- categories = 'NoCategories'
130- else :
131- categories = frozenset ( self . categories )
132- return hash (( FactorInfo , str ( self .factor ), str ( self . type ),
133- str ( self .state ), str ( self . num_columns ), categories ))
129+ def __setstate__ (self , pickle ):
130+ check_pickle_version ( pickle [ 'version' ], 0 , self .__class__ . __name__ )
131+ self . factor = pickle [ 'factor' ]
132+ self . type = pickle [ 'type' ]
133+ self . state = pickle [ 'state' ]
134+ self .num_columns = pickle [ 'num_columns' ]
135+ self .categories = pickle [ ' categories' ]
134136
135137
136138def test_FactorInfo ():
@@ -245,10 +247,17 @@ def _repr_pretty_(self, p, cycle):
245247 ("contrast_matrices" , self .contrast_matrices ),
246248 ("num_columns" , self .num_columns )])
247249
248- def __eq__ (self , other ):
249- return self .__dict__ == other .__dict__
250+ def __getstate__ (self ):
251+ return {'version' : 0 , 'factors' : self .factors ,
252+ 'contrast_matrices' : self .contrast_matrices ,
253+ 'num_columns' : self .num_columns }
254+
255+ def __setstate__ (self , pickle ):
256+ check_pickle_version (pickle ['version' ], 0 , self .__class__ .__name__ )
257+ self .factors = pickle ['factors' ]
258+ self .contrast_matrices = pickle ['contrast_matrices' ]
259+ self .num_columns = pickle ['num_columns' ]
250260
251- # __getstate__ = no_pickling
252261
253262def test_SubtermInfo ():
254263 cm = ContrastMatrix (np .ones ((2 , 2 )), ["[1]" , "[2]" ])
@@ -706,21 +715,19 @@ def from_array(cls, array_like, default_column_prefix="column"):
706715 return DesignInfo (column_names )
707716
708717 def __getstate__ (self ):
709- return (0 , self .column_name_indexes , self .factor_infos ,
710- self .term_codings , self .term_slices , self .term_name_slices )
718+ return {'version' : 0 , 'column_name_indexes' : self .column_name_indexes ,
719+ 'factor_infos' : self .factor_infos ,
720+ 'term_codings' : self .term_codings ,
721+ 'term_slices' : self .term_slices ,
722+ 'term_name_slices' : self .term_name_slices }
711723
712724 def __setstate__ (self , pickle ):
713- (version , column_name_indexes , factor_infos , term_codings ,
714- term_slices , term_name_slices ) = pickle
715- check_pickle_version (version , 0 , self .__class__ .__name__ )
716- self .column_name_indexes = column_name_indexes
717- self .factor_infos = factor_infos
718- self .term_codings = term_codings
719- self .term_slices = term_slices
720- self .term_name_slices = term_name_slices
721-
722- def __eq__ (self , other ):
723- return self .__dict__ == other .__dict__
725+ check_pickle_version (pickle ['version' ], 0 , self .__class__ .__name__ )
726+ self .column_name_indexes = pickle ['column_name_indexes' ]
727+ self .factor_infos = pickle ['factor_infos' ]
728+ self .term_codings = pickle ['term_codings' ]
729+ self .term_slices = pickle ['term_slices' ]
730+ self .term_name_slices = pickle ['term_name_slices' ]
724731
725732
726733class _MockFactor (object ):
@@ -772,9 +779,12 @@ def test_DesignInfo():
772779
773780 # smoke test
774781 repr (di )
775- from six .moves import cPickle as pickle
776782
777- assert di == pickle .loads (pickle .dumps (di , pickle .HIGHEST_PROTOCOL ))
783+ # Pickling check
784+ from six .moves import cPickle as pickle
785+ from patsy .util import assert_pickled_equals
786+ di2 = pickle .loads (pickle .dumps (di , pickle .HIGHEST_PROTOCOL ))
787+ assert_pickled_equals (di , di2 )
778788
779789 # One without term objects
780790 di = DesignInfo (["a1" , "a2" , "a3" , "b" ])
@@ -795,7 +805,8 @@ def test_DesignInfo():
795805 assert di .slice ("a3" ) == slice (2 , 3 )
796806 assert di .slice ("b" ) == slice (3 , 4 )
797807
798- assert di == pickle .loads (pickle .dumps (di , pickle .HIGHEST_PROTOCOL ))
808+ di2 = pickle .loads (pickle .dumps (di , pickle .HIGHEST_PROTOCOL ))
809+ assert_pickled_equals (di , di2 )
799810
800811 # Check intercept handling in describe()
801812 assert DesignInfo (["Intercept" , "a" , "b" ]).describe () == "1 + a + b"
0 commit comments