Skip to content

Commit 5b39040

Browse files
committed
fix processor saving the original pose
1 parent 4a8205f commit 5b39040

2 files changed

Lines changed: 138 additions & 15 deletions

File tree

dlclivegui/processors/dlc_processor_socket.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -387,12 +387,11 @@ def broadcast(self, payload):
387387
def process(self, pose, **kwargs):
388388
curr_time = self.timing_func()
389389

390-
if self.save_original:
391-
self.original_pose.append(pose.copy())
392-
393390
self.curr_step += 1
394391

395392
if self.recording:
393+
if self.save_original and self.original_pose is not None:
394+
self.original_pose.append(pose.copy())
396395
self.time_stamp.append(curr_time)
397396
self.step.append(self.curr_step)
398397
self.frame_time.append(kwargs.get("frame_time", -1))
@@ -568,9 +567,6 @@ def _initialize_filters(self, vals):
568567
logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}")
569568

570569
def process(self, pose, **kwargs):
571-
if self.save_original:
572-
self.original_pose.append(pose.copy())
573-
574570
# Extract keypoints and confidence
575571
xy = pose[:, :2]
576572
conf = pose[:, 2]
@@ -623,6 +619,8 @@ def process(self, pose, **kwargs):
623619

624620
# Store processed data (only if recording)
625621
if self.recording:
622+
if self.save_original and self.original_pose is not None:
623+
self.original_pose.append(pose.copy())
626624
self.center_x.append(vals[0])
627625
self.center_y.append(vals[1])
628626
self.heading_direction.append(vals[2])
@@ -680,7 +678,7 @@ class ExampleProcessorSocketFilterKeypoints(BaseProcessorSocket): # pragma: no
680678
},
681679
"save_original": {
682680
"type": "bool",
683-
"default": False,
681+
"default": True,
684682
"description": "Save raw pose arrays for analysis",
685683
},
686684
}
@@ -692,7 +690,7 @@ def __init__(
692690
use_perf_counter=False,
693691
use_filter=False,
694692
filter_kwargs: dict | None = None,
695-
save_original=False,
693+
save_original=True,
696694
p_cutoff=0.4,
697695
):
698696
super().__init__(
@@ -731,9 +729,6 @@ def _initialize_filters(self, vals):
731729
logger.debug(f"Initialized One-Euro filters with parameters: {self.filter_kwargs}")
732730

733731
def process(self, pose, **kwargs):
734-
if self.save_original:
735-
self.original_pose.append(pose.copy())
736-
737732
# Extract keypoints and confidence
738733
xy = pose[:, :2]
739734
conf = pose[:, 2]
@@ -791,6 +786,8 @@ def process(self, pose, **kwargs):
791786

792787
# Store processed data (only if recording)
793788
if self.recording:
789+
if self.save_original and self.original_pose is not None:
790+
self.original_pose.append(pose.copy())
794791
self.center_x.append(vals[0])
795792
self.center_y.append(vals[1])
796793
self.heading_direction.append(vals[2])

tests/custom_processors/test_base_processor.py

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_base_process_without_and_with_recording(socket_mod):
120120
BaseProcessorSocket.process() should:
121121
- increment curr_step always,
122122
- when recording, append time/step/frame_time/pose_time,
123-
- when save_original=True, store copies of pose arrays.
123+
- when save_original=True, store copies of pose arrays only while recording.
124124
"""
125125
BaseProcessorSocket = socket_mod.BaseProcessorSocket
126126
proc = BaseProcessorSocket(bind=("127.0.0.1", 0), save_original=True)
@@ -136,10 +136,9 @@ def test_base_process_without_and_with_recording(socket_mod):
136136
assert len(proc.step) == 0
137137
assert len(proc.frame_time) == 0
138138
assert len(proc.pose_time) == 0
139-
# When not recording, save_original is still respected
139+
# Raw poses must stay aligned with recorded metadata.
140140
assert proc.original_pose is not None
141-
assert len(proc.original_pose) == 1
142-
np.testing.assert_allclose(proc.original_pose[0], pose)
141+
assert len(proc.original_pose) == 0
143142

144143
# Start recording and push two frames
145144
proc._handle_client_message({"cmd": "start_recording"})
@@ -150,6 +149,9 @@ def test_base_process_without_and_with_recording(socket_mod):
150149
assert len(proc.step) == 2
151150
assert len(proc.frame_time) == 2
152151
assert len(proc.pose_time) == 2
152+
assert len(proc.original_pose) == 2
153+
np.testing.assert_allclose(proc.original_pose[0], pose)
154+
np.testing.assert_allclose(proc.original_pose[1], pose)
153155

154156
# Data snapshot integrity
155157
data = proc.get_data()
@@ -166,6 +168,128 @@ def test_base_process_without_and_with_recording(socket_mod):
166168
proc.stop()
167169

168170

171+
def test_save_ignores_pre_recording_original_pose_frames(socket_mod):
172+
"""
173+
save_original data must stay aligned with recorded metadata even if process()
174+
is called before recording starts.
175+
"""
176+
BaseProcessorSocket = socket_mod.BaseProcessorSocket
177+
proc = BaseProcessorSocket(bind=("127.0.0.1", 0), save_original=True)
178+
179+
try:
180+
n_keypoints = 4
181+
bodyparts = _mk_bodyparts(n_keypoints)
182+
proc.set_dlc_cfg({"metadata": {"bodyparts": bodyparts}})
183+
184+
pose = _mk_pose(n_keypoints=n_keypoints)
185+
186+
for _ in range(3):
187+
proc.process(pose, frame_time=0.001, pose_time=0.002)
188+
189+
assert len(proc.original_pose) == 0
190+
assert len(proc.frame_time) == 0
191+
192+
proc._handle_client_message({"cmd": "start_recording"})
193+
for _ in range(2):
194+
proc.process(pose, frame_time=0.01, pose_time=0.02)
195+
proc._handle_client_message({"cmd": "stop_recording"})
196+
197+
filename = "unit_test_pre_recording_frames.pkl"
198+
ret = proc.save(filename)
199+
assert ret == 1
200+
201+
data_dir = _module_data_dir(socket_mod)
202+
pkl_path = data_dir / filename
203+
h5_path = data_dir / (Path(filename).stem + "_DLC.hdf5")
204+
205+
assert pkl_path.exists()
206+
assert h5_path.exists()
207+
208+
with open(pkl_path, "rb") as f:
209+
payload = pickle.load(f)
210+
211+
assert len(payload["frame_time"]) == 2
212+
assert len(payload["time_stamp"]) == 2
213+
214+
pytest.importorskip("tables")
215+
df = pd.read_hdf(h5_path, key="df_with_missing")
216+
assert df.shape[0] == 2
217+
assert list(df["frame_time"]) == [0.01, 0.01]
218+
assert list(df["pose_time"]) == list(payload["time_stamp"])
219+
220+
finally:
221+
proc.stop()
222+
try:
223+
pkl_path.unlink(missing_ok=True)
224+
h5_path.unlink(missing_ok=True)
225+
except Exception:
226+
pass
227+
228+
229+
@pytest.mark.parametrize(
230+
("class_name", "n_keypoints"),
231+
[
232+
("ExampleProcessorSocketCalculateMousePose", 27),
233+
("ExampleProcessorSocketFilterKeypoints", 10),
234+
],
235+
)
236+
def test_subclass_save_ignores_pre_recording_original_pose_frames(socket_mod, class_name, n_keypoints):
237+
"""
238+
Concrete processors must keep original_pose aligned with recorded metadata
239+
even when process() is called before recording starts.
240+
"""
241+
processor_class = getattr(socket_mod, class_name)
242+
proc = processor_class(bind=("127.0.0.1", 0), save_original=True)
243+
244+
try:
245+
bodyparts = _mk_bodyparts(n_keypoints)
246+
proc.set_dlc_cfg({"metadata": {"bodyparts": bodyparts}})
247+
248+
pose = _mk_pose(n_keypoints=n_keypoints)
249+
250+
for _ in range(4):
251+
proc.process(pose, frame_time=0.001, pose_time=0.002)
252+
253+
assert len(proc.original_pose) == 0
254+
assert len(proc.frame_time) == 0
255+
256+
proc._handle_client_message({"cmd": "start_recording"})
257+
for _ in range(3):
258+
proc.process(pose, frame_time=0.01, pose_time=0.02)
259+
proc._handle_client_message({"cmd": "stop_recording"})
260+
261+
filename = f"unit_test_{class_name}.pkl"
262+
ret = proc.save(filename)
263+
assert ret == 1
264+
265+
data_dir = _module_data_dir(socket_mod)
266+
pkl_path = data_dir / filename
267+
h5_path = data_dir / (Path(filename).stem + "_DLC.hdf5")
268+
269+
assert pkl_path.exists()
270+
assert h5_path.exists()
271+
272+
with open(pkl_path, "rb") as f:
273+
payload = pickle.load(f)
274+
275+
assert len(payload["frame_time"]) == 3
276+
assert len(payload["time_stamp"]) == 3
277+
278+
pytest.importorskip("tables")
279+
df = pd.read_hdf(h5_path, key="df_with_missing")
280+
assert df.shape[0] == 3
281+
assert list(df["frame_time"]) == [0.01, 0.01, 0.01]
282+
assert list(df["pose_time"]) == list(payload["time_stamp"])
283+
284+
finally:
285+
proc.stop()
286+
try:
287+
pkl_path.unlink(missing_ok=True)
288+
h5_path.unlink(missing_ok=True)
289+
except Exception:
290+
pass
291+
292+
169293
def test_base_broadcast_handles_bad_connections(socket_mod):
170294
"""
171295
broadcast() must handle failing connections gracefully and drop them.
@@ -270,6 +394,8 @@ def test_save_writes_pkl_and_hdf5_with_labels(socket_mod, caplog):
270394
# frame_time & pose_time columns are present
271395
assert "frame_time" in df.columns
272396
assert "pose_time" in df.columns
397+
assert list(df["frame_time"]) == [0.01, 0.01, 0.01]
398+
assert list(df["pose_time"]) == list(payload["time_stamp"])
273399

274400
# sanity check values for first row
275401
for i, bp in enumerate(bodyparts):

0 commit comments

Comments
 (0)