Skip to content

Commit b6dd9b4

Browse files
committed
fix(auto3dseg): handle precomputed crops and safe no-grad cleanup
Signed-off-by: Jun Hyeok Lee <bluehyena123@naver.com>
1 parent a537499 commit b6dd9b4

2 files changed

Lines changed: 86 additions & 54 deletions

File tree

monai/auto3dseg/analyzer.py

Lines changed: 40 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -216,50 +216,58 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS)
216216
super().__init__(stats_name, report_format)
217217
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
218218

219+
@torch.no_grad()
219220
def __call__(self, data):
220-
# Input Validation Addition
221-
if not isinstance(data, dict):
222-
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
223-
if self.image_key not in data:
224-
raise KeyError(f"Key '{self.image_key}' not found in input data.")
225-
image = data[self.image_key]
226-
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
227-
raise TypeError(
228-
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
229-
f"but got {type(image).__name__}."
230-
)
231-
if image.ndim < 3:
232-
raise ValueError(
233-
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
234-
)
235-
# --- End of validation ---
236221
"""
237-
Callable to execute the pre-defined functions
222+
Callable to execute the pre-defined functions.
238223
239224
Returns:
240225
A dictionary. The dict has the key in self.report_format. The value of
241226
ImageStatsKeys.INTENSITY is in a list format. Each element of the value list
242227
has stats pre-defined by SampleOperations (max, min, ....).
243228
244229
Raises:
245-
RuntimeError if the stats report generated is not consistent with the pre-
230+
KeyError: if ``self.image_key`` is not present in the input data.
231+
TypeError: if the input data is not a dictionary, or if the image value is
232+
not a numpy array, torch.Tensor, or MetaTensor.
233+
ValueError: if the image has fewer than 3 dimensions, or if pre-computed
234+
``nda_croppeds`` is not a list/tuple with one entry per image channel.
235+
RuntimeError: if the stats report generated is not consistent with the pre-
246236
defined report_format.
247237
248238
Note:
249239
The stats operation uses numpy and torch to compute max, min, and other
250240
functions. If the input has nan/inf, the stats results will be nan/inf.
251241
252242
"""
243+
if not isinstance(data, dict):
244+
raise TypeError(f"Input data must be a dict, but got {type(data).__name__}.")
245+
if self.image_key not in data:
246+
raise KeyError(f"Key '{self.image_key}' not found in input data.")
247+
image = data[self.image_key]
248+
if not isinstance(image, (np.ndarray, torch.Tensor, MetaTensor)):
249+
raise TypeError(
250+
f"Value for '{self.image_key}' must be a numpy array, torch.Tensor, or MetaTensor, "
251+
f"but got {type(image).__name__}."
252+
)
253+
if image.ndim < 3:
254+
raise ValueError(
255+
f"Image data under '{self.image_key}' must have at least 3 dimensions, but got shape {image.shape}."
256+
)
257+
253258
d = dict(data)
254259
start = time.time()
255-
restore_grad_state = torch.is_grad_enabled()
256-
torch.set_grad_enabled(False)
257-
258260
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
259-
if "nda_croppeds" not in d:
261+
if "nda_croppeds" in d:
262+
nda_croppeds = d["nda_croppeds"]
263+
if not isinstance(nda_croppeds, (list, tuple)) or len(nda_croppeds) != len(ndas):
264+
raise ValueError(
265+
"Pre-computed 'nda_croppeds' must be a list or tuple with one entry per image channel "
266+
f"(expected {len(ndas)})."
267+
)
268+
else:
260269
nda_croppeds = [get_foreground_image(nda) for nda in ndas]
261270

262-
# perform calculation
263271
report = deepcopy(self.get_report_format())
264272

265273
report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
@@ -275,16 +283,13 @@ def __call__(self, data):
275283
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
276284
]
277285

278-
report[ImageStatsKeys.INTENSITY] = [
279-
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
280-
]
286+
report[ImageStatsKeys.INTENSITY] = [self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds]
281287

282288
if not verify_report_format(report, self.get_report_format()):
283289
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
284290

285291
d[self.stats_name] = report
286292

287-
torch.set_grad_enabled(restore_grad_state)
288293
logger.debug(f"Get image stats spent {time.time() - start}")
289294
return d
290295

@@ -321,6 +326,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKe
321326
super().__init__(stats_name, report_format)
322327
self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
323328

329+
@torch.no_grad()
324330
def __call__(self, data: Mapping) -> dict:
325331
"""
326332
Callable to execute the pre-defined functions
@@ -341,9 +347,6 @@ def __call__(self, data: Mapping) -> dict:
341347

342348
d = dict(data)
343349
start = time.time()
344-
restore_grad_state = torch.is_grad_enabled()
345-
torch.set_grad_enabled(False)
346-
347350
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
348351
ndas_label = d[self.label_key] # (H,W,D)
349352

@@ -353,19 +356,15 @@ def __call__(self, data: Mapping) -> dict:
353356
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
354357
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
355358

356-
# perform calculation
357359
report = deepcopy(self.get_report_format())
358360

359-
report[ImageStatsKeys.INTENSITY] = [
360-
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds
361-
]
361+
report[ImageStatsKeys.INTENSITY] = [self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds]
362362

363363
if not verify_report_format(report, self.get_report_format()):
364364
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
365365

366366
d[self.stats_name] = report
367367

368-
torch.set_grad_enabled(restore_grad_state)
369368
logger.debug(f"Get foreground image stats spent {time.time() - start}")
370369
return d
371370

@@ -418,6 +417,7 @@ def __init__(
418417
id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.IMAGE_INTST])
419418
self.update_ops_nested_label(id_seq, SampleOperations())
420419

420+
@torch.no_grad()
421421
def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTensor | dict]:
422422
"""
423423
Callable to execute the pre-defined functions.
@@ -470,19 +470,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470470
start = time.time()
471471
image_tensor = d[self.image_key]
472472
label_tensor = d[self.label_key]
473-
# Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
474473
using_cuda = any(
475474
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
476475
)
477-
restore_grad_state = torch.is_grad_enabled()
478-
torch.set_grad_enabled(False)
479476

