@@ -82,10 +82,6 @@ def _init_trace_threadlocal(self):
8282 if not hasattr (self ._tracing , "value" ):
8383 self ._tracing .value = MONAIEnvVars .trace_transform () != "0"
8484
85- # Initialize group identifier (set by Compose for automatic group tracking)
86- if not hasattr (self , "_group" ):
87- self ._group : str | None = None
88-
8985 def __getstate__ (self ):
9086 """When pickling, remove the `_tracing` member from the output, if present, since it's not picklable."""
9187 _dict = dict (getattr (self , "__dict__" , {})) # this makes __dict__ always present in the unpickled object
@@ -123,7 +119,6 @@ def get_transform_info(self) -> dict:
123119 """
124120 Return a dictionary with the relevant information pertaining to an applied transform.
125121 """
126- # Ensure _group is initialized
127122 self ._init_trace_threadlocal ()
128123
129124 vals = (
@@ -132,13 +127,7 @@ def get_transform_info(self) -> dict:
132127 self .tracing ,
133128 self ._do_transform if hasattr (self , "_do_transform" ) else True ,
134129 )
135- info = dict (zip (self .transform_info_keys (), vals ))
136-
137- # Add group if set (automatically set by Compose)
138- if self ._group is not None :
139- info [TraceKeys .GROUP ] = self ._group
140-
141- return info
130+ return dict (zip (self .transform_info_keys (), vals ))
142131
143132 def push_transform (self , data , * args , ** kwargs ):
144133 """
@@ -314,24 +303,33 @@ def track_transform_meta(
314303
315304 def check_transforms_match (self , transform : Mapping ) -> None :
316305 """Check transforms are of same instance."""
317- xform_id = transform .get (TraceKeys .ID , "" )
318- if xform_id == id (self ):
319- return
320- # TraceKeys.NONE to skip the id check
321- if xform_id == TraceKeys .NONE :
306+ if self ._transforms_match (transform ):
322307 return
308+
309+ xform_id = transform .get (TraceKeys .ID , "" )
323310 xform_name = transform .get (TraceKeys .CLASS_NAME , "" )
324311 warning_msg = transform .get (TraceKeys .EXTRA_INFO , {}).get ("warn" )
325312 if warning_msg :
326313 warnings .warn (warning_msg )
327- # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
328- if torch .multiprocessing .get_start_method () in ("spawn" , None ) and xform_name == self .__class__ .__name__ :
329- return
330314 raise RuntimeError (
331315 f"Error { self .__class__ .__name__ } getting the most recently "
332316 f"applied invertible transform { xform_name } { xform_id } != { id (self )} ."
333317 )
334318
319+ def _transforms_match (self , transform : Mapping ) -> bool :
320+ """Return whether a traced transform entry matches this transform instance."""
321+ xform_id = transform .get (TraceKeys .ID , "" )
322+ if xform_id == id (self ):
323+ return True
324+ # TraceKeys.NONE to skip the id check
325+ if xform_id == TraceKeys .NONE :
326+ return True
327+ xform_name = transform .get (TraceKeys .CLASS_NAME , "" )
328+ # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
329+ if torch .multiprocessing .get_start_method () in ("spawn" , None ) and xform_name == self .__class__ .__name__ :
330+ return True
331+ return False
332+
335333 def get_most_recent_transform (self , data , key : Hashable = None , check : bool = True , pop : bool = False ):
336334 """
337335 Get most recent matching transform for the current class from the sequence of applied operations.
@@ -363,10 +361,16 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr
363361 if not all_transforms :
364362 raise ValueError (f"Item of type { type (data )} (key: { key } , pop: { pop } ) has empty 'applied_operations'" )
365363
364+ match_idx = len (all_transforms ) - 1
366365 if check :
367- self .check_transforms_match (all_transforms [- 1 ])
366+ for idx in range (len (all_transforms ) - 1 , - 1 , - 1 ):
367+ if self ._transforms_match (all_transforms [idx ]):
368+ match_idx = idx
369+ break
370+ else :
371+ self .check_transforms_match (all_transforms [- 1 ])
368372
369- return all_transforms .pop (- 1 ) if pop else all_transforms [- 1 ]
373+ return all_transforms .pop (match_idx ) if pop else all_transforms [match_idx ]
370374
371375 def pop_transform (self , data , key : Hashable = None , check : bool = True ):
372376 """
0 commit comments