Skip to content

Commit 6da2cf8

Browse files
authored
Merge pull request #83 from basf/lss_fix
fix label shape for lss regression
2 parents 3f482fa + b246d5c commit 6da2cf8

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

mambular/base_models/lightning_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def compute_loss(self, predictions, y_true):
126126
Computed loss.
127127
"""
128128
if self.lss:
129-
return self.family.compute_loss(predictions, y_true)
129+
return self.family.compute_loss(predictions, y_true.squeeze(-1))
130130
else:
131131
loss = self.loss_fct(predictions, y_true)
132132
return loss

0 commit comments

Comments
 (0)