Skip to content

Commit d82c371

Browse files
authored
feat: refactor training (#316)
* feat: rewrite training, add similarity based trainer * do better check * add class weight * remove print * fix typing * fix: issue with saving * fix comments * fix test * fix: tests for windows * fix: test atol
1 parent b3012ee commit d82c371

11 files changed

Lines changed: 840 additions & 335 deletions

File tree

model2vec/inference/model.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import re
44
from collections.abc import Sequence
5+
from enum import Enum
56
from pathlib import Path
67
from tempfile import TemporaryDirectory
78
from typing import TypeVar, cast
@@ -10,7 +11,7 @@
1011
import numpy as np
1112
import skops.io
1213
from sklearn.metrics import classification_report
13-
from sklearn.neural_network import MLPClassifier
14+
from sklearn.neural_network import MLPClassifier, MLPRegressor
1415
from sklearn.pipeline import Pipeline
1516
from sklearn.preprocessing import MultiLabelBinarizer
1617

@@ -23,25 +24,29 @@
2324
LabelType = TypeVar("LabelType", list[str], list[list[str]])
2425

2526

27+
class HeadType(str, Enum):
28+
CLASSIFIER = "classifier"
29+
PROJECTOR = "projector"
30+
MULTILABEL = "multilabel"
31+
32+
2633
class StaticModelPipeline:
2734
def __init__(self, model: StaticModel, head: Pipeline) -> None:
2835
"""Create a pipeline with a StaticModel encoder."""
2936
self.model = model
3037
self.head = head
31-
classifier = self.head[-1]
32-
# Check if the classifier is a multilabel classifier.
33-
# NOTE: this doesn't look robust, but it is.
34-
# Different classifiers, such as OVR wrappers, support multilabel output natively, so we
35-
# can just use predict.
36-
self.multilabel = False
37-
if isinstance(classifier, MLPClassifier):
38-
if classifier.out_activation_ == "logistic":
39-
self.multilabel = True
40-
41-
@property
42-
def classes_(self) -> np.ndarray:
43-
"""The classes of the classifier."""
44-
return self.head.classes_
38+
39+
last_head = self.head[-1]
40+
self.classes_: None | np.ndarray = None
41+
if isinstance(last_head, MLPRegressor):
42+
self.classifier_type = HeadType.PROJECTOR
43+
elif isinstance(last_head, MLPClassifier):
44+
activation = last_head.out_activation_
45+
self.classifier_type = HeadType.MULTILABEL if activation == "logistic" else HeadType.CLASSIFIER
46+
self.classes_ = self.head.classes_
47+
else:
48+
# Default to classifier: the assumption is the user is unlikely to use multilabel here.
49+
self.classifier_type = HeadType.CLASSIFIER
4550

4651
@classmethod
4752
def from_pretrained(
@@ -138,7 +143,8 @@ def predict(
138143
multiprocessing_threshold=multiprocessing_threshold,
139144
)
140145

141-
if self.multilabel:
146+
if self.classifier_type == HeadType.MULTILABEL:
147+
assert self.classes_ is not None
142148
out_labels = []
143149
proba = self.head.predict_proba(encoded)
144150
for vector in proba:
@@ -166,7 +172,10 @@ def predict_proba(
166172
:param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
167173
:param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
168174
:return: The predicted labels or probabilities.
175+
:raises ValueError: If the classifier type is projector.
169176
"""
177+
if self.classifier_type == HeadType.PROJECTOR:
178+
raise ValueError("You are using evaluate on a projector model. This is not supported.")
170179
encoded = self._encode_and_coerce_to_2d(
171180
X,
172181
show_progress_bar=show_progress_bar,
@@ -190,7 +199,10 @@ def evaluate(
190199
:param threshold: The threshold for multilabel classification.
191200
:param output_dict: Whether to output the classification report as a dictionary.
192201
:return: A classification report.
202+
:raises ValueError: If the classifier type is projector.
193203
"""
204+
if self.classifier_type == HeadType.PROJECTOR:
205+
raise ValueError("You are using evaluate on a projector model. This is not supported.")
194206
predictions = self.predict(X, show_progress_bar=True, batch_size=batch_size, threshold=threshold)
195207
report = evaluate_single_or_multi_label(predictions=predictions, y=y, output_dict=output_dict)
196208

model2vec/train/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
from model2vec.utils import get_package_extras, importable
24

35
_REQUIRED_EXTRA = "train"
@@ -6,5 +8,10 @@
68
importable(extra_dependency, _REQUIRED_EXTRA)
79

810
from model2vec.train.classifier import StaticModelForClassification
11+
from model2vec.train.similarity import StaticModelForSimilarity
12+
from model2vec.train.utils import TipFilter
13+
14+
__all__ = ["StaticModelForClassification", "StaticModelForSimilarity"]
15+
916

10-
__all__ = ["StaticModelForClassification"]
17+
logging.getLogger("lightning.pytorch.utilities.rank_zero").addFilter(TipFilter())

0 commit comments

Comments
 (0)