Skip to content

Commit d9bfb5d

Browse files
committed
Adding unittest skip only to test cc functions and resolving shape check bug
Signed-off-by: Vijay Vignesh Prasad Rao <vijayvigneshp02@gmail.com>
1 parent ba2e0b3 commit d9bfb5d

2 files changed

Lines changed: 8 additions & 6 deletions

File tree

monai/metrics/meandice.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,12 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
418418
y_pred = torch.sigmoid(y_pred)
419419
y_pred = y_pred > 0.5
420420

421-
if self.per_component and (len(y_pred.shape) != 5 or y_pred.shape[1] != 2):
422-
raise ValueError(
423-
f"per_component requires 5D binary segmentation with 2 channels (background + foreground). "
424-
f"Got shape {y_pred.shape}, expected shape (B, 2, D, H, W)."
425-
)
421+
if self.per_component:
422+
if len(y_pred.shape) != 5 or len(y.shape) != 5 or y_pred.shape[1] != 2 or y.shape[1] != 2:
423+
raise ValueError(
424+
"per_component requires both y_pred and y to be 5D binary segmentations "
425+
f"with 2 channels. Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}."
426+
)
426427

427428
first_ch = 0 if self.include_background and not self.per_component else 1
428429
data = []

tests/metrics/test_compute_meandice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,6 @@
272272
]
273273

274274

275-
@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
276275
class TestComputeMeanDice(unittest.TestCase):
277276

278277
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
@@ -325,12 +324,14 @@ def test_nans_class(self, params, input_data, expected_value):
325324

326325
# CC DiceMetric tests
327326
@parameterized.expand([TEST_CASE_16])
327+
@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
328328
def test_cc_dice_value(self, params, input_data, expected_value):
329329
dice_metric = DiceMetric(**params)
330330
dice_metric(**input_data)
331331
result = dice_metric.aggregate(reduction="none")
332332
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
333333

334+
@unittest.skipUnless(has_ndimage, "Requires scipy.ndimage.")
334335
def test_input_dimensions(self):
335336
with self.assertRaises(ValueError):
336337
DiceMetric(per_component=True)(torch.ones([3, 3, 144, 144]), torch.ones([3, 3, 145, 145]))

0 commit comments

Comments
 (0)