@@ -156,11 +156,11 @@ def compute_generalized_dice(
156156
157157 # Apply ignore_index masking
158158 if ignore_index is not None :
159- if ignore_index < y .shape [1 ]:
159+ if 0 <= ignore_index < y .shape [1 ]:
160160 # For one-hot: use the ignored class channel
161161 mask = 1.0 - y [:, ignore_index : ignore_index + 1 ]
162162 else :
163- # For sentinel values, check if any channel is valid
163+ # For sentinel values (like 255 or -100) , check if any channel is valid
164164 mask = (y .sum (dim = 1 , keepdim = True ) > 0 ).float ()
165165 y_pred = y_pred * mask
166166 y = y * mask
@@ -171,7 +171,7 @@ def compute_generalized_dice(
171171 if not include_background :
172172 channels_to_use .pop (0 )
173173
174- if ignore_index is not None :
174+ if ignore_index is not None and 0 <= ignore_index < n_channels :
175175 # If background was 0 and we ignore class 2, we need the correct absolute index
176176 if ignore_index in channels_to_use :
177177 channels_to_use .remove (ignore_index )
@@ -181,35 +181,33 @@ def compute_generalized_dice(
181181
182182 # Reducing only spatial dimensions (not batch nor channels), compute the intersection and non-weighted denominator
183183 reduce_axis = list (range (2 , y_pred .dim ()))
184- y_o_full = torch .sum (y , dim = reduce_axis ) # shape: (B, C)
185184 intersection = torch .sum (y [:, channels_to_use , ...] * y_pred [:, channels_to_use , ...], dim = reduce_axis )
186185 y_o = torch .sum (y [:, channels_to_use , ...], dim = reduce_axis )
187186 y_pred_o = torch .sum (y_pred [:, channels_to_use , ...], dim = reduce_axis )
188187
189188 denominator = y_o + y_pred_o
190189
191190 # Set the class weights
191+ # Set the class weights (computed from scored channels only)
192192 weight_type = look_up_option (weight_type , Weight )
193- y_o_float = y_o_full .float ()
193+ y_o_float = y_o .float ()
194194
195195 if weight_type == Weight .SIMPLE :
196- w_full = torch .reciprocal (y_o_float )
196+ w = torch .reciprocal (y_o_float )
197197 elif weight_type == Weight .SQUARE :
198- w_full = torch .reciprocal (y_o_float * y_o_float )
198+ w = torch .reciprocal (y_o_float * y_o_float )
199199 else :
200- w_full = torch .ones_like (y_o_float )
200+ w = torch .ones_like (y_o_float )
201201
202202 # Replace infinite values for non-appearing classes by the maximum weight
203- for b_idx in range (w_full .shape [0 ]):
204- batch_w = w_full [b_idx ]
203+ for b_idx in range (w .shape [0 ]):
204+ batch_w = w [b_idx ]
205205 infs = torch .isinf (batch_w )
206206 if infs .any ():
207207 batch_w [infs ] = 0
208208 max_w = torch .max (batch_w )
209209 batch_w [infs ] = max_w if max_w > 0 else 1.0
210210
211- w = w_full [:, channels_to_use ]
212-
213211 if sum_over_classes :
214212 intersection = (intersection * w ).sum (dim = 1 , keepdim = True )
215213 denominator = (denominator * w ).sum (dim = 1 , keepdim = True )
0 commit comments