@@ -252,39 +252,35 @@ def __call__(self, data):
252252 """
253253 d = dict (data )
254254 start = time .time ()
255- restore_grad_state = torch .is_grad_enabled ()
256- torch .set_grad_enabled (False )
257-
258- ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
259- if "nda_croppeds" not in d :
260- nda_croppeds = [get_foreground_image (nda ) for nda in ndas ]
261-
262- # perform calculation
263- report = deepcopy (self .get_report_format ())
264-
265- report [ImageStatsKeys .SHAPE ] = [list (nda .shape ) for nda in ndas ]
266- report [ImageStatsKeys .CHANNELS ] = len (ndas )
267- report [ImageStatsKeys .CROPPED_SHAPE ] = [list (nda_c .shape ) for nda_c in nda_croppeds ]
268- report [ImageStatsKeys .SPACING ] = (
269- affine_to_spacing (data [self .image_key ].affine ).tolist ()
270- if isinstance (data [self .image_key ], MetaTensor )
271- else [1.0 ] * min (3 , data [self .image_key ].ndim )
272- )
255+ with torch .no_grad ():
256+ ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
257+ nda_croppeds = d ["nda_croppeds" ] if "nda_croppeds" in d else [get_foreground_image (nda ) for nda in ndas ]
258+
259+ # perform calculation
260+ report = deepcopy (self .get_report_format ())
261+
262+ report [ImageStatsKeys .SHAPE ] = [list (nda .shape ) for nda in ndas ]
263+ report [ImageStatsKeys .CHANNELS ] = len (ndas )
264+ report [ImageStatsKeys .CROPPED_SHAPE ] = [list (nda_c .shape ) for nda_c in nda_croppeds ]
265+ report [ImageStatsKeys .SPACING ] = (
266+ affine_to_spacing (data [self .image_key ].affine ).tolist ()
267+ if isinstance (data [self .image_key ], MetaTensor )
268+ else [1.0 ] * min (3 , data [self .image_key ].ndim )
269+ )
273270
274- report [ImageStatsKeys .SIZEMM ] = [
275- a * b for a , b in zip (report [ImageStatsKeys .SHAPE ][0 ], report [ImageStatsKeys .SPACING ])
276- ]
271+ report [ImageStatsKeys .SIZEMM ] = [
272+ a * b for a , b in zip (report [ImageStatsKeys .SHAPE ][0 ], report [ImageStatsKeys .SPACING ])
273+ ]
277274
278- report [ImageStatsKeys .INTENSITY ] = [
279- self .ops [ImageStatsKeys .INTENSITY ].evaluate (nda_c ) for nda_c in nda_croppeds
280- ]
275+ report [ImageStatsKeys .INTENSITY ] = [
276+ self .ops [ImageStatsKeys .INTENSITY ].evaluate (nda_c ) for nda_c in nda_croppeds
277+ ]
281278
282- if not verify_report_format (report , self .get_report_format ()):
283- raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
279+ if not verify_report_format (report , self .get_report_format ()):
280+ raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
284281
285- d [self .stats_name ] = report
282+ d [self .stats_name ] = report
286283
287- torch .set_grad_enabled (restore_grad_state )
288284 logger .debug (f"Get image stats spent { time .time () - start } " )
289285 return d
290286
@@ -341,31 +337,28 @@ def __call__(self, data: Mapping) -> dict:
341337
342338 d = dict (data )
343339 start = time .time ()
344- restore_grad_state = torch .is_grad_enabled ()
345- torch .set_grad_enabled (False )
340+ with torch .no_grad ():
341+ ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
342+ ndas_label = d [self .label_key ] # (H,W,D)
346343
347- ndas = [ d [ self . image_key ][ i ] for i in range ( d [ self . image_key ]. shape [0 ])]
348- ndas_label = d [ self . label_key ] # (H,W,D )
344+ if ndas_label . shape != ndas [0 ]. shape :
345+ raise ValueError ( f"Label shape { ndas_label . shape } is different from image shape { ndas [ 0 ]. shape } " )
349346
350- if ndas_label . shape != ndas [ 0 ]. shape :
351- raise ValueError ( f"Label shape { ndas_label . shape } is different from image shape { ndas [ 0 ]. shape } " )
347+ nda_foregrounds = [ get_foreground_label ( nda , ndas_label ) for nda in ndas ]
348+ nda_foregrounds = [ nda if nda . numel () > 0 else MetaTensor ([ 0.0 ]) for nda in nda_foregrounds ]
352349
353- nda_foregrounds = [get_foreground_label (nda , ndas_label ) for nda in ndas ]
354- nda_foregrounds = [nda if nda .numel () > 0 else MetaTensor ([0.0 ]) for nda in nda_foregrounds ]
355-
356- # perform calculation
357- report = deepcopy (self .get_report_format ())
350+ # perform calculation
351+ report = deepcopy (self .get_report_format ())
358352
359- report [ImageStatsKeys .INTENSITY ] = [
360- self .ops [ImageStatsKeys .INTENSITY ].evaluate (nda_f ) for nda_f in nda_foregrounds
361- ]
353+ report [ImageStatsKeys .INTENSITY ] = [
354+ self .ops [ImageStatsKeys .INTENSITY ].evaluate (nda_f ) for nda_f in nda_foregrounds
355+ ]
362356
363- if not verify_report_format (report , self .get_report_format ()):
364- raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
357+ if not verify_report_format (report , self .get_report_format ()):
358+ raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
365359
366- d [self .stats_name ] = report
360+ d [self .stats_name ] = report
367361
368- torch .set_grad_enabled (restore_grad_state )
369362 logger .debug (f"Get foreground image stats spent { time .time () - start } " )
370363 return d
371364
@@ -470,78 +463,77 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470463 start = time .time ()
471464 image_tensor = d [self .image_key ]
472465 label_tensor = d [self .label_key ]
473- using_cuda = any (
474- isinstance (t , (torch .Tensor , MetaTensor )) and t .device .type == "cuda" for t in (image_tensor , label_tensor )
475- )
476- restore_grad_state = torch .is_grad_enabled ()
477- torch .set_grad_enabled (False )
478-
479- if isinstance (image_tensor , (MetaTensor , torch .Tensor )) and isinstance (
480- label_tensor , (MetaTensor , torch .Tensor )
481- ):
482- if label_tensor .device != image_tensor .device :
483- label_tensor = label_tensor .to (image_tensor .device ) # type: ignore
466+ with torch .no_grad ():
467+ using_cuda = any (
468+ isinstance (t , (torch .Tensor , MetaTensor )) and t .device .type == "cuda"
469+ for t in (image_tensor , label_tensor )
470+ )
484471
485- ndas : list [MetaTensor ] = [image_tensor [i ] for i in range (image_tensor .shape [0 ])] # type: ignore
486- ndas_label : MetaTensor = label_tensor .astype (torch .int16 ) # (H,W,D)
472+ if isinstance (image_tensor , (MetaTensor , torch .Tensor )) and isinstance (
473+ label_tensor , (MetaTensor , torch .Tensor )
474+ ):
475+ if label_tensor .device != image_tensor .device :
476+ label_tensor = label_tensor .to (image_tensor .device ) # type: ignore
487477
488- if ndas_label . shape != ndas [ 0 ] .shape :
489- raise ValueError ( f"Label shape { ndas_label . shape } is different from image shape { ndas [ 0 ]. shape } " )
478+ ndas : list [ MetaTensor ] = [ image_tensor [ i ] for i in range ( image_tensor .shape [ 0 ])] # type: ignore
479+ ndas_label : MetaTensor = label_tensor . astype ( torch . int16 ) # (H,W,D )
490480
491- nda_foregrounds : list [ torch . Tensor ] = [ get_foreground_label ( nda , ndas_label ) for nda in ndas ]
492- nda_foregrounds = [ nda if nda . numel () > 0 else MetaTensor ([ 0.0 ]) for nda in nda_foregrounds ]
481+ if ndas_label . shape != ndas [ 0 ]. shape :
482+ raise ValueError ( f"Label shape { ndas_label . shape } is different from image shape { ndas [ 0 ]. shape } " )
493483
494- unique_label = unique (ndas_label )
495- if isinstance (ndas_label , (MetaTensor , torch .Tensor )):
496- unique_label = unique_label .data .cpu ().numpy () # type: ignore[assignment]
484+ nda_foregrounds : list [torch .Tensor ] = [get_foreground_label (nda , ndas_label ) for nda in ndas ]
485+ nda_foregrounds = [nda if nda .numel () > 0 else MetaTensor ([0.0 ]) for nda in nda_foregrounds ]
497486
498- unique_label = unique_label .astype (np .int16 ).tolist ()
487+ unique_label = unique (ndas_label )
488+ if isinstance (ndas_label , (MetaTensor , torch .Tensor )):
489+ unique_label = unique_label .data .cpu ().numpy () # type: ignore[assignment]
499490
500- label_substats = [] # each element is one label
501- pixel_sum = 0
502- pixel_arr = []
503- for index in unique_label :
504- start_label = time .time ()
505- label_dict : dict [str , Any ] = {}
506- mask_index = ndas_label == index
491+ unique_label = unique_label .astype (np .int16 ).tolist ()
507492
508- nda_masks = [nda [mask_index ] for nda in ndas ]
509- label_dict [LabelStatsKeys .IMAGE_INTST ] = [
510- self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_m ) for nda_m in nda_masks
511- ]
493+ label_substats = [] # each element is one label
494+ pixel_sum = 0
495+ pixel_arr = []
496+ for index in unique_label :
497+ start_label = time .time ()
498+ label_dict : dict [str , Any ] = {}
499+ mask_index = ndas_label == index
512500
513- pixel_count = sum (mask_index )
514- pixel_arr .append (pixel_count )
515- pixel_sum += pixel_count
516- if self .do_ccp : # apply connected component
517- if using_cuda :
518- # The back end of get_label_ccp is CuPy
519- # which is unable to automatically release CUDA GPU memory held by PyTorch
520- del nda_masks
521- torch .cuda .empty_cache ()
522- shape_list , ncomponents = get_label_ccp (mask_index )
523- label_dict [LabelStatsKeys .LABEL_SHAPE ] = shape_list
524- label_dict [LabelStatsKeys .LABEL_NCOMP ] = ncomponents
525-
526- label_substats .append (label_dict )
527- logger .debug (f" label { index } stats takes { time .time () - start_label } " )
528-
529- for i , _ in enumerate (unique_label ):
530- label_substats [i ].update ({LabelStatsKeys .PIXEL_PCT : float (pixel_arr [i ] / pixel_sum )})
501+ nda_masks = [nda [mask_index ] for nda in ndas ]
502+ label_dict [LabelStatsKeys .IMAGE_INTST ] = [
503+ self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_m ) for nda_m in nda_masks
504+ ]
531505
532- report = deepcopy (self .get_report_format ())
533- report [LabelStatsKeys .LABEL_UID ] = unique_label
534- report [LabelStatsKeys .IMAGE_INTST ] = [
535- self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_f ) for nda_f in nda_foregrounds
536- ]
537- report [LabelStatsKeys .LABEL ] = label_substats
506+ pixel_count = sum (mask_index )
507+ pixel_arr .append (pixel_count )
508+ pixel_sum += pixel_count
509+ if self .do_ccp : # apply connected component
510+ if using_cuda :
511+ # The back end of get_label_ccp is CuPy
512+ # which is unable to automatically release CUDA GPU memory held by PyTorch
513+ del nda_masks
514+ torch .cuda .empty_cache ()
515+ shape_list , ncomponents = get_label_ccp (mask_index )
516+ label_dict [LabelStatsKeys .LABEL_SHAPE ] = shape_list
517+ label_dict [LabelStatsKeys .LABEL_NCOMP ] = ncomponents
518+
519+ label_substats .append (label_dict )
520+ logger .debug (f" label { index } stats takes { time .time () - start_label } " )
521+
522+ for i , _ in enumerate (unique_label ):
523+ label_substats [i ].update ({LabelStatsKeys .PIXEL_PCT : float (pixel_arr [i ] / pixel_sum )})
524+
525+ report = deepcopy (self .get_report_format ())
526+ report [LabelStatsKeys .LABEL_UID ] = unique_label
527+ report [LabelStatsKeys .IMAGE_INTST ] = [
528+ self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_f ) for nda_f in nda_foregrounds
529+ ]
530+ report [LabelStatsKeys .LABEL ] = label_substats
538531
539- if not verify_report_format (report , self .get_report_format ()):
540- raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
532+ if not verify_report_format (report , self .get_report_format ()):
533+ raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
541534
542- d [self .stats_name ] = report # type: ignore[assignment]
535+ d [self .stats_name ] = report # type: ignore[assignment]
543536
544- torch .set_grad_enabled (restore_grad_state )
545537 logger .debug (f"Get label stats spent { time .time () - start } " )
546538 return d # type: ignore[return-value]
547539
0 commit comments