Skip to content

Commit d019067

Browse files
committed
Initialize subclass attributes on classifier load
Adds initialization of subclass-specific attributes (such as default_threshold, min_predictions, max_predictions, and label_thresholds) in AdaptiveClassifier.load to ensure proper state after loading. Updates tests to handle ONNX and non-ONNX models correctly when setting eval mode and loading classifiers.
1 parent cf801f2 commit d019067

2 files changed

Lines changed: 23 additions & 10 deletions

File tree

src/adaptive_classifier/classifier.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,17 @@ def _from_pretrained(
832832
classifier.strategic_optimizer = None
833833
classifier.strategic_evaluator = None
834834

835+
# Initialize subclass-specific attributes (e.g., for MultiLabelAdaptiveClassifier)
836+
# These will be overwritten if the subclass has its own initialization logic
837+
if not hasattr(classifier, 'default_threshold'):
838+
classifier.default_threshold = 0.5
839+
if not hasattr(classifier, 'min_predictions'):
840+
classifier.min_predictions = 1
841+
if not hasattr(classifier, 'max_predictions'):
842+
classifier.max_predictions = None
843+
if not hasattr(classifier, 'label_thresholds'):
844+
classifier.label_thresholds = {}
845+
835846
if classifier.config.enable_strategic_mode:
836847
classifier._initialize_strategic_components()
837848
else:

tests/test_classifier.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,16 @@ def test_save_load(base_classifier, sample_data):
6060
torch.manual_seed(42)
6161
np.random.seed(42)
6262
random.seed(42)
63-
63+
6464
texts, labels = sample_data
6565
base_classifier.add_examples(texts, labels)
66-
66+
6767
with tempfile.TemporaryDirectory() as tmpdir:
6868
save_path = Path(tmpdir) / "test_classifier"
69-
70-
# Ensure model is in eval mode before saving
71-
base_classifier.model.eval()
69+
70+
# Ensure model is in eval mode before saving (if not ONNX)
71+
if not base_classifier.use_onnx and hasattr(base_classifier.model, 'eval'):
72+
base_classifier.model.eval()
7273
if base_classifier.adaptive_head is not None:
7374
base_classifier.adaptive_head.eval()
7475

@@ -81,13 +82,14 @@ def test_save_load(base_classifier, sample_data):
8182
assert (save_path / "examples.json").exists()
8283
assert (save_path / "README.md").exists()
8384

84-
# Load with same device
85-
loaded_classifier = AdaptiveClassifier.load(save_path, device=base_classifier.device)
85+
# Load with same device (disable ONNX for deterministic comparison)
86+
loaded_classifier = AdaptiveClassifier.load(save_path, device=base_classifier.device, use_onnx=False)
8687
assert loaded_classifier is not None
8788
assert loaded_classifier.label_to_id == base_classifier.label_to_id
88-
89-
# Ensure loaded model is also in eval mode
90-
loaded_classifier.model.eval()
89+
90+
# Ensure loaded model is also in eval mode (if not ONNX)
91+
if not loaded_classifier.use_onnx and hasattr(loaded_classifier.model, 'eval'):
92+
loaded_classifier.model.eval()
9193
if loaded_classifier.adaptive_head is not None:
9294
loaded_classifier.adaptive_head.eval()
9395

0 commit comments

Comments
 (0)