@@ -120,7 +120,7 @@ def test_base_process_without_and_with_recording(socket_mod):
120120 BaseProcessorSocket.process() should:
121121 - increment curr_step always,
122122 - when recording, append time/step/frame_time/pose_time,
123- - when save_original=True, store copies of pose arrays.
123+ - when save_original=True, store copies of pose arrays only while recording .
124124 """
125125 BaseProcessorSocket = socket_mod .BaseProcessorSocket
126126 proc = BaseProcessorSocket (bind = ("127.0.0.1" , 0 ), save_original = True )
@@ -136,10 +136,9 @@ def test_base_process_without_and_with_recording(socket_mod):
136136 assert len (proc .step ) == 0
137137 assert len (proc .frame_time ) == 0
138138 assert len (proc .pose_time ) == 0
139- # When not recording, save_original is still respected
139+ # Raw poses must stay aligned with recorded metadata.
140140 assert proc .original_pose is not None
141- assert len (proc .original_pose ) == 1
142- np .testing .assert_allclose (proc .original_pose [0 ], pose )
141+ assert len (proc .original_pose ) == 0
143142
144143 # Start recording and push two frames
145144 proc ._handle_client_message ({"cmd" : "start_recording" })
@@ -150,6 +149,9 @@ def test_base_process_without_and_with_recording(socket_mod):
150149 assert len (proc .step ) == 2
151150 assert len (proc .frame_time ) == 2
152151 assert len (proc .pose_time ) == 2
152+ assert len (proc .original_pose ) == 2
153+ np .testing .assert_allclose (proc .original_pose [0 ], pose )
154+ np .testing .assert_allclose (proc .original_pose [1 ], pose )
153155
154156 # Data snapshot integrity
155157 data = proc .get_data ()
@@ -166,6 +168,128 @@ def test_base_process_without_and_with_recording(socket_mod):
166168 proc .stop ()
167169
168170
171+ def test_save_ignores_pre_recording_original_pose_frames (socket_mod ):
172+ """
173+ save_original data must stay aligned with recorded metadata even if process()
174+ is called before recording starts.
175+ """
176+ BaseProcessorSocket = socket_mod .BaseProcessorSocket
177+ proc = BaseProcessorSocket (bind = ("127.0.0.1" , 0 ), save_original = True )
178+
179+ try :
180+ n_keypoints = 4
181+ bodyparts = _mk_bodyparts (n_keypoints )
182+ proc .set_dlc_cfg ({"metadata" : {"bodyparts" : bodyparts }})
183+
184+ pose = _mk_pose (n_keypoints = n_keypoints )
185+
186+ for _ in range (3 ):
187+ proc .process (pose , frame_time = 0.001 , pose_time = 0.002 )
188+
189+ assert len (proc .original_pose ) == 0
190+ assert len (proc .frame_time ) == 0
191+
192+ proc ._handle_client_message ({"cmd" : "start_recording" })
193+ for _ in range (2 ):
194+ proc .process (pose , frame_time = 0.01 , pose_time = 0.02 )
195+ proc ._handle_client_message ({"cmd" : "stop_recording" })
196+
197+ filename = "unit_test_pre_recording_frames.pkl"
198+ ret = proc .save (filename )
199+ assert ret == 1
200+
201+ data_dir = _module_data_dir (socket_mod )
202+ pkl_path = data_dir / filename
203+ h5_path = data_dir / (Path (filename ).stem + "_DLC.hdf5" )
204+
205+ assert pkl_path .exists ()
206+ assert h5_path .exists ()
207+
208+ with open (pkl_path , "rb" ) as f :
209+ payload = pickle .load (f )
210+
211+ assert len (payload ["frame_time" ]) == 2
212+ assert len (payload ["time_stamp" ]) == 2
213+
214+ pytest .importorskip ("tables" )
215+ df = pd .read_hdf (h5_path , key = "df_with_missing" )
216+ assert df .shape [0 ] == 2
217+ assert list (df ["frame_time" ]) == [0.01 , 0.01 ]
218+ assert list (df ["pose_time" ]) == list (payload ["time_stamp" ])
219+
220+ finally :
221+ proc .stop ()
222+ try :
223+ pkl_path .unlink (missing_ok = True )
224+ h5_path .unlink (missing_ok = True )
225+ except Exception :
226+ pass
227+
228+
229+ @pytest .mark .parametrize (
230+ ("class_name" , "n_keypoints" ),
231+ [
232+ ("ExampleProcessorSocketCalculateMousePose" , 27 ),
233+ ("ExampleProcessorSocketFilterKeypoints" , 10 ),
234+ ],
235+ )
236+ def test_subclass_save_ignores_pre_recording_original_pose_frames (socket_mod , class_name , n_keypoints ):
237+ """
238+ Concrete processors must keep original_pose aligned with recorded metadata
239+ even when process() is called before recording starts.
240+ """
241+ processor_class = getattr (socket_mod , class_name )
242+ proc = processor_class (bind = ("127.0.0.1" , 0 ), save_original = True )
243+
244+ try :
245+ bodyparts = _mk_bodyparts (n_keypoints )
246+ proc .set_dlc_cfg ({"metadata" : {"bodyparts" : bodyparts }})
247+
248+ pose = _mk_pose (n_keypoints = n_keypoints )
249+
250+ for _ in range (4 ):
251+ proc .process (pose , frame_time = 0.001 , pose_time = 0.002 )
252+
253+ assert len (proc .original_pose ) == 0
254+ assert len (proc .frame_time ) == 0
255+
256+ proc ._handle_client_message ({"cmd" : "start_recording" })
257+ for _ in range (3 ):
258+ proc .process (pose , frame_time = 0.01 , pose_time = 0.02 )
259+ proc ._handle_client_message ({"cmd" : "stop_recording" })
260+
261+ filename = f"unit_test_{ class_name } .pkl"
262+ ret = proc .save (filename )
263+ assert ret == 1
264+
265+ data_dir = _module_data_dir (socket_mod )
266+ pkl_path = data_dir / filename
267+ h5_path = data_dir / (Path (filename ).stem + "_DLC.hdf5" )
268+
269+ assert pkl_path .exists ()
270+ assert h5_path .exists ()
271+
272+ with open (pkl_path , "rb" ) as f :
273+ payload = pickle .load (f )
274+
275+ assert len (payload ["frame_time" ]) == 3
276+ assert len (payload ["time_stamp" ]) == 3
277+
278+ pytest .importorskip ("tables" )
279+ df = pd .read_hdf (h5_path , key = "df_with_missing" )
280+ assert df .shape [0 ] == 3
281+ assert list (df ["frame_time" ]) == [0.01 , 0.01 , 0.01 ]
282+ assert list (df ["pose_time" ]) == list (payload ["time_stamp" ])
283+
284+ finally :
285+ proc .stop ()
286+ try :
287+ pkl_path .unlink (missing_ok = True )
288+ h5_path .unlink (missing_ok = True )
289+ except Exception :
290+ pass
291+
292+
169293def test_base_broadcast_handles_bad_connections (socket_mod ):
170294 """
171295 broadcast() must handle failing connections gracefully and drop them.
@@ -270,6 +394,8 @@ def test_save_writes_pkl_and_hdf5_with_labels(socket_mod, caplog):
270394 # frame_time & pose_time columns are present
271395 assert "frame_time" in df .columns
272396 assert "pose_time" in df .columns
397+ assert list (df ["frame_time" ]) == [0.01 , 0.01 , 0.01 ]
398+ assert list (df ["pose_time" ]) == list (payload ["time_stamp" ])
273399
274400 # sanity check values for first row
275401 for i , bp in enumerate (bodyparts ):
0 commit comments