Skip to content

Commit cde33ff

Browse files
committed
Add validation for et_label in ConvertToMultiChannelBasedOnBratsClasses and extend tests
Signed-off-by: Abdessamad <abdoomis6@gmail.com>
1 parent a933cf6 commit cde33ff

3 files changed

Lines changed: 45 additions & 2 deletions

File tree

monai/transforms/utility/array.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,8 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform):
10611061
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
10621062

10631063
def __init__(self, et_label: int = 4) -> None:
1064+
if et_label in (1, 2):
1065+
raise ValueError(f"et_label cannot be 1 or 2, these are reserved.Got {et_label}.")
10641066
self.et_label = et_label
10651067

10661068
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:

tests/transforms/test_convert_to_multi_channel.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from tests.test_utils import TEST_NDARRAYS, assert_allclose
2121

2222
TESTS = []
23+
TESTS_ET_LABEL_3 = []
24+
25+
# Tests for default et_label = 4
2326
for p in TEST_NDARRAYS:
2427
TESTS.extend(
2528
[
@@ -46,6 +49,22 @@
4649
]
4750
)
4851

52+
# Tests for et_label = 3
53+
for p in TEST_NDARRAYS:
54+
TESTS_ET_LABEL_3.extend(
55+
[
56+
[
57+
p([[0, 1, 2], [1, 2, 3], [0, 1, 3]]),
58+
p(
59+
[
60+
[[0, 1, 0], [1, 0, 1], [0, 1, 1]],
61+
[[0, 1, 1], [1, 1, 1], [0, 1, 1]],
62+
[[0, 0, 0], [0, 0, 1], [0, 0, 1]],
63+
]
64+
),
65+
],
66+
]
67+
)
4968

5069
class TestConvertToMultiChannel(unittest.TestCase):
5170
@parameterized.expand(TESTS)
@@ -54,6 +73,17 @@ def test_type_shape(self, data, expected_result):
5473
assert_allclose(result, expected_result)
5574
self.assertTrue(result.dtype in (bool, torch.bool))
5675

76+
@parameterized.expand(TESTS_ET_LABEL_3)
77+
def test_type_shape_et_label_3(self, data, expected_result):
78+
result = ConvertToMultiChannelBasedOnBratsClasses(et_label=3)(data)
79+
assert_allclose(result, expected_result)
80+
self.assertTrue(result.dtype in (bool, torch.bool))
81+
82+
def test_invalid_et_label(self):
83+
with self.assertRaises(ValueError):
84+
ConvertToMultiChannelBasedOnBratsClasses(et_label=1)
85+
with self.assertRaises(ValueError):
86+
ConvertToMultiChannelBasedOnBratsClasses(et_label=2)
5787

5888
if __name__ == "__main__":
59-
unittest.main()
89+
unittest.main()

tests/transforms/test_convert_to_multi_channeld.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]),
2525
]
2626

27+
TEST_CASE_ET_LABEL_3 = [
28+
{"keys": "label", "et_label": 3},
29+
{"label": np.array([[0, 1, 2], [1, 2, 3], [0, 1, 3]])},
30+
np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]),
31+
]
32+
2733

2834
class TestConvertToMultiChanneld(unittest.TestCase):
2935

@@ -32,6 +38,11 @@ def test_type_shape(self, keys, data, expected_result):
3238
result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)
3339
np.testing.assert_equal(result["label"], expected_result)
3440

41+
@parameterized.expand([TEST_CASE_ET_LABEL_3])
42+
def test_et_label_3(self, keys, data, expected_result):
43+
result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)
44+
np.testing.assert_equal(result["label"], expected_result)
45+
3546

3647
if __name__ == "__main__":
37-
unittest.main()
48+
unittest.main()

0 commit comments

Comments
 (0)