55import numpy as np
66import pandas as pd
77import pytest
8+ from hypothesis import given , settings
9+ from hypothesis import strategies as st
10+ from hypothesis .extra .numpy import arrays
811
9- from dlclive .core .inferenceutils import Assembler , Assembly , Joint , Link
12+ from dlclive .core .inferenceutils import Assembler , Assembly , Joint , Link , _conv_square_to_condensed_indices
13+
14+ HYPOTHESIS_SETTINGS = settings (max_examples = 300 , deadline = None )
1015
1116
1217def _bag_from_frame (frame : dict ) -> dict [int , list ]:
@@ -17,6 +22,29 @@ def _bag_from_frame(frame: dict) -> dict[int, list]:
1722 return bag
1823
1924
25+ # _conv_square_to_condensed_indices
26+ @HYPOTHESIS_SETTINGS
27+ @given (
28+ n = st .integers (min_value = 2 , max_value = 50 ),
29+ i = st .integers (min_value = 0 , max_value = 49 ),
30+ j = st .integers (min_value = 0 , max_value = 49 ),
31+ )
32+ def test_condensed_index_properties (n , i , j ):
33+ i = i % n
34+ j = j % n
35+
36+ if i == j :
37+ with pytest .raises (ValueError ):
38+ _conv_square_to_condensed_indices (i , j , n )
39+ return
40+
41+ k1 = _conv_square_to_condensed_indices (i , j , n )
42+ k2 = _conv_square_to_condensed_indices (j , i , n )
43+
44+ assert k1 == k2
45+ assert 0 <= k1 < (n * (n - 1 )) // 2
46+
47+
2048# --------------------------------------------------------------------------------------
2149# Basic metadata and __getitem__
2250# --------------------------------------------------------------------------------------
@@ -57,8 +85,6 @@ def test_empty_classmethod(assembler_graph_and_pafs):
5785# --------------------------------------------------------------------------------------
5886# _flatten_detections
5987# --------------------------------------------------------------------------------------
60-
61-
6288def test_flatten_detections_no_identity (simple_two_label_scene ):
6389 frame = simple_two_label_scene
6490 joints = list (Assembler ._flatten_detections (frame ))
@@ -82,11 +108,52 @@ def test_flatten_detections_with_identity(scene_copy):
82108 assert groups .count (1 ) == 2
83109
84110
111+ @st .composite
112+ def coords_and_conf (draw , max_n = 5 ):
113+ n = draw (st .integers (1 , max_n ))
114+ coords = draw (
115+ arrays (
116+ dtype = np .float64 ,
117+ shape = (n , 2 ),
118+ elements = st .floats (min_value = 0.1 , max_value = 1000 , allow_nan = False , allow_infinity = False ),
119+ )
120+ )
121+ conf = draw (
122+ arrays (
123+ dtype = np .float64 ,
124+ shape = (n ,),
125+ elements = st .floats (min_value = 0.0 , max_value = 1.0 , allow_nan = False , allow_infinity = False ),
126+ )
127+ )
128+ return coords , conf
129+
130+
131+ @HYPOTHESIS_SETTINGS
132+ @given (
133+ c0 = coords_and_conf (),
134+ c1 = coords_and_conf (),
135+ )
136+ def test_flatten_detections_counts (c0 , c1 ):
137+ coords0 , conf0 = c0
138+ coords1 , conf1 = c1
139+
140+ frame = {
141+ "coordinates" : [[coords0 , coords1 ]],
142+ "confidence" : [conf0 , conf1 ],
143+ "costs" : {},
144+ }
145+
146+ joints = list (Assembler ._flatten_detections (frame ))
147+
148+ # Should yield exactly one Joint per detection
149+ assert len (joints ) == (len (coords0 ) + len (coords1 ))
150+ assert sum (j .label == 0 for j in joints ) == len (coords0 )
151+ assert sum (j .label == 1 for j in joints ) == len (coords1 )
152+
153+
85154# --------------------------------------------------------------------------------------
86155# extract_best_links
87156# --------------------------------------------------------------------------------------
88-
89-
90157def test_extract_best_links_optimal_assignment (assembler_data_single_frame , make_assembler ):
91158 sframe_data = assembler_data_single_frame
92159 asm = make_assembler (
@@ -134,6 +201,74 @@ def test_extract_best_links_greedy_with_thresholds(assembler_data_single_frame,
134201 )
135202
136203
204+ @HYPOTHESIS_SETTINGS
205+ @given (
206+ n = st .integers (min_value = 1 , max_value = 4 ),
207+ pcutoff = st .floats (min_value = 0.0 , max_value = 1.0 , allow_nan = False , allow_infinity = False ),
208+ min_aff = st .floats (min_value = 0.0 , max_value = 1.0 , allow_nan = False , allow_infinity = False ),
209+ conf0 = st .lists (st .floats (0.0 , 1.0 , allow_nan = False , allow_infinity = False ), min_size = 1 , max_size = 4 ),
210+ conf1 = st .lists (st .floats (0.0 , 1.0 , allow_nan = False , allow_infinity = False ), min_size = 1 , max_size = 4 ),
211+ )
212+ def test_extract_best_links_greedy_invariants_with_threshold_gates (n , pcutoff , min_aff , conf0 , conf1 ):
213+ # Normalize confidences to exactly n items
214+ conf0 = (conf0 + [0.0 ] * n )[:n ]
215+ conf1 = (conf1 + [0.0 ] * n )[:n ]
216+ conf0 = np .array (conf0 , dtype = float )
217+ conf1 = np .array (conf1 , dtype = float )
218+
219+ # Random-ish affinity matrix (still stable), in [0,1]
220+ rng = np .random .default_rng (0 ) # deterministic noise
221+ aff = rng .random ((n , n )) # uniform [0,1)
222+ # Ensure at least one "good" candidate sometimes; otherwise test is vacuously true.
223+ # We'll only assert gated properties on returned links anyway.
224+ # But for better coverage, bias the diagonal upward a bit:
225+ np .fill_diagonal (aff , np .maximum (np .diag (aff ), 0.8 ))
226+ dist = np .ones ((n , n ), dtype = float )
227+
228+ graph = [(0 , 1 )]
229+ paf_inds = [0 ]
230+ data = {
231+ "metadata" : {"all_joints_names" : ["b0" , "b1" ], "PAFgraph" : graph , "PAFinds" : paf_inds },
232+ "0" : {},
233+ }
234+
235+ asm = Assembler (
236+ data ,
237+ max_n_individuals = n ,
238+ n_multibodyparts = 2 ,
239+ greedy = True ,
240+ pcutoff = pcutoff ,
241+ min_affinity = min_aff ,
242+ min_n_links = 1 ,
243+ method = "m1" ,
244+ )
245+
246+ dets0 = [Joint ((float (i ), 0.0 ), float (conf0 [i ]), label = 0 , idx = i ) for i in range (n )]
247+ dets1 = [Joint ((float (i ), 1.0 ), float (conf1 [i ]), label = 1 , idx = 100 + i ) for i in range (n )]
248+ joints_dict = {0 : dets0 , 1 : dets1 }
249+ costs = {0 : {"distance" : dist , "m1" : aff }}
250+
251+ links = asm .extract_best_links (joints_dict , costs , trees = None )
252+
253+ assert len (links ) <= n
254+
255+ used_src = set ()
256+ used_tgt = set ()
257+
258+ for link in links :
259+ # Invariant 1: affinity gate
260+ assert link .affinity >= min_aff
261+
262+ # Invariant 2: pcutoff gate (confidence product)
263+ assert link .j1 .confidence * link .j2 .confidence >= pcutoff * pcutoff
264+
265+ # Invariant 3: disjointness in greedy selection
266+ assert link .j1 .idx not in used_src
267+ assert link .j2 .idx not in used_tgt
268+ used_src .add (link .j1 .idx )
269+ used_tgt .add (link .j2 .idx )
270+
271+
137272# --------------------------------------------------------------------------------------
138273# build_assemblies
139274# --------------------------------------------------------------------------------------
0 commit comments