@@ -1050,18 +1050,27 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform):
10501050 label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion,
10511051 label 2 is the peritumoral edema, which is counted only under WT subregion,
10521052 label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.
1053+
1054+ Args:
1055+ et_label: the label used for the GD-enhancing tumor (ET).
1056+ - Use 4 for BraTS 2018-2022.
1057+ - Use 3 for BraTS 2023.
1058+ Defaults to 4.
10531059 """
10541060
10551061 backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
10561062
1063+ def __init__ (self , et_label : int = 4 ) -> None :
1064+ self .et_label = et_label
1065+
10571066 def __call__ (self , img : NdarrayOrTensor ) -> NdarrayOrTensor :
10581067 # if img has channel dim, squeeze it
10591068 if img .ndim == 4 and img .shape [0 ] == 1 :
10601069 img = img .squeeze (0 )
10611070
1062- result = [(img == 1 ) | (img == 4 ), (img == 1 ) | (img == 4 ) | (img == 2 ), img == 4 ]
1063- # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT
1064- # label 4 is ET
1071+ result = [(img == 1 ) | (img == self . et_label ), (img == 1 ) | (img == self . et_label ) | (img == 2 ), img == self . et_label ]
1072+ # merge labels 1 (tumor non-enh) and self.et_label (tumor enh) and 2 (large edema) to WT
1073+ # self.et_label is ET (4 or 3)
10651074 return torch .stack (result , dim = 0 ) if isinstance (img , torch .Tensor ) else np .stack (result , axis = 0 )
10661075
10671076
0 commit comments