Skip to content

Commit b246d5c

Browse files
committed
fix label shape for lss regression
1 parent 3f482fa commit b246d5c

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

mambular/base_models/lightning_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def compute_loss(self, predictions, y_true):
126126
Computed loss.
127127
"""
128128
if self.lss:
129-
return self.family.compute_loss(predictions, y_true)
129+
return self.family.compute_loss(predictions, y_true.squeeze(-1))
130130
else:
131131
loss = self.loss_fct(predictions, y_true)
132132
return loss

0 commit comments

Comments
 (0)