Skip to content

Commit 7ec008f

Browse files
committed
Add comprehensive tests for Assembler and Assembly classes
Introduces detailed unit tests for the Assembler, Assembly, Joint, and Link classes in dlclive.core.inferenceutils. The new tests cover metadata parsing, detection flattening, link extraction, assembly building, Mahalanobis distance calculation, I/O helpers, and various Assembly operations, improving test coverage and reliability.
1 parent 7b8113e commit 7ec008f

3 files changed

Lines changed: 838 additions & 98 deletions

File tree

dlclive/core/inferenceutils.py

Lines changed: 51 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
#
99
# Licensed under GNU Lesser General Public License v3.0
1010
#
11+
12+
13+
# NOTE DUPLICATED from deeplabcut/core/inferenceutils.py
1114
from __future__ import annotations
1215

1316
import heapq
@@ -17,9 +20,10 @@
1720
import pickle
1821
import warnings
1922
from collections import defaultdict
23+
from collections.abc import Iterable
2024
from dataclasses import dataclass
2125
from math import erf, sqrt
22-
from typing import Any, Iterable, Tuple
26+
from typing import Any
2327

2428
import networkx as nx
2529
import numpy as np
@@ -41,7 +45,7 @@ def _conv_square_to_condensed_indices(ind_row, ind_col, n):
4145
return n * ind_col - ind_col * (ind_col + 1) // 2 + ind_row - 1 - ind_col
4246

4347

44-
Position = Tuple[float, float]
48+
Position = tuple[float, float]
4549

4650

4751
@dataclass(frozen=True)
@@ -61,9 +65,7 @@ def __init__(self, j1, j2, affinity=1):
6165
self._length = sqrt((j1.pos[0] - j2.pos[0]) ** 2 + (j1.pos[1] - j2.pos[1]) ** 2)
6266

6367
def __repr__(self):
64-
return (
65-
f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}"
66-
)
68+
return f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}"
6769

6870
@property
6971
def confidence(self):
@@ -155,7 +157,7 @@ def soft_identity(self):
155157
unq, idx, cnt = np.unique(data[:, 3], return_inverse=True, return_counts=True)
156158
avg = np.bincount(idx, weights=data[:, 2]) / cnt
157159
soft = softmax(avg)
158-
return dict(zip(unq.astype(int), soft))
160+
return dict(zip(unq.astype(int), soft, strict=False))
159161

160162
@property
161163
def affinity(self):
@@ -261,9 +263,7 @@ def __init__(
261263
self.max_overlap = max_overlap
262264
self._has_identity = "identity" in self[0]
263265
if identity_only and not self._has_identity:
264-
warnings.warn(
265-
"The network was not trained with identity; setting `identity_only` to False."
266-
)
266+
warnings.warn("The network was not trained with identity; setting `identity_only` to False.", stacklevel=2)
267267
self.identity_only = identity_only & self._has_identity
268268
self.nan_policy = nan_policy
269269
self.force_fusion = force_fusion
@@ -344,15 +344,15 @@ def calibrate(self, train_data_file):
344344
pass
345345
n_bpts = len(df.columns.get_level_values("bodyparts").unique())
346346
if n_bpts == 1:
347-
warnings.warn("There is only one keypoint; skipping calibration...")
347+
warnings.warn("There is only one keypoint; skipping calibration...", stacklevel=2)
348348
return
349349

350350
xy = df.to_numpy().reshape((-1, n_bpts, 2))
351351
frac_valid = np.mean(~np.isnan(xy), axis=(1, 2))
352352
# Only keeps skeletons that are more than 90% complete
353353
xy = xy[frac_valid >= 0.9]
354354
if not xy.size:
355-
warnings.warn("No complete poses were found. Skipping calibration...")
355+
warnings.warn("No complete poses were found. Skipping calibration...", stacklevel=2)
356356
return
357357

358358
# TODO Normalize dists by longest length?
@@ -368,13 +368,9 @@ def calibrate(self, train_data_file):
368368
self.safe_edge = True
369369
except np.linalg.LinAlgError:
370370
# Covariance matrix estimation fails due to numerical singularities
371-
warnings.warn(
372-
"The assembler could not be robustly calibrated. Continuing without it..."
373-
)
371+
warnings.warn("The assembler could not be robustly calibrated. Continuing without it...", stacklevel=2)
374372

375-
def calc_assembly_mahalanobis_dist(
376-
self, assembly, return_proba=False, nan_policy="little"
377-
):
373+
def calc_assembly_mahalanobis_dist(self, assembly, return_proba=False, nan_policy="little"):
378374
if self._kde is None:
379375
raise ValueError("Assembler should be calibrated first with training data.")
380376

@@ -428,10 +424,10 @@ def _flatten_detections(data_dict):
428424
ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence]
429425
else:
430426
ids = [arr.argmax(axis=1) for arr in ids]
431-
for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids)):
427+
for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids, strict=False)):
432428
if not np.any(coords):
433429
continue
434-
for xy, p, g in zip(coords, conf, id_):
430+
for xy, p, g in zip(coords, conf, id_, strict=False):
435431
joint = Joint(tuple(xy), p.item(), i, ind, g)
436432
ind += 1
437433
yield joint
@@ -453,9 +449,7 @@ def extract_best_links(self, joints_dict, costs, trees=None):
453449
aff[np.isnan(aff)] = 0
454450

