Skip to content

Commit f5863a2

Browse files
committed
fix param_count in build_model
1 parent 967f49f commit f5863a2

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

mambular/models/sklearn_base_lss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
NormalDistribution,
3131
PoissonDistribution,
3232
StudentTDistribution,
33+
Quantile,
3334
)
3435
from lightning.pytorch.callbacks import ModelSummary
3536

@@ -210,11 +211,9 @@ def build_model(
210211
X, y, X_val, y_val, val_size=val_size, random_state=random_state
211212
)
212213

213-
num_classes = len(np.unique(y))
214-
215214
self.task_model = TaskModel(
216215
model_class=self.base_model,
217-
num_classes=num_classes,
216+
num_classes=self.family.param_count,
218217
config=self.config,
219218
cat_feature_info=self.data_module.cat_feature_info,
220219
num_feature_info=self.data_module.num_feature_info,
@@ -347,6 +346,7 @@ def fit(
347346
"negativebinom": NegativeBinomialDistribution,
348347
"inversegamma": InverseGammaDistribution,
349348
"categorical": CategoricalDistribution,
349+
"quantile": Quantile,
350350
}
351351

352352
if distributional_kwargs is None:

0 commit comments

Comments
 (0)