@@ -119,32 +119,83 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
119119 pos_mask = (target == 1 ).float ()
120120 neg_mask = (target == 0 ).float ()
121121
122+ mean_pos_sq = (input - self .a ) ** 2
123+ mean_neg_sq = (input - self .b ) ** 2
124+
125+ # Note:
126+ # v1 uses global expectations (normalized by total number of samples),
127+ # following the original LibAUC implementation.
128+ # v2 uses class-conditional expectations (normalized by number of samples
129+ # in each class), implemented via non-zero averaging.
130+ # These behaviors differ and should not be unified.
122131 if self .version == "v1" :
123132 p = float (self .imratio ) if self .imratio is not None else float (pos_mask .mean ().item ())
133+ p1 = 1.0 - p
134+
135+ mean_pos = self ._global_mean (mean_pos_sq , pos_mask )
136+ mean_neg = self ._global_mean (mean_neg_sq , neg_mask )
137+
138+ interaction = self ._global_mean (p * input * neg_mask - p1 * input * pos_mask , pos_mask + neg_mask )
139+
124140 loss = (
125- (1 - p ) * self ._safe_mean ((input - self .a ) ** 2 , pos_mask )
126- + p * self ._safe_mean ((input - self .b ) ** 2 , neg_mask )
127- + 2
128- * self .alpha
129- * (
130- p * (1 - p ) * self .margin
131- + self ._safe_mean (p * input * neg_mask - (1 - p ) * input * pos_mask , pos_mask + neg_mask )
132- )
133- - p * (1 - p ) * self .alpha ** 2
141+ p1 * mean_pos
142+ + p * mean_neg
143+ + 2 * self .alpha * (p * p1 * self .margin + interaction )
144+ - p * p1 * self .alpha ** 2
134145 )
135- else :
146+
147+ else : # v2
148+ mean_pos = self ._class_mean (mean_pos_sq , pos_mask )
149+ mean_neg = self ._class_mean (mean_neg_sq , neg_mask )
150+
151+ mean_input_pos = self ._class_mean (input , pos_mask )
152+ mean_input_neg = self ._class_mean (input , neg_mask )
153+
136154 loss = (
137- self ._safe_mean ((input - self .a ) ** 2 , pos_mask )
138- + self ._safe_mean ((input - self .b ) ** 2 , neg_mask )
139- + 2 * self .alpha * (self .margin + self ._safe_mean (input , neg_mask ) - self ._safe_mean (input , pos_mask ))
140- - self .alpha ** 2
155+ mean_pos + mean_neg + 2 * self .alpha * (self .margin + mean_input_neg - mean_input_pos ) - self .alpha ** 2
141156 )
142157
143158 return loss
144159
145- def _safe_mean (self , tensor : torch .Tensor , mask : torch .Tensor ) -> torch .Tensor :
146- """Compute mean safely over masked elements."""
160+ def _global_mean (self , tensor : torch .Tensor , mask : torch .Tensor ) -> torch .Tensor :
161+ """
162+ Compute the global mean of a masked tensor.
163+
164+ This computes the mean over all elements, where values outside the mask
165+ are zeroed out. The result is normalized by the total number of elements,
166+ not by the number of masked elements.
167+
168+ This corresponds to a global expectation:
169+ E[mask * tensor]
170+
171+ Args:
172+ tensor: Input tensor.
173+ mask: Binary mask tensor of the same shape as ``tensor``.
174+
175+ Returns:
176+ Scalar tensor representing the global mean.
177+ """
178+ return (tensor * mask ).mean ()
179+
180+ def _class_mean (self , tensor : torch .Tensor , mask : torch .Tensor ) -> torch .Tensor :
181+ """
182+ Compute the class-conditional mean of a masked tensor.
183+
184+ This computes the mean over only the masked (non-zero) elements, i.e.,
185+ the result is normalized by the number of masked elements.
186+
187+ This corresponds to a class-conditional expectation:
188+ E[tensor | mask]
189+
190+ Args:
191+ tensor: Input tensor.
192+ mask: Binary mask tensor of the same shape as ``tensor``.
193+
194+ Returns:
195+ Scalar tensor representing the class-conditional mean.
196+ Returns 0 if no elements are selected by the mask.
197+ """
147198 denom = mask .sum ()
148- if denom == 0 :
149- return torch .tensor ( 0.0 , device = tensor .device , dtype = tensor .dtype )
199+ if denom . item () == 0 :
200+ return torch .zeros ((), dtype = tensor .dtype , device = tensor .device )
150201 return (tensor * mask ).sum () / denom
0 commit comments