88#
99# Licensed under GNU Lesser General Public License v3.0
1010#
11+
12+
13+ # NOTE DUPLICATED from deeplabcut/core/inferenceutils.py
1114from __future__ import annotations
1215
1316import heapq
1720import pickle
1821import warnings
1922from collections import defaultdict
23+ from collections .abc import Iterable
2024from dataclasses import dataclass
2125from math import erf , sqrt
22- from typing import Any , Iterable , Tuple
26+ from typing import Any
2327
2428import networkx as nx
2529import numpy as np
@@ -41,7 +45,7 @@ def _conv_square_to_condensed_indices(ind_row, ind_col, n):
4145 return n * ind_col - ind_col * (ind_col + 1 ) // 2 + ind_row - 1 - ind_col
4246
4347
44- Position = Tuple [float , float ]
48+ Position = tuple [float , float ]
4549
4650
4751@dataclass (frozen = True )
@@ -61,9 +65,7 @@ def __init__(self, j1, j2, affinity=1):
6165 self ._length = sqrt ((j1 .pos [0 ] - j2 .pos [0 ]) ** 2 + (j1 .pos [1 ] - j2 .pos [1 ]) ** 2 )
6266
6367 def __repr__ (self ):
64- return (
65- f"Link { self .idx } , affinity={ self .affinity :.2f} , length={ self .length :.2f} "
66- )
68+ return f"Link { self .idx } , affinity={ self .affinity :.2f} , length={ self .length :.2f} "
6769
6870 @property
6971 def confidence (self ):
@@ -155,7 +157,7 @@ def soft_identity(self):
155157 unq , idx , cnt = np .unique (data [:, 3 ], return_inverse = True , return_counts = True )
156158 avg = np .bincount (idx , weights = data [:, 2 ]) / cnt
157159 soft = softmax (avg )
158- return dict (zip (unq .astype (int ), soft ))
160+ return dict (zip (unq .astype (int ), soft , strict = False ))
159161
160162 @property
161163 def affinity (self ):
@@ -261,9 +263,7 @@ def __init__(
261263 self .max_overlap = max_overlap
262264 self ._has_identity = "identity" in self [0 ]
263265 if identity_only and not self ._has_identity :
264- warnings .warn (
265- "The network was not trained with identity; setting `identity_only` to False."
266- )
266+ warnings .warn ("The network was not trained with identity; setting `identity_only` to False." , stacklevel = 2 )
267267 self .identity_only = identity_only & self ._has_identity
268268 self .nan_policy = nan_policy
269269 self .force_fusion = force_fusion
@@ -344,15 +344,15 @@ def calibrate(self, train_data_file):
344344 pass
345345 n_bpts = len (df .columns .get_level_values ("bodyparts" ).unique ())
346346 if n_bpts == 1 :
347- warnings .warn ("There is only one keypoint; skipping calibration..." )
347+ warnings .warn ("There is only one keypoint; skipping calibration..." , stacklevel = 2 )
348348 return
349349
350350 xy = df .to_numpy ().reshape ((- 1 , n_bpts , 2 ))
351351 frac_valid = np .mean (~ np .isnan (xy ), axis = (1 , 2 ))
352352 # Only keeps skeletons that are more than 90% complete
353353 xy = xy [frac_valid >= 0.9 ]
354354 if not xy .size :
355- warnings .warn ("No complete poses were found. Skipping calibration..." )
355+ warnings .warn ("No complete poses were found. Skipping calibration..." , stacklevel = 2 )
356356 return
357357
358358 # TODO Normalize dists by longest length?
@@ -368,13 +368,9 @@ def calibrate(self, train_data_file):
368368 self .safe_edge = True
369369 except np .linalg .LinAlgError :
370370 # Covariance matrix estimation fails due to numerical singularities
371- warnings .warn (
372- "The assembler could not be robustly calibrated. Continuing without it..."
373- )
371+ warnings .warn ("The assembler could not be robustly calibrated. Continuing without it..." , stacklevel = 2 )
374372
375- def calc_assembly_mahalanobis_dist (
376- self , assembly , return_proba = False , nan_policy = "little"
377- ):
373+ def calc_assembly_mahalanobis_dist (self , assembly , return_proba = False , nan_policy = "little" ):
378374 if self ._kde is None :
379375 raise ValueError ("Assembler should be calibrated first with training data." )
380376
@@ -428,10 +424,10 @@ def _flatten_detections(data_dict):
428424 ids = [np .ones (len (arr ), dtype = int ) * - 1 for arr in confidence ]
429425 else :
430426 ids = [arr .argmax (axis = 1 ) for arr in ids ]
431- for i , (coords , conf , id_ ) in enumerate (zip (coordinates , confidence , ids )):
427+ for i , (coords , conf , id_ ) in enumerate (zip (coordinates , confidence , ids , strict = False )):
432428 if not np .any (coords ):
433429 continue
434- for xy , p , g in zip (coords , conf , id_ ):
430+ for xy , p , g in zip (coords , conf , id_ , strict = False ):
435431 joint = Joint (tuple (xy ), p .item (), i , ind , g )
436432 ind += 1
437433 yield joint
@@ -453,9 +449,7 @@ def extract_best_links(self, joints_dict, costs, trees=None):
453449 aff [np .isnan (aff )] = 0
454450
455451 if trees :
456- vecs = np .vstack (
457- [[* det_s .pos , * det_t .pos ] for det_s in dets_s for det_t in dets_t ]
458- )
452+ vecs = np .vstack ([[* det_s .pos , * det_t .pos ] for det_s in dets_s for det_t in dets_t ])
459453 dists = []
460454 for n , tree in enumerate (trees , start = 1 ):
461455 d , _ = tree .query (vecs )
@@ -464,45 +458,34 @@ def extract_best_links(self, joints_dict, costs, trees=None):
464458 aff *= w .reshape (aff .shape )
465459
466460 if self .greedy :
467- conf = np .asarray (
468- [
469- [det_s .confidence * det_t .confidence for det_t in dets_t ]
470- for det_s in dets_s
471- ]
472- )
473- rows , cols = np .where (
474- (conf >= self .pcutoff * self .pcutoff ) & (aff >= self .min_affinity )
475- )
461+ conf = np .asarray ([[det_s .confidence * det_t .confidence for det_t in dets_t ] for det_s in dets_s ])
462+ rows , cols = np .where ((conf >= self .pcutoff * self .pcutoff ) & (aff >= self .min_affinity ))
476463 candidates = sorted (
477- zip (rows , cols , aff [rows , cols ], lengths [rows , cols ]),
464+ zip (rows , cols , aff [rows , cols ], lengths [rows , cols ], strict = False ),
478465 key = lambda x : x [2 ],
479466 reverse = True ,
480467 )
481468 i_seen = set ()
482469 j_seen = set ()
483- for i , j , w , l in candidates :
470+ for i , j , w , _l in candidates :
484471 if i not in i_seen and j not in j_seen :
485472 i_seen .add (i )
486473 j_seen .add (j )
487474 links .append (Link (dets_s [i ], dets_t [j ], w ))
488475 if len (i_seen ) == self .max_n_individuals :
489476 break
490477 else : # Optimal keypoint pairing
491- inds_s = sorted (
492- range (len (dets_s )), key = lambda x : dets_s [x ].confidence , reverse = True
493- )[: self .max_n_individuals ]
494- inds_t = sorted (
495- range (len (dets_t )), key = lambda x : dets_t [x ].confidence , reverse = True
496- )[: self .max_n_individuals ]
497- keep_s = [
498- ind for ind in inds_s if dets_s [ind ].confidence >= self .pcutoff
478+ inds_s = sorted (range (len (dets_s )), key = lambda x : dets_s [x ].confidence , reverse = True )[
479+ : self .max_n_individuals
499480 ]
500- keep_t = [
501- ind for ind in inds_t if dets_t [ ind ]. confidence >= self .pcutoff
481+ inds_t = sorted ( range ( len ( dets_t )), key = lambda x : dets_t [ x ]. confidence , reverse = True ) [
482+ : self .max_n_individuals
502483 ]
484+ keep_s = [ind for ind in inds_s if dets_s [ind ].confidence >= self .pcutoff ]
485+ keep_t = [ind for ind in inds_t if dets_t [ind ].confidence >= self .pcutoff ]
503486 aff = aff [np .ix_ (keep_s , keep_t )]
504487 rows , cols = linear_sum_assignment (aff , maximize = True )
505- for row , col in zip (rows , cols ):
488+ for row , col in zip (rows , cols , strict = False ):
506489 w = aff [row , col ]
507490 if w >= self .min_affinity :
508491 links .append (Link (dets_s [keep_s [row ]], dets_t [keep_t [col ]], w ))
@@ -538,19 +521,17 @@ def push_to_stack(i):
538521 if new_ind in assembled :
539522 continue
540523 if safe_edge :
541- d_old = self .calc_assembly_mahalanobis_dist (
542- assembly , nan_policy = nan_policy
543- )
524+ d_old = self .calc_assembly_mahalanobis_dist (assembly , nan_policy = nan_policy )
544525 success = assembly .add_link (best , store_dict = True )
545526 if not success :
546527 assembly ._dict = dict ()
547528 continue
548529 d = self .calc_assembly_mahalanobis_dist (assembly , nan_policy = nan_policy )
549530 if d < d_old :
550531 push_to_stack (new_ind )
551- if tabu :
552- _ , _ , link = heapq .heappop (tabu )
553- heapq .heappush (stack , (- link .affinity , next (counter ), link ))
532+ if tabu :
533+ _ , _ , link = heapq .heappop (tabu )
534+ heapq .heappush (stack , (- link .affinity , next (counter ), link ))
554535 else :
555536 heapq .heappush (tabu , (d - d_old , next (counter ), best ))
556537 assembly .__dict__ .update (assembly ._dict )
@@ -593,9 +574,7 @@ def build_assemblies(self, links):
593574 continue
594575 assembly = Assembly (self .n_multibodyparts )
595576 assembly .add_link (link )
596- self ._fill_assembly (
597- assembly , lookup , assembled , self .safe_edge , self .nan_policy
598- )
577+ self ._fill_assembly (assembly , lookup , assembled , self .safe_edge , self .nan_policy )
599578 for assembly_link in assembly ._links :
600579 i , j = assembly_link .idx
601580 lookup [i ].pop (j )
@@ -607,10 +586,7 @@ def build_assemblies(self, links):
607586 n_extra = len (assemblies ) - self .max_n_individuals
608587 if n_extra > 0 :
609588 if self .safe_edge :
610- ds_old = [
611- self .calc_assembly_mahalanobis_dist (assembly )
612- for assembly in assemblies
613- ]
589+ ds_old = [self .calc_assembly_mahalanobis_dist (assembly ) for assembly in assemblies ]
614590 while len (assemblies ) > self .max_n_individuals :
615591 ds = []
616592 for i , j in itertools .combinations (range (len (assemblies )), 2 ):
@@ -665,7 +641,7 @@ def build_assemblies(self, links):
665641 for idx in store [j ]._idx :
666642 store [idx ] = store [i ]
667643 except KeyError :
668- # Some links may reference indices that were never added to `store`;
644+ # Some links may reference indices that were never added to `store`;
669645 # in that case we intentionally skip merging for this link
670646 pass
671647
@@ -742,10 +718,7 @@ def _assemble(self, data_dict, ind_frame):
742718 for _ , group in groups :
743719 ass = Assembly (self .n_multibodyparts )
744720 for joint in sorted (group , key = lambda x : x .confidence , reverse = True ):
745- if (
746- joint .confidence >= self .pcutoff
747- and joint .label < self .n_multibodyparts
748- ):
721+ if joint .confidence >= self .pcutoff and joint .label < self .n_multibodyparts :
749722 ass .add_joint (joint )
750723 if len (ass ):
751724 assemblies .append (ass )
@@ -774,24 +747,18 @@ def _assemble(self, data_dict, ind_frame):
774747 assembled .update (assembled_ )
775748
776749 # Remove invalid assemblies
777- discarded = set (
778- joint
779- for joint in joints
780- if joint .idx not in assembled and np .isfinite (joint .confidence )
781- )
750+ discarded = set (joint for joint in joints if joint .idx not in assembled and np .isfinite (joint .confidence ))
782751 for assembly in assemblies [::- 1 ]:
783752 if 0 < assembly .n_links < self .min_n_links or not len (assembly ):
784753 for link in assembly ._links :
785754 discarded .update ((link .j1 , link .j2 ))
786755 assemblies .remove (assembly )
787756 if 0 < self .max_overlap < 1 : # Non-maximum pose suppression
788757 if self ._kde is not None :
789- scores = [
790- - self .calc_assembly_mahalanobis_dist (ass ) for ass in assemblies
791- ]
758+ scores = [- self .calc_assembly_mahalanobis_dist (ass ) for ass in assemblies ]
792759 else :
793760 scores = [ass ._affinity for ass in assemblies ]
794- lst = list (zip (scores , assemblies ))
761+ lst = list (zip (scores , assemblies , strict = False ))
795762 assemblies = []
796763 while lst :
797764 temp = max (lst , key = lambda x : x [0 ])
@@ -857,9 +824,7 @@ def wrapped(i):
857824 n_frames = len (self .metadata ["imnames" ])
858825 with multiprocessing .Pool (n_processes ) as p :
859826 with tqdm (total = n_frames ) as pbar :
860- for i , (assemblies , unique ) in p .imap_unordered (
861- wrapped , range (n_frames ), chunksize = chunk_size
862- ):
827+ for i , (assemblies , unique ) in p .imap_unordered (wrapped , range (n_frames ), chunksize = chunk_size ):
863828 if assemblies :
864829 self .assemblies [i ] = assemblies
865830 if unique is not None :
@@ -878,9 +843,7 @@ def parse_metadata(data):
878843 params ["joint_names" ] = data ["metadata" ]["all_joints_names" ]
879844 params ["num_joints" ] = len (params ["joint_names" ])
880845 params ["paf_graph" ] = data ["metadata" ]["PAFgraph" ]
881- params ["paf" ] = data ["metadata" ].get (
882- "PAFinds" , np .arange (len (params ["joint_names" ]))
883- )
846+ params ["paf" ] = data ["metadata" ].get ("PAFinds" , np .arange (len (params ["joint_names" ])))
884847 params ["bpts" ] = params ["ibpts" ] = range (params ["num_joints" ])
885848 params ["imnames" ] = [fn for fn in list (data ) if fn != "metadata" ]
886849 return params
@@ -970,11 +933,7 @@ def calc_object_keypoint_similarity(
970933 else :
971934 oks = []
972935 xy_preds = [xy_pred ]
973- combos = (
974- pair
975- for l in range (len (symmetric_kpts ))
976- for pair in itertools .combinations (symmetric_kpts , l + 1 )
977- )
936+ combos = (pair for l in range (len (symmetric_kpts )) for pair in itertools .combinations (symmetric_kpts , l + 1 ))
978937 for pairs in combos :
979938 # Swap corresponding keypoints
980939 tmp = xy_pred .copy ()
@@ -1011,9 +970,7 @@ def match_assemblies(
1011970 num_ground_truth = len (ground_truth )
1012971
1013972 # Sort predictions by score
1014- inds_pred = np .argsort (
1015- [ins .affinity if ins .n_links else ins .confidence for ins in predictions ]
1016- )[::- 1 ]
973+ inds_pred = np .argsort ([ins .affinity if ins .n_links else ins .confidence for ins in predictions ])[::- 1 ]
1017974 predictions = np .asarray (predictions )[inds_pred ]
1018975
1019976 # indices of unmatched ground truth assemblies
@@ -1074,7 +1031,7 @@ def match_assemblies(
10741031 if ~ np .isnan (oks ):
10751032 mat [i , j ] = oks
10761033 rows , cols = linear_sum_assignment (mat , maximize = True )
1077- for row , col in zip (rows , cols ):
1034+ for row , col in zip (rows , cols , strict = False ):
10781035 matched [row ].ground_truth = ground_truth [col ]
10791036 matched [row ].oks = mat [row , col ]
10801037 _ = inds_true .remove (col )
@@ -1087,7 +1044,7 @@ def parse_ground_truth_data_file(h5_file):
10871044 try :
10881045 df .drop ("single" , axis = 1 , level = "individuals" , inplace = True )
10891046 except KeyError :
1090- # Ignore if the "single" individual column is absent
1047+ # Ignore if the "single" individual column is absent
10911048 pass
10921049 # Cast columns of dtype 'object' to float to avoid TypeError
10931050 # further down in _parse_ground_truth_data.
@@ -1120,15 +1077,13 @@ def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)):
11201077 raise ValueError (f"Invalid criterion { criterion } ." )
11211078
11221079 if len (qs ) != 2 :
1123- raise ValueError (
1124- "Two percentiles (for lower and upper bounds) should be given."
1125- )
1080+ raise ValueError ("Two percentiles (for lower and upper bounds) should be given." )
11261081
11271082 tuples = []
11281083 for frame_ind , assemblies in dict_of_assemblies .items ():
11291084 for assembly in assemblies :
11301085 tuples .append ((frame_ind , getattr (assembly , criterion )))
1131- frame_inds , vals = zip (* tuples )
1086+ frame_inds , vals = zip (* tuples , strict = False )
11321087 vals = np .asarray (vals )
11331088 lo , up = np .percentile (vals , qs , interpolation = "nearest" )
11341089 inds = np .flatnonzero ((vals < lo ) | (vals > up )).tolist ()
@@ -1226,9 +1181,7 @@ def evaluate_assembly_greedy(
12261181 oks = np .asarray ([match .oks for match in all_matched ])[sorted_pred_indices ]
12271182
12281183 # Compute prediction and recall
1229- p , r = _compute_precision_and_recall (
1230- total_gt_assemblies , oks , oks_t , recall_thresholds
1231- )
1184+ p , r = _compute_precision_and_recall (total_gt_assemblies , oks , oks_t , recall_thresholds )
12321185 precisions .append (p )
12331186 recalls .append (r )
12341187
@@ -1246,12 +1199,14 @@ def evaluate_assembly(
12461199 ass_pred_dict ,
12471200 ass_true_dict ,
12481201 oks_sigma = 0.072 ,
1249- oks_thresholds = np . linspace ( 0.5 , 0.95 , 10 ) ,
1202+ oks_thresholds = None ,
12501203 margin = 0 ,
12511204 symmetric_kpts = None ,
12521205 greedy_matching = False ,
12531206 with_tqdm : bool = True ,
12541207):
1208+ if oks_thresholds is None :
1209+ oks_thresholds = np .linspace (0.5 , 0.95 , 10 )
12551210 if greedy_matching :
12561211 return evaluate_assembly_greedy (
12571212 ass_true_dict ,
@@ -1299,9 +1254,7 @@ def evaluate_assembly(
12991254 precisions = []
13001255 recalls = []
13011256 for t in oks_thresholds :
1302- p , r = _compute_precision_and_recall (
1303- total_gt_assemblies , oks , t , recall_thresholds
1304- )
1257+ p , r = _compute_precision_and_recall (total_gt_assemblies , oks , t , recall_thresholds )
13051258 precisions .append (p )
13061259 recalls .append (r )
13071260
0 commit comments