Skip to content

Commit ef4b2cd

Browse files
committed
Update classifier.py
1 parent bcfd551 commit ef4b2cd

1 file changed

Lines changed: 9 additions & 6 deletions

File tree

src/adaptive_classifier/classifier.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,12 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
367367
# Get embedding
368368
embedding = self._get_embeddings([text])[0]
369369

370-
# Get prototype predictions
371-
proto_preds = self.memory.get_nearest_prototypes(embedding, k=k)
372-
373-
# Get neural predictions if available
370+
# Get prototype predictions for ALL classes (not limited by k)
371+
# This ensures complete scoring information for proper combination
372+
max_classes = len(self.id_to_label) if self.id_to_label else k
373+
proto_preds = self.memory.get_nearest_prototypes(embedding, k=max_classes)
374+
375+
# Get neural predictions if available for ALL classes (not limited by k)
374376
if self.adaptive_head is not None:
375377
self.adaptive_head.eval() # Ensure eval mode
376378
# Add batch dimension and move to device
@@ -379,8 +381,9 @@ def _predict_regular(self, text: str, k: int = 5) -> List[Tuple[str, float]]:
379381
# Squeeze batch dimension
380382
logits = logits.squeeze(0)
381383
probs = F.softmax(logits, dim=0)
382-
383-
values, indices = torch.topk(probs, min(k, len(self.id_to_label)))
384+
385+
# Get predictions for ALL classes for proper scoring combination
386+
values, indices = torch.topk(probs, len(self.id_to_label))
384387
head_preds = [
385388
(self.id_to_label[idx.item()], val.item())
386389
for val, idx in zip(values, indices)

0 commit comments

Comments
 (0)