480-
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
481-
label_tensor, (MetaTensor, torch.Tensor)
482-
):
477+
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(label_tensor, (MetaTensor, torch.Tensor)):
483478
if label_tensor.device != image_tensor.device:
484479
if using_cuda:
485-
# Move both tensors to CUDA when mixing devices
486480
cuda_device = image_tensor.device if image_tensor.device.type == "cuda" else label_tensor.device
487481
image_tensor = cast(MetaTensor, image_tensor.to(cuda_device))
488482
label_tensor = cast(MetaTensor, label_tensor.to(cuda_device))
@@ -504,7 +498,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
504498

505499
unique_label = unique_label.astype(np.int16).tolist()
506500

507-
label_substats = [] # each element is one label
501+
label_substats = []
508502
pixel_sum = 0
509503
pixel_arr = []
510504
for index in unique_label:
@@ -513,17 +507,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
513507
mask_index = ndas_label == index
514508

515509
nda_masks = [nda[mask_index] for nda in ndas]
516-
label_dict[LabelStatsKeys.IMAGE_INTST] = [
517-
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
518-
]
510+
label_dict[LabelStatsKeys.IMAGE_INTST] = [self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks]
519511

520512
pixel_count = sum(mask_index)
521513
pixel_arr.append(pixel_count)
522514
pixel_sum += pixel_count
523-
if self.do_ccp: # apply connected component
515+
if self.do_ccp:
524516
if using_cuda:
525-
# The back end of get_label_ccp is CuPy
526-
# which is unable to automatically release CUDA GPU memory held by PyTorch
527517
del nda_masks
528518
torch.cuda.empty_cache()
529519
shape_list, ncomponents = get_label_ccp(mask_index)
@@ -538,17 +528,14 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
538528

539529
report = deepcopy(self.get_report_format())
540530
report[LabelStatsKeys.LABEL_UID] = unique_label
541-
report[LabelStatsKeys.IMAGE_INTST] = [
542-
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds
543-
]
531+
report[LabelStatsKeys.IMAGE_INTST] = [self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds]
544532
report[LabelStatsKeys.LABEL] = label_substats
545533

546534
if not verify_report_format(report, self.get_report_format()):
547535
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
548536

549537
d[self.stats_name] = report # type: ignore[assignment]
550538

551-
torch.set_grad_enabled(restore_grad_state)
552539
logger.debug(f"Get label stats spent {time.time() - start}")
553540
return d # type: ignore[return-value]
554541

tests/apps/test_auto3dseg.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
SqueezeDimd,
5454
ToDeviced,
5555
)
56-
from monai.utils.enums import DataStatsKeys, LabelStatsKeys
56+
from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys
5757
from tests.test_utils import skip_if_no_cuda
5858

5959
device = "cpu"
@@ -322,6 +322,18 @@ def test_image_stats_case_analyzer(self):
322322
report_format = analyzer.get_report_format()
323323
assert verify_report_format(d["image_stats"], report_format)
324324

325+
def test_image_stats_uses_precomputed_nda_croppeds(self):
326+
analyzer = ImageStats(image_key="image")
327+
image = torch.arange(64.0, dtype=torch.float32).reshape(1, 4, 4, 4)
328+
nda_croppeds = [torch.ones((2, 2, 2), dtype=torch.float32)]
329+
330+
result = analyzer({"image": image, "nda_croppeds": nda_croppeds})
331+
report = result["image_stats"]
332+
333+
assert verify_report_format(report, analyzer.get_report_format())
334+
assert report[ImageStatsKeys.CROPPED_SHAPE] == [[2, 2, 2]]
335+
self.assertAlmostEqual(report[ImageStatsKeys.INTENSITY][0]["mean"], 1.0)
336+
325337
def test_foreground_image_stats_cases_analyzer(self):
326338
analyzer = FgImageStats(image_key="image", label_key="label")
327339
transform_list = [
@@ -412,6 +424,39 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
412424
self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75)
413425
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)
414426

427+
def test_case_analyzers_restore_grad_state_on_exception(self):
428+
cases = [
429+
(
430+
"image_stats",
431+
ImageStats(image_key="image"),
432+
{"image": torch.randn(1, 4, 4, 4), "nda_croppeds": [None]},
433+
AttributeError,
434+
),
435+
(
436+
"fg_image_stats",
437+
FgImageStats(image_key="image", label_key="label"),
438+
{"image": torch.randn(1, 4, 4, 4), "label": torch.ones(3, 4, 4)},
439+
ValueError,
440+
),
441+
(
442+
"label_stats",
443+
LabelStats(image_key="image", label_key="label"),
444+
{"image": MetaTensor(torch.randn(1, 4, 4, 4)), "label": MetaTensor(torch.ones(3, 4, 4))},
445+
ValueError,
446+
),
447+
]
448+
449+
original_grad_state = torch.is_grad_enabled()
450+
try:
451+
for name, analyzer, data, error in cases:
452+
with self.subTest(analyzer=name):
453+
torch.set_grad_enabled(True)
454+
with self.assertRaises(error):
455+
analyzer(data)
456+
self.assertTrue(torch.is_grad_enabled())
457+
finally:
458+
torch.set_grad_enabled(original_grad_state)
459+
415460
def test_filename_case_analyzer(self):
416461
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
417462
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)

0 commit comments

Comments
 (0)