Skip to content

Commit d9a2bc4

Browse files
committed
fix(auto3dseg): handle precomputed crops and safe no-grad cleanup
Signed-off-by: Jun Hyeok Lee <bluehyena123@naver.com>
1 parent 853f702 commit d9a2bc4

2 files changed

Lines changed: 144 additions & 107 deletions

File tree

monai/auto3dseg/analyzer.py

Lines changed: 98 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -252,39 +252,35 @@ def __call__(self, data):
252252
"""
253253
d = dict(data)
254254
start = time.time()
255-
restore_grad_state = torch.is_grad_enabled()
256-
torch.set_grad_enabled(False)
257-
258-
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
259-
if "nda_croppeds" not in d:
260-
nda_croppeds = [get_foreground_image(nda) for nda in ndas]
261-
262-
# perform calculation
263-
report = deepcopy(self.get_report_format())
264-
265-
report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
266-
report[ImageStatsKeys.CHANNELS] = len(ndas)
267-
report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
268-
report[ImageStatsKeys.SPACING] = (
269-
affine_to_spacing(data[self.image_key].affine).tolist()
270-
if isinstance(data[self.image_key], MetaTensor)
271-
else [1.0] * min(3, data[self.image_key].ndim)
272-
)
255+
with torch.no_grad():
256+
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
257+
nda_croppeds = d["nda_croppeds"] if "nda_croppeds" in d else [get_foreground_image(nda) for nda in ndas]
258+
259+
# perform calculation
260+
report = deepcopy(self.get_report_format())
261+
262+
report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
263+
report[ImageStatsKeys.CHANNELS] = len(ndas)
264+
report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
265+
report[ImageStatsKeys.SPACING] = (
266+
affine_to_spacing(data[self.image_key].affine).tolist()
267+
if isinstance(data[self.image_key], MetaTensor)
268+
else [1.0] * min(3, data[self.image_key].ndim)
269+
)
273270

274-
report[ImageStatsKeys.SIZEMM] = [
275-
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
276-
]
271+
report[ImageStatsKeys.SIZEMM] = [
272+
a * b for a, b in zip(report[ImageStatsKeys.SHAPE][0], report[ImageStatsKeys.SPACING])
273+
]
277274

278-
report[ImageStatsKeys.INTENSITY] = [
279-
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
280-
]
275+
report[ImageStatsKeys.INTENSITY] = [
276+
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
277+
]
281278

282-
if not verify_report_format(report, self.get_report_format()):
283-
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
279+
if not verify_report_format(report, self.get_report_format()):
280+
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
284281

285-
d[self.stats_name] = report
282+
d[self.stats_name] = report
286283

287-
torch.set_grad_enabled(restore_grad_state)
288284
logger.debug(f"Get image stats spent {time.time() - start}")
289285
return d
290286

@@ -341,31 +337,28 @@ def __call__(self, data: Mapping) -> dict:
341337

342338
d = dict(data)
343339
start = time.time()
344-
restore_grad_state = torch.is_grad_enabled()
345-
torch.set_grad_enabled(False)
340+
with torch.no_grad():
341+
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
342+
ndas_label = d[self.label_key] # (H,W,D)
346343

347-
ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
348-
ndas_label = d[self.label_key] # (H,W,D)
344+
if ndas_label.shape != ndas[0].shape:
345+
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
349346

350-
if ndas_label.shape != ndas[0].shape:
351-
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
347+
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
348+
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
352349

353-
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
354-
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
355-
356-
# perform calculation
357-
report = deepcopy(self.get_report_format())
350+
# perform calculation
351+
report = deepcopy(self.get_report_format())
358352

359-
report[ImageStatsKeys.INTENSITY] = [
360-
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds
361-
]
353+
report[ImageStatsKeys.INTENSITY] = [
354+
self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds
355+
]
362356

363-
if not verify_report_format(report, self.get_report_format()):
364-
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
357+
if not verify_report_format(report, self.get_report_format()):
358+
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
365359

366-
d[self.stats_name] = report
360+
d[self.stats_name] = report
367361

368-
torch.set_grad_enabled(restore_grad_state)
369362
logger.debug(f"Get foreground image stats spent {time.time() - start}")
370363
return d
371364

@@ -470,78 +463,77 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470463
start = time.time()
471464
image_tensor = d[self.image_key]
472465
label_tensor = d[self.label_key]
473-
using_cuda = any(
474-
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda" for t in (image_tensor, label_tensor)
475-
)
476-
restore_grad_state = torch.is_grad_enabled()
477-
torch.set_grad_enabled(False)
478-
479-
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
480-
label_tensor, (MetaTensor, torch.Tensor)
481-
):
482-
if label_tensor.device != image_tensor.device:
483-
label_tensor = label_tensor.to(image_tensor.device) # type: ignore
466+
with torch.no_grad():
467+
using_cuda = any(
468+
isinstance(t, (torch.Tensor, MetaTensor)) and t.device.type == "cuda"
469+
for t in (image_tensor, label_tensor)
470+
)
484471

485-
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
486-
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
472+
if isinstance(image_tensor, (MetaTensor, torch.Tensor)) and isinstance(
473+
label_tensor, (MetaTensor, torch.Tensor)
474+
):
475+
if label_tensor.device != image_tensor.device:
476+
label_tensor = label_tensor.to(image_tensor.device) # type: ignore
487477

488-
if ndas_label.shape != ndas[0].shape:
489-
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
478+
ndas: list[MetaTensor] = [image_tensor[i] for i in range(image_tensor.shape[0])] # type: ignore
479+
ndas_label: MetaTensor = label_tensor.astype(torch.int16) # (H,W,D)
490480

491-
nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
492-
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
481+
if ndas_label.shape != ndas[0].shape:
482+
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
493483

494-
unique_label = unique(ndas_label)
495-
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
496-
unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]
484+
nda_foregrounds: list[torch.Tensor] = [get_foreground_label(nda, ndas_label) for nda in ndas]
485+
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
497486

498-
unique_label = unique_label.astype(np.int16).tolist()
487+
unique_label = unique(ndas_label)
488+
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
489+
unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]
499490

500-
label_substats = [] # each element is one label
501-
pixel_sum = 0
502-
pixel_arr = []
503-
for index in unique_label:
504-
start_label = time.time()
505-
label_dict: dict[str, Any] = {}
506-
mask_index = ndas_label == index
491+
unique_label = unique_label.astype(np.int16).tolist()
507492

508-
nda_masks = [nda[mask_index] for nda in ndas]
509-
label_dict[LabelStatsKeys.IMAGE_INTST] = [
510-
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
511-
]
493+
label_substats = [] # each element is one label
494+
pixel_sum = 0
495+
pixel_arr = []
496+
for index in unique_label:
497+
start_label = time.time()
498+
label_dict: dict[str, Any] = {}
499+
mask_index = ndas_label == index
512500

513-
pixel_count = sum(mask_index)
514-
pixel_arr.append(pixel_count)
515-
pixel_sum += pixel_count
516-
if self.do_ccp: # apply connected component
517-
if using_cuda:
518-
# The back end of get_label_ccp is CuPy
519-
# which is unable to automatically release CUDA GPU memory held by PyTorch
520-
del nda_masks
521-
torch.cuda.empty_cache()
522-
shape_list, ncomponents = get_label_ccp(mask_index)
523-
label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list
524-
label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents
525-
526-
label_substats.append(label_dict)
527-
logger.debug(f" label {index} stats takes {time.time() - start_label}")
528-
529-
for i, _ in enumerate(unique_label):
530-
label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)})
501+
nda_masks = [nda[mask_index] for nda in ndas]
502+
label_dict[LabelStatsKeys.IMAGE_INTST] = [
503+
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
504+
]
531505

532-
report = deepcopy(self.get_report_format())
533-
report[LabelStatsKeys.LABEL_UID] = unique_label
534-
report[LabelStatsKeys.IMAGE_INTST] = [
535-
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds
536-
]
537-
report[LabelStatsKeys.LABEL] = label_substats
506+
pixel_count = sum(mask_index)
507+
pixel_arr.append(pixel_count)
508+
pixel_sum += pixel_count
509+
if self.do_ccp: # apply connected component
510+
if using_cuda:
511+
# The back end of get_label_ccp is CuPy
512+
# which is unable to automatically release CUDA GPU memory held by PyTorch
513+
del nda_masks
514+
torch.cuda.empty_cache()
515+
shape_list, ncomponents = get_label_ccp(mask_index)
516+
label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list
517+
label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents
518+
519+
label_substats.append(label_dict)
520+
logger.debug(f" label {index} stats takes {time.time() - start_label}")
521+
522+
for i, _ in enumerate(unique_label):
523+
label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)})
524+
525+
report = deepcopy(self.get_report_format())
526+
report[LabelStatsKeys.LABEL_UID] = unique_label
527+
report[LabelStatsKeys.IMAGE_INTST] = [
528+
self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds
529+
]
530+
report[LabelStatsKeys.LABEL] = label_substats
538531

539-
if not verify_report_format(report, self.get_report_format()):
540-
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
532+
if not verify_report_format(report, self.get_report_format()):
533+
raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
541534

542-
d[self.stats_name] = report # type: ignore[assignment]
535+
d[self.stats_name] = report # type: ignore[assignment]
543536

544-
torch.set_grad_enabled(restore_grad_state)
545537
logger.debug(f"Get label stats spent {time.time() - start}")
546538
return d # type: ignore[return-value]
547539

tests/apps/test_auto3dseg.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
SqueezeDimd,
5454
ToDeviced,
5555
)
56-
from monai.utils.enums import DataStatsKeys, LabelStatsKeys
56+
from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys
5757
from tests.test_utils import skip_if_no_cuda
5858

5959
device = "cpu"
@@ -322,6 +322,18 @@ def test_image_stats_case_analyzer(self):
322322
report_format = analyzer.get_report_format()
323323
assert verify_report_format(d["image_stats"], report_format)
324324

325+
def test_image_stats_uses_precomputed_nda_croppeds(self):
326+
analyzer = ImageStats(image_key="image")
327+
image = torch.arange(64.0, dtype=torch.float32).reshape(1, 4, 4, 4)
328+
nda_croppeds = [torch.ones((2, 2, 2), dtype=torch.float32)]
329+
330+
result = analyzer({"image": image, "nda_croppeds": nda_croppeds})
331+
report = result["image_stats"]
332+
333+
assert verify_report_format(report, analyzer.get_report_format())
334+
assert report[ImageStatsKeys.CROPPED_SHAPE] == [[2, 2, 2]]
335+
self.assertAlmostEqual(report[ImageStatsKeys.INTENSITY][0]["mean"], 1.0)
336+
325337
def test_foreground_image_stats_cases_analyzer(self):
326338
analyzer = FgImageStats(image_key="image", label_key="label")
327339
transform_list = [
@@ -411,6 +423,39 @@ def test_label_stats_mixed_device_analyzer(self, input_params):
411423
self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75)
412424
self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75)
413425

426+
def test_case_analyzers_restore_grad_state_on_exception(self):
427+
cases = [
428+
(
429+
"image_stats",
430+
ImageStats(image_key="image"),
431+
{"image": torch.randn(1, 4, 4, 4), "nda_croppeds": [None]},
432+
AttributeError,
433+
),
434+
(
435+
"fg_image_stats",
436+
FgImageStats(image_key="image", label_key="label"),
437+
{"image": torch.randn(1, 4, 4, 4), "label": torch.ones(3, 4, 4)},
438+
ValueError,
439+
),
440+
(
441+
"label_stats",
442+
LabelStats(image_key="image", label_key="label"),
443+
{"image": MetaTensor(torch.randn(1, 4, 4, 4)), "label": MetaTensor(torch.ones(3, 4, 4))},
444+
ValueError,
445+
),
446+
]
447+
448+
original_grad_state = torch.is_grad_enabled()
449+
try:
450+
for name, analyzer, data, error in cases:
451+
with self.subTest(analyzer=name):
452+
torch.set_grad_enabled(True)
453+
with self.assertRaises(error):
454+
analyzer(data)
455+
self.assertTrue(torch.is_grad_enabled())
456+
finally:
457+
torch.set_grad_enabled(original_grad_state)
458+
414459
def test_filename_case_analyzer(self):
415460
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
416461
analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)

0 commit comments

Comments
 (0)