@@ -216,50 +216,58 @@ def __init__(self, image_key: str, stats_name: str = DataStatsKeys.IMAGE_STATS)
216216 super ().__init__ (stats_name , report_format )
217217 self .update_ops (ImageStatsKeys .INTENSITY , SampleOperations ())
218218
219+ @torch .no_grad ()
219220 def __call__ (self , data ):
220- # Input Validation Addition
221- if not isinstance (data , dict ):
222- raise TypeError (f"Input data must be a dict, but got { type (data ).__name__ } ." )
223- if self .image_key not in data :
224- raise KeyError (f"Key '{ self .image_key } ' not found in input data." )
225- image = data [self .image_key ]
226- if not isinstance (image , (np .ndarray , torch .Tensor , MetaTensor )):
227- raise TypeError (
228- f"Value for '{ self .image_key } ' must be a numpy array, torch.Tensor, or MetaTensor, "
229- f"but got { type (image ).__name__ } ."
230- )
231- if image .ndim < 3 :
232- raise ValueError (
233- f"Image data under '{ self .image_key } ' must have at least 3 dimensions, but got shape { image .shape } ."
234- )
235- # --- End of validation ---
236221 """
237- Callable to execute the pre-defined functions
222+ Callable to execute the pre-defined functions.
238223
239224 Returns:
240225 A dictionary. The dict has the key in self.report_format. The value of
241226 ImageStatsKeys.INTENSITY is in a list format. Each element of the value list
242227 has stats pre-defined by SampleOperations (max, min, ....).
243228
244229 Raises:
245- RuntimeError if the stats report generated is not consistent with the pre-
230+ KeyError: if ``self.image_key`` is not present in the input data.
231+ TypeError: if the input data is not a dictionary, or if the image value is
232+ not a numpy array, torch.Tensor, or MetaTensor.
233+ ValueError: if the image has fewer than 3 dimensions, or if pre-computed
234+ ``nda_croppeds`` is not a list/tuple with one entry per image channel.
235+ RuntimeError: if the stats report generated is not consistent with the pre-
246236 defined report_format.
247237
248238 Note:
249239 The stats operation uses numpy and torch to compute max, min, and other
250240 functions. If the input has nan/inf, the stats results will be nan/inf.
251241
252242 """
243+ if not isinstance (data , dict ):
244+ raise TypeError (f"Input data must be a dict, but got { type (data ).__name__ } ." )
245+ if self .image_key not in data :
246+ raise KeyError (f"Key '{ self .image_key } ' not found in input data." )
247+ image = data [self .image_key ]
248+ if not isinstance (image , (np .ndarray , torch .Tensor , MetaTensor )):
249+ raise TypeError (
250+ f"Value for '{ self .image_key } ' must be a numpy array, torch.Tensor, or MetaTensor, "
251+ f"but got { type (image ).__name__ } ."
252+ )
253+ if image .ndim < 3 :
254+ raise ValueError (
255+ f"Image data under '{ self .image_key } ' must have at least 3 dimensions, but got shape { image .shape } ."
256+ )
257+
253258 d = dict (data )
254259 start = time .time ()
255- restore_grad_state = torch .is_grad_enabled ()
256- torch .set_grad_enabled (False )
257-
258260 ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
259- if "nda_croppeds" not in d :
261+ if "nda_croppeds" in d :
262+ nda_croppeds = d ["nda_croppeds" ]
263+ if not isinstance (nda_croppeds , (list , tuple )) or len (nda_croppeds ) != len (ndas ):
264+ raise ValueError (
265+ "Pre-computed 'nda_croppeds' must be a list or tuple with one entry per image channel "
266+ f"(expected { len (ndas )} )."
267+ )
268+ else :
260269 nda_croppeds = [get_foreground_image (nda ) for nda in ndas ]
261270
262- # perform calculation
263271 report = deepcopy (self .get_report_format ())
264272
265273 report [ImageStatsKeys .SHAPE ] = [list (nda .shape ) for nda in ndas ]
@@ -275,16 +283,13 @@ def __call__(self, data):
275283 a * b for a , b in zip (report [ImageStatsKeys .SHAPE ][0 ], report [ImageStatsKeys .SPACING ])
276284 ]
277285
278- report [ImageStatsKeys .INTENSITY ] = [
279- self .ops [ImageStatsKeys .INTENSITY ].evaluate (nda_c ) for nda_c in nda_croppeds
280- ]
286+ report [ImageStatsKeys .INTENSITY ] = [self .ops [ImageStatsKeys .INTENSITY ].evaluate (nda_c ) for nda_c in nda_croppeds ]
281287
282288 if not verify_report_format (report , self .get_report_format ()):
283289 raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
284290
285291 d [self .stats_name ] = report
286292
287- torch .set_grad_enabled (restore_grad_state )
288293 logger .debug (f"Get image stats spent { time .time () - start } " )
289294 return d
290295
@@ -321,6 +326,7 @@ def __init__(self, image_key: str, label_key: str, stats_name: str = DataStatsKe
321326 super ().__init__ (stats_name , report_format )
322327 self .update_ops (ImageStatsKeys .INTENSITY , SampleOperations ())
323328
329+ @torch .no_grad ()
324330 def __call__ (self , data : Mapping ) -> dict :
325331 """
326332 Callable to execute the pre-defined functions
@@ -341,9 +347,6 @@ def __call__(self, data: Mapping) -> dict:
341347
342348 d = dict (data )
343349 start = time .time ()
344- restore_grad_state = torch .is_grad_enabled ()
345- torch .set_grad_enabled (False )
346-
347350 ndas = [d [self .image_key ][i ] for i in range (d [self .image_key ].shape [0 ])]
348351 ndas_label = d [self .label_key ] # (H,W,D)
349352
@@ -353,19 +356,15 @@ def __call__(self, data: Mapping) -> dict:
353356 nda_foregrounds = [get_foreground_label (nda , ndas_label ) for nda in ndas ]
354357 nda_foregrounds = [nda if nda .numel () > 0 else MetaTensor ([0.0 ]) for nda in nda_foregrounds ]
355358
356- # perform calculation
357359 report = deepcopy (self .get_report_format ())
358360
359- report [ImageStatsKeys .INTENSITY ] = [
360- self .ops [ImageStatsKeys .INTENSITY ].evaluate (nda_f ) for nda_f in nda_foregrounds
361- ]
361+ report [ImageStatsKeys .INTENSITY ] = [self .ops [ImageStatsKeys .INTENSITY ].evaluate (nda_f ) for nda_f in nda_foregrounds ]
362362
363363 if not verify_report_format (report , self .get_report_format ()):
364364 raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
365365
366366 d [self .stats_name ] = report
367367
368- torch .set_grad_enabled (restore_grad_state )
369368 logger .debug (f"Get foreground image stats spent { time .time () - start } " )
370369 return d
371370
@@ -418,6 +417,7 @@ def __init__(
418417 id_seq = ID_SEP_KEY .join ([LabelStatsKeys .LABEL , "0" , LabelStatsKeys .IMAGE_INTST ])
419418 self .update_ops_nested_label (id_seq , SampleOperations ())
420419
420+ @torch .no_grad ()
421421 def __call__ (self , data : Mapping [Hashable , MetaTensor ]) -> dict [Hashable , MetaTensor | dict ]:
422422 """
423423 Callable to execute the pre-defined functions.
@@ -470,19 +470,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
470470 start = time .time ()
471471 image_tensor = d [self .image_key ]
472472 label_tensor = d [self .label_key ]
473- # Check if either tensor is on CUDA to determine if we should move both to CUDA for processing
474473 using_cuda = any (
475474 isinstance (t , (torch .Tensor , MetaTensor )) and t .device .type == "cuda" for t in (image_tensor , label_tensor )
476475 )
477- restore_grad_state = torch .is_grad_enabled ()
478- torch .set_grad_enabled (False )
479476
480- if isinstance (image_tensor , (MetaTensor , torch .Tensor )) and isinstance (
481- label_tensor , (MetaTensor , torch .Tensor )
482- ):
477+ if isinstance (image_tensor , (MetaTensor , torch .Tensor )) and isinstance (label_tensor , (MetaTensor , torch .Tensor )):
483478 if label_tensor .device != image_tensor .device :
484479 if using_cuda :
485- # Move both tensors to CUDA when mixing devices
486480 cuda_device = image_tensor .device if image_tensor .device .type == "cuda" else label_tensor .device
487481 image_tensor = cast (MetaTensor , image_tensor .to (cuda_device ))
488482 label_tensor = cast (MetaTensor , label_tensor .to (cuda_device ))
@@ -504,7 +498,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
504498
505499 unique_label = unique_label .astype (np .int16 ).tolist ()
506500
507- label_substats = [] # each element is one label
501+ label_substats = []
508502 pixel_sum = 0
509503 pixel_arr = []
510504 for index in unique_label :
@@ -513,17 +507,13 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
513507 mask_index = ndas_label == index
514508
515509 nda_masks = [nda [mask_index ] for nda in ndas ]
516- label_dict [LabelStatsKeys .IMAGE_INTST ] = [
517- self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_m ) for nda_m in nda_masks
518- ]
510+ label_dict [LabelStatsKeys .IMAGE_INTST ] = [self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_m ) for nda_m in nda_masks ]
519511
520512 pixel_count = sum (mask_index )
521513 pixel_arr .append (pixel_count )
522514 pixel_sum += pixel_count
523- if self .do_ccp : # apply connected component
515+ if self .do_ccp :
524516 if using_cuda :
525- # The back end of get_label_ccp is CuPy
526- # which is unable to automatically release CUDA GPU memory held by PyTorch
527517 del nda_masks
528518 torch .cuda .empty_cache ()
529519 shape_list , ncomponents = get_label_ccp (mask_index )
@@ -538,17 +528,14 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
538528
539529 report = deepcopy (self .get_report_format ())
540530 report [LabelStatsKeys .LABEL_UID ] = unique_label
541- report [LabelStatsKeys .IMAGE_INTST ] = [
542- self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_f ) for nda_f in nda_foregrounds
543- ]
531+ report [LabelStatsKeys .IMAGE_INTST ] = [self .ops [LabelStatsKeys .IMAGE_INTST ].evaluate (nda_f ) for nda_f in nda_foregrounds ]
544532 report [LabelStatsKeys .LABEL ] = label_substats
545533
546534 if not verify_report_format (report , self .get_report_format ()):
547535 raise RuntimeError (f"report generated by { self .__class__ } differs from the report format." )
548536
549537 d [self .stats_name ] = report # type: ignore[assignment]
550538
551- torch .set_grad_enabled (restore_grad_state )
552539 logger .debug (f"Get label stats spent { time .time () - start } " )
553540 return d # type: ignore[return-value]
554541
0 commit comments