Skip to content

Commit 089f02d

Browse files
AbdoMispre-commit-ci[bot]ericspod
authored
Add support for configurable GD-enhancing tumor label in ConvertToMul… (#8779)
**Description** This PR introduces an `et_label` parameter to `ConvertToMultiChannelBasedOnBratsClasses` to easily support the BraTS 2023 dataset conventions without breaking backwards compatibility. **Motivation and Context** Historically (BraTS 18-21), the Enhancing Tumor (ET) label was 4. However, starting with the BraTS 2023 Adult Glioma (GLI) dataset, the ET label was shifted to 3. Because the original transform hardcoded the ET label as 4, it fails to recognize the ET channel in newer datasets and generates empty masks. By parameterizing `et_label` and defaulting it to 4, existing pipelines remain 100% unaffected, while BraTS 23 users can now simply pass `et_label=3` to utilize the standard sub-region extraction (TC, WT, ET). **Types of changes** - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) --------- Signed-off-by: Abdessamad <abdoomis6@gmail.com> Signed-off-by: Abdessamad Misdak <abdoomis6@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent ee5b4d4 commit 089f02d

4 files changed

Lines changed: 74 additions & 8 deletions

File tree

monai/transforms/utility/array.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,19 +1049,34 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform):
10491049
which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor):
10501050
label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion,
10511051
label 2 is the peritumoral edema, which is counted only under WT subregion,
1052-
label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.
1052+
the specified `et_label` (default 4) is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.
1053+
1054+
Args:
1055+
et_label: the label used for the GD-enhancing tumor (ET).
1056+
- Use 4 for BraTS 2018-2022.
1057+
- Use 3 for BraTS 2023.
1058+
Defaults to 4.
10531059
"""
10541060

10551061
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
10561062

1063+
def __init__(self, et_label: int = 4) -> None:
1064+
if et_label in (1, 2):
1065+
raise ValueError(f"et_label cannot be 1 or 2, as these are reserved. Got {et_label}.")
1066+
self.et_label = et_label
1067+
10571068
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
10581069
# if img has channel dim, squeeze it
10591070
if img.ndim == 4 and img.shape[0] == 1:
10601071
img = img.squeeze(0)
10611072

1062-
result = [(img == 1) | (img == 4), (img == 1) | (img == 4) | (img == 2), img == 4]
1063-
# merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT
1064-
# label 4 is ET
1073+
result = [
1074+
(img == 1) | (img == self.et_label),
1075+
(img == 1) | (img == self.et_label) | (img == 2),
1076+
img == self.et_label,
1077+
]
1078+
# merge labels 1 (tumor non-enh) and self.et_label (tumor enh) and 2 (large edema) to WT
1079+
# self.et_label is ET (4 or 3)
10651080
return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)
10661081

10671082

monai/transforms/utility/dictionary.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,19 +1297,27 @@ def __call__(self, data: Mapping[Hashable, Any]):
12971297
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
12981298
"""
12991299
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`.
1300-
Convert labels to multi channels based on brats18 classes:
1300+
Convert labels to multi channels based on brats classes:
13011301
label 1 is the necrotic and non-enhancing tumor core
13021302
label 2 is the peritumoral edema
1303-
label 4 is the GD-enhancing tumor
1303+
the specified `et_label` (default 4) is the GD-enhancing tumor
13041304
The possible classes are TC (Tumor core), WT (Whole tumor)
13051305
and ET (Enhancing tumor).
1306+
1307+
Args:
1308+
keys: keys of the corresponding items to be transformed.
1309+
et_label: the label used for the GD-enhancing tumor (ET).
1310+
- Use 4 for BraTS 2018-2022.
1311+
- Use 3 for BraTS 2023.
1312+
Defaults to 4.
1313+
allow_missing_keys: don't raise exception if key is missing.
13061314
"""
13071315

13081316
backend = ConvertToMultiChannelBasedOnBratsClasses.backend
13091317

1310-
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
1318+
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, et_label: int = 4):
13111319
super().__init__(keys, allow_missing_keys)
1312-
self.converter = ConvertToMultiChannelBasedOnBratsClasses()
1320+
self.converter = ConvertToMultiChannelBasedOnBratsClasses(et_label=et_label)
13131321

13141322
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
13151323
d = dict(data)

tests/transforms/test_convert_to_multi_channel.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from tests.test_utils import TEST_NDARRAYS, assert_allclose
2121

2222
TESTS = []
23+
TESTS_ET_LABEL_3 = []
24+
25+
# Tests for default et_label = 4
2326
for p in TEST_NDARRAYS:
2427
TESTS.extend(
2528
[
@@ -46,6 +49,23 @@
4649
]
4750
)
4851

52+
# Tests for et_label = 3
53+
for p in TEST_NDARRAYS:
54+
TESTS_ET_LABEL_3.extend(
55+
[
56+
[
57+
p([[0, 1, 2], [1, 2, 3], [0, 1, 3]]),
58+
p(
59+
[
60+
[[0, 1, 0], [1, 0, 1], [0, 1, 1]],
61+
[[0, 1, 1], [1, 1, 1], [0, 1, 1]],
62+
[[0, 0, 0], [0, 0, 1], [0, 0, 1]],
63+
]
64+
),
65+
]
66+
]
67+
)
68+
4969

5070
class TestConvertToMultiChannel(unittest.TestCase):
5171
@parameterized.expand(TESTS)
@@ -54,6 +74,18 @@ def test_type_shape(self, data, expected_result):
5474
assert_allclose(result, expected_result)
5575
self.assertTrue(result.dtype in (bool, torch.bool))
5676

77+
@parameterized.expand(TESTS_ET_LABEL_3)
78+
def test_type_shape_et_label_3(self, data, expected_result):
79+
result = ConvertToMultiChannelBasedOnBratsClasses(et_label=3)(data)
80+
assert_allclose(result, expected_result)
81+
self.assertTrue(result.dtype in (bool, torch.bool))
82+
83+
def test_invalid_et_label(self):
84+
with self.assertRaises(ValueError):
85+
ConvertToMultiChannelBasedOnBratsClasses(et_label=1)
86+
with self.assertRaises(ValueError):
87+
ConvertToMultiChannelBasedOnBratsClasses(et_label=2)
88+
5789

5890
if __name__ == "__main__":
5991
unittest.main()

tests/transforms/test_convert_to_multi_channeld.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]),
2525
]
2626

27+
TEST_CASE_ET_LABEL_3 = [
28+
{"keys": "label", "et_label": 3},
29+
{"label": np.array([[0, 1, 2], [1, 2, 3], [0, 1, 3]])},
30+
np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]),
31+
]
32+
2733

2834
class TestConvertToMultiChanneld(unittest.TestCase):
2935

@@ -32,6 +38,11 @@ def test_type_shape(self, keys, data, expected_result):
3238
result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)
3339
np.testing.assert_equal(result["label"], expected_result)
3440

41+
@parameterized.expand([TEST_CASE_ET_LABEL_3])
42+
def test_et_label_3(self, keys, data, expected_result):
43+
result = ConvertToMultiChannelBasedOnBratsClassesd(**keys)(data)
44+
np.testing.assert_equal(result["label"], expected_result)
45+
3546

3647
if __name__ == "__main__":
3748
unittest.main()

0 commit comments

Comments
 (0)