22
33import re
44from collections .abc import Sequence
5+ from enum import Enum
56from pathlib import Path
67from tempfile import TemporaryDirectory
78from typing import TypeVar , cast
1011import numpy as np
1112import skops .io
1213from sklearn .metrics import classification_report
13- from sklearn .neural_network import MLPClassifier
14+ from sklearn .neural_network import MLPClassifier , MLPRegressor
1415from sklearn .pipeline import Pipeline
1516from sklearn .preprocessing import MultiLabelBinarizer
1617
2324LabelType = TypeVar ("LabelType" , list [str ], list [list [str ]])
2425
2526
27+ class HeadType (str , Enum ):
28+ CLASSIFIER = "classifier"
29+ PROJECTOR = "projector"
30+ MULTILABEL = "multilabel"
31+
32+
2633class StaticModelPipeline :
2734 def __init__ (self , model : StaticModel , head : Pipeline ) -> None :
2835 """Create a pipeline with a StaticModel encoder."""
2936 self .model = model
3037 self .head = head
31- classifier = self .head [- 1 ]
32- # Check if the classifier is a multilabel classifier.
33- # NOTE: this doesn't look robust, but it is.
34- # Different classifiers, such as OVR wrappers, support multilabel output natively, so we
35- # can just use predict.
36- self .multilabel = False
37- if isinstance (classifier , MLPClassifier ):
38- if classifier .out_activation_ == "logistic" :
39- self .multilabel = True
40-
41- @property
42- def classes_ (self ) -> np .ndarray :
43- """The classes of the classifier."""
44- return self .head .classes_
38+
39+ last_head = self .head [- 1 ]
40+ self .classes_ : None | np .ndarray = None
41+ if isinstance (last_head , MLPRegressor ):
42+ self .classifier_type = HeadType .PROJECTOR
43+ elif isinstance (last_head , MLPClassifier ):
44+ activation = last_head .out_activation_
45+ self .classifier_type = HeadType .MULTILABEL if activation == "logistic" else HeadType .CLASSIFIER
46+ self .classes_ = self .head .classes_
47+ else :
48+ # Default to classifier: the assumption is the user is unlikely to use multilabel here.
49+ self .classifier_type = HeadType .CLASSIFIER
4550
4651 @classmethod
4752 def from_pretrained (
@@ -138,7 +143,8 @@ def predict(
138143 multiprocessing_threshold = multiprocessing_threshold ,
139144 )
140145
141- if self .multilabel :
146+ if self .classifier_type == HeadType .MULTILABEL :
147+ assert self .classes_ is not None
142148 out_labels = []
143149 proba = self .head .predict_proba (encoded )
144150 for vector in proba :
@@ -166,7 +172,10 @@ def predict_proba(
166172 :param use_multiprocessing: Whether to use multiprocessing for encoding. Defaults to True.
167173 :param multiprocessing_threshold: The threshold for the number of samples to use multiprocessing. Defaults to 10,000.
168174 :return: The predicted labels or probabilities.
175+ :raises ValueError: If the classifier type is projector.
169176 """
177+ if self .classifier_type == HeadType .PROJECTOR :
178+ raise ValueError ("You are using evaluate on a projector model. This is not supported." )
170179 encoded = self ._encode_and_coerce_to_2d (
171180 X ,
172181 show_progress_bar = show_progress_bar ,
@@ -190,7 +199,10 @@ def evaluate(
190199 :param threshold: The threshold for multilabel classification.
191200 :param output_dict: Whether to output the classification report as a dictionary.
192201 :return: A classification report.
202+ :raises ValueError: If the classifier type is projector.
193203 """
204+ if self .classifier_type == HeadType .PROJECTOR :
205+ raise ValueError ("You are using evaluate on a projector model. This is not supported." )
194206 predictions = self .predict (X , show_progress_bar = True , batch_size = batch_size , threshold = threshold )
195207 report = evaluate_single_or_multi_label (predictions = predictions , y = y , output_dict = output_dict )
196208
0 commit comments