455451
if trees:
456-
vecs = np.vstack(
457-
[[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t]
458-
)
452+
vecs = np.vstack([[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t])
459453
dists = []
460454
for n, tree in enumerate(trees, start=1):
461455
d, _ = tree.query(vecs)
@@ -464,45 +458,34 @@ def extract_best_links(self, joints_dict, costs, trees=None):
464458
aff *= w.reshape(aff.shape)
465459

466460
if self.greedy:
467-
conf = np.asarray(
468-
[
469-
[det_s.confidence * det_t.confidence for det_t in dets_t]
470-
for det_s in dets_s
471-
]
472-
)
473-
rows, cols = np.where(
474-
(conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity)
475-
)
461+
conf = np.asarray([[det_s.confidence * det_t.confidence for det_t in dets_t] for det_s in dets_s])
462+
rows, cols = np.where((conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity))
476463
candidates = sorted(
477-
zip(rows, cols, aff[rows, cols], lengths[rows, cols]),
464+
zip(rows, cols, aff[rows, cols], lengths[rows, cols], strict=False),
478465
key=lambda x: x[2],
479466
reverse=True,
480467
)
481468
i_seen = set()
482469
j_seen = set()
483-
for i, j, w, l in candidates:
470+
for i, j, w, _l in candidates:
484471
if i not in i_seen and j not in j_seen:
485472
i_seen.add(i)
486473
j_seen.add(j)
487474
links.append(Link(dets_s[i], dets_t[j], w))
488475
if len(i_seen) == self.max_n_individuals:
489476
break
490477
else: # Optimal keypoint pairing
491-
inds_s = sorted(
492-
range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True
493-
)[: self.max_n_individuals]
494-
inds_t = sorted(
495-
range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True
496-
)[: self.max_n_individuals]
497-
keep_s = [
498-
ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff
478+
inds_s = sorted(range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True)[
479+
: self.max_n_individuals
499480
]
500-
keep_t = [
501-
ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff
481+
inds_t = sorted(range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True)[
482+
: self.max_n_individuals
502483
]
484+
keep_s = [ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff]
485+
keep_t = [ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff]
503486
aff = aff[np.ix_(keep_s, keep_t)]
504487
rows, cols = linear_sum_assignment(aff, maximize=True)
505-
for row, col in zip(rows, cols):
488+
for row, col in zip(rows, cols, strict=False):
506489
w = aff[row, col]
507490
if w >= self.min_affinity:
508491
links.append(Link(dets_s[keep_s[row]], dets_t[keep_t[col]], w))
@@ -538,19 +521,17 @@ def push_to_stack(i):
538521
if new_ind in assembled:
539522
continue
540523
if safe_edge:
541-
d_old = self.calc_assembly_mahalanobis_dist(
542-
assembly, nan_policy=nan_policy
543-
)
524+
d_old = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy)
544525
success = assembly.add_link(best, store_dict=True)
545526
if not success:
546527
assembly._dict = dict()
547528
continue
548529
d = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy)
549530
if d < d_old:
550531
push_to_stack(new_ind)
551-
if tabu:
552-
_, _, link = heapq.heappop(tabu)
553-
heapq.heappush(stack, (-link.affinity, next(counter), link))
532+
if tabu:
533+
_, _, link = heapq.heappop(tabu)
534+
heapq.heappush(stack, (-link.affinity, next(counter), link))
554535
else:
555536
heapq.heappush(tabu, (d - d_old, next(counter), best))
556537
assembly.__dict__.update(assembly._dict)
@@ -593,9 +574,7 @@ def build_assemblies(self, links):
593574
continue
594575
assembly = Assembly(self.n_multibodyparts)
595576
assembly.add_link(link)
596-
self._fill_assembly(
597-
assembly, lookup, assembled, self.safe_edge, self.nan_policy
598-
)
577+
self._fill_assembly(assembly, lookup, assembled, self.safe_edge, self.nan_policy)
599578
for assembly_link in assembly._links:
600579
i, j = assembly_link.idx
601580
lookup[i].pop(j)
@@ -607,10 +586,7 @@ def build_assemblies(self, links):
607586
n_extra = len(assemblies) - self.max_n_individuals
608587
if n_extra > 0:
609588
if self.safe_edge:
610-
ds_old = [
611-
self.calc_assembly_mahalanobis_dist(assembly)
612-
for assembly in assemblies
613-
]
589+
ds_old = [self.calc_assembly_mahalanobis_dist(assembly) for assembly in assemblies]
614590
while len(assemblies) > self.max_n_individuals:
615591
ds = []
616592
for i, j in itertools.combinations(range(len(assemblies)), 2):
@@ -665,7 +641,7 @@ def build_assemblies(self, links):
665641
for idx in store[j]._idx:
666642
store[idx] = store[i]
667643
except KeyError:
668-
# Some links may reference indices that were never added to `store`;
644+
# Some links may reference indices that were never added to `store`;
669645
# in that case we intentionally skip merging for this link
670646
pass
671647

@@ -742,10 +718,7 @@ def _assemble(self, data_dict, ind_frame):
742718
for _, group in groups:
743719
ass = Assembly(self.n_multibodyparts)
744720
for joint in sorted(group, key=lambda x: x.confidence, reverse=True):
745-
if (
746-
joint.confidence >= self.pcutoff
747-
and joint.label < self.n_multibodyparts
748-
):
721+
if joint.confidence >= self.pcutoff and joint.label < self.n_multibodyparts:
749722
ass.add_joint(joint)
750723
if len(ass):
751724
assemblies.append(ass)
@@ -774,24 +747,18 @@ def _assemble(self, data_dict, ind_frame):
774747
assembled.update(assembled_)
775748

776749
# Remove invalid assemblies
777-
discarded = set(
778-
joint
779-
for joint in joints
780-
if joint.idx not in assembled and np.isfinite(joint.confidence)
781-
)
750+
discarded = set(joint for joint in joints if joint.idx not in assembled and np.isfinite(joint.confidence))
782751
for assembly in assemblies[::-1]:
783752
if 0 < assembly.n_links < self.min_n_links or not len(assembly):
784753
for link in assembly._links:
785754
discarded.update((link.j1, link.j2))
786755
assemblies.remove(assembly)
787756
if 0 < self.max_overlap < 1: # Non-maximum pose suppression
788757
if self._kde is not None:
789-
scores = [
790-
-self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies
791-
]
758+
scores = [-self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies]
792759
else:
793760
scores = [ass._affinity for ass in assemblies]
794-
lst = list(zip(scores, assemblies))
761+
lst = list(zip(scores, assemblies, strict=False))
795762
assemblies = []
796763
while lst:
797764
temp = max(lst, key=lambda x: x[0])
@@ -857,9 +824,7 @@ def wrapped(i):
857824
n_frames = len(self.metadata["imnames"])
858825
with multiprocessing.Pool(n_processes) as p:
859826
with tqdm(total=n_frames) as pbar:
860-
for i, (assemblies, unique) in p.imap_unordered(
861-
wrapped, range(n_frames), chunksize=chunk_size
862-
):
827+
for i, (assemblies, unique) in p.imap_unordered(wrapped, range(n_frames), chunksize=chunk_size):
863828
if assemblies:
864829
self.assemblies[i] = assemblies
865830
if unique is not None:
@@ -878,9 +843,7 @@ def parse_metadata(data):
878843
params["joint_names"] = data["metadata"]["all_joints_names"]
879844
params["num_joints"] = len(params["joint_names"])
880845
params["paf_graph"] = data["metadata"]["PAFgraph"]
881-
params["paf"] = data["metadata"].get(
882-
"PAFinds", np.arange(len(params["joint_names"]))
883-
)
846+
params["paf"] = data["metadata"].get("PAFinds", np.arange(len(params["joint_names"])))
884847
params["bpts"] = params["ibpts"] = range(params["num_joints"])
885848
params["imnames"] = [fn for fn in list(data) if fn != "metadata"]
886849
return params
@@ -970,11 +933,7 @@ def calc_object_keypoint_similarity(
970933
else:
971934
oks = []
972935
xy_preds = [xy_pred]
973-
combos = (
974-
pair
975-
for l in range(len(symmetric_kpts))
976-
for pair in itertools.combinations(symmetric_kpts, l + 1)
977-
)
936+
combos = (pair for l in range(len(symmetric_kpts)) for pair in itertools.combinations(symmetric_kpts, l + 1))
978937
for pairs in combos:
979938
# Swap corresponding keypoints
980939
tmp = xy_pred.copy()
@@ -1011,9 +970,7 @@ def match_assemblies(
1011970
num_ground_truth = len(ground_truth)
1012971

1013972
# Sort predictions by score
1014-
inds_pred = np.argsort(
1015-
[ins.affinity if ins.n_links else ins.confidence for ins in predictions]
1016-
)[::-1]
973+
inds_pred = np.argsort([ins.affinity if ins.n_links else ins.confidence for ins in predictions])[::-1]
1017974
predictions = np.asarray(predictions)[inds_pred]
1018975

1019976
# indices of unmatched ground truth assemblies
@@ -1074,7 +1031,7 @@ def match_assemblies(
10741031
if ~np.isnan(oks):
10751032
mat[i, j] = oks
10761033
rows, cols = linear_sum_assignment(mat, maximize=True)
1077-
for row, col in zip(rows, cols):
1034+
for row, col in zip(rows, cols, strict=False):
10781035
matched[row].ground_truth = ground_truth[col]
10791036
matched[row].oks = mat[row, col]
10801037
_ = inds_true.remove(col)
@@ -1087,7 +1044,7 @@ def parse_ground_truth_data_file(h5_file):
10871044
try:
10881045
df.drop("single", axis=1, level="individuals", inplace=True)
10891046
except KeyError:
1090-
# Ignore if the "single" individual column is absent
1047+
# Ignore if the "single" individual column is absent
10911048
pass
10921049
# Cast columns of dtype 'object' to float to avoid TypeError
10931050
# further down in _parse_ground_truth_data.
@@ -1120,15 +1077,13 @@ def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)):
11201077
raise ValueError(f"Invalid criterion {criterion}.")
11211078

11221079
if len(qs) != 2:
1123-
raise ValueError(
1124-
"Two percentiles (for lower and upper bounds) should be given."
1125-
)
1080+
raise ValueError("Two percentiles (for lower and upper bounds) should be given.")
11261081

11271082
tuples = []
11281083
for frame_ind, assemblies in dict_of_assemblies.items():
11291084
for assembly in assemblies:
11301085
tuples.append((frame_ind, getattr(assembly, criterion)))
1131-
frame_inds, vals = zip(*tuples)
1086+
frame_inds, vals = zip(*tuples, strict=False)
11321087
vals = np.asarray(vals)
11331088
lo, up = np.percentile(vals, qs, interpolation="nearest")
11341089
inds = np.flatnonzero((vals < lo) | (vals > up)).tolist()
@@ -1226,9 +1181,7 @@ def evaluate_assembly_greedy(
12261181
oks = np.asarray([match.oks for match in all_matched])[sorted_pred_indices]
12271182

12281183
# Compute prediction and recall
1229-
p, r = _compute_precision_and_recall(
1230-
total_gt_assemblies, oks, oks_t, recall_thresholds
1231-
)
1184+
p, r = _compute_precision_and_recall(total_gt_assemblies, oks, oks_t, recall_thresholds)
12321185
precisions.append(p)
12331186
recalls.append(r)
12341187

@@ -1246,12 +1199,14 @@ def evaluate_assembly(
12461199
ass_pred_dict,
12471200
ass_true_dict,
12481201
oks_sigma=0.072,
1249-
oks_thresholds=np.linspace(0.5, 0.95, 10),
1202+
oks_thresholds=None,
12501203
margin=0,
12511204
symmetric_kpts=None,
12521205
greedy_matching=False,
12531206
with_tqdm: bool = True,
12541207
):
1208+
if oks_thresholds is None:
1209+
oks_thresholds = np.linspace(0.5, 0.95, 10)
12551210
if greedy_matching:
12561211
return evaluate_assembly_greedy(
12571212
ass_true_dict,
@@ -1299,9 +1254,7 @@ def evaluate_assembly(
12991254
precisions = []
13001255
recalls = []
13011256
for t in oks_thresholds:
1302-
p, r = _compute_precision_and_recall(
1303-
total_gt_assemblies, oks, t, recall_thresholds
1304-
)
1257+
p, r = _compute_precision_and_recall(total_gt_assemblies, oks, t, recall_thresholds)
13051258
precisions.append(p)
13061259
recalls.append(r)
13071260

0 commit comments

Comments
 (0)