@@ -367,10 +367,12 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
367367 # Get embedding
368368 embedding = self ._get_embeddings ([text ])[0 ]
369369
370- # Get prototype predictions
371- proto_preds = self .memory .get_nearest_prototypes (embedding , k = k )
372-
373- # Get neural predictions if available
370+ # Get prototype predictions for ALL classes (not limited by k)
371+ # This ensures complete scoring information for proper combination
372+ max_classes = len (self .id_to_label ) if self .id_to_label else k
373+ proto_preds = self .memory .get_nearest_prototypes (embedding , k = max_classes )
374+
375+ # Get neural predictions if available for ALL classes (not limited by k)
374376 if self .adaptive_head is not None :
375377 self .adaptive_head .eval () # Ensure eval mode
376378 # Add batch dimension and move to device
@@ -379,8 +381,9 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
379381 # Squeeze batch dimension
380382 logits = logits .squeeze (0 )
381383 probs = F .softmax (logits , dim = 0 )
382-
383- values , indices = torch .topk (probs , min (k , len (self .id_to_label )))
384+
385+ # Get predictions for ALL classes for proper scoring combination
386+ values , indices = torch .topk (probs , len (self .id_to_label ))
384387 head_preds = [
385388 (self .id_to_label [idx .item ()], val .item ())
386389 for val , idx in zip (values , indices )
0 commit comments