Skip to content

Commit c42261b

Browse files
committed
Fix inverse matching with mixed postprocessing history
Signed-off-by: sewon.jeon <irocks0922@gmail.com>
1 parent 226deb6 commit c42261b

6 files changed

Lines changed: 69 additions & 231 deletions

File tree

monai/transforms/compose.py

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -255,90 +255,6 @@ def __init__(
255255
self.set_random_state(seed=get_seed())
256256
self.overrides = overrides
257257

258-
# Automatically assign group ID to child transforms for inversion tracking
259-
self._set_transform_groups()
260-
261-
def _set_transform_groups(self):
262-
"""
263-
Automatically set group IDs on child transforms for inversion tracking.
264-
265-
This allows Invertd to identify which transforms belong to this
266-
``Compose`` instance, including wrapped transforms (for example,
267-
array transforms inside dictionary transforms).
268-
269-
Args:
270-
None.
271-
272-
Returns:
273-
None.
274-
"""
275-
from monai.transforms.inverse import TraceableTransform
276-
277-
group_id = str(id(self))
278-
visited = set() # Track visited objects to avoid infinite recursion
279-
280-
def set_group_recursive(obj, gid, allow_compose: bool = False):
281-
"""
282-
Recursively set a group ID on a transform and its wrapped transforms.
283-
284-
Args:
285-
obj: Transform instance to process.
286-
gid: Group identifier to assign.
287-
allow_compose: Whether to set group on ``Compose`` instances.
288-
``Compose`` internals are not traversed to preserve nested
289-
pipeline boundaries.
290-
291-
Returns:
292-
None.
293-
"""
294-
if obj is None or isinstance(obj, (bool, int, float, str, bytes)):
295-
return
296-
297-
# Avoid infinite recursion
298-
obj_id = id(obj)
299-
if obj_id in visited:
300-
return
301-
visited.add(obj_id)
302-
303-
if isinstance(obj, Compose):
304-
if allow_compose:
305-
obj._group = gid
306-
return
307-
308-
if isinstance(obj, TraceableTransform):
309-
obj._group = gid
310-
311-
if isinstance(obj, Mapping):
312-
for attr in obj.values():
313-
set_group_recursive(attr, gid)
314-
return
315-
316-
if isinstance(obj, (list, tuple, set)):
317-
for attr in obj:
318-
set_group_recursive(attr, gid)
319-
return
320-
321-
attrs: list[Any] = []
322-
if hasattr(obj, "__dict__"):
323-
attrs.extend(vars(obj).values())
324-
325-
slots = getattr(type(obj), "__slots__", ())
326-
if isinstance(slots, str):
327-
slots = (slots,)
328-
for slot in slots:
329-
if slot.startswith("__"):
330-
continue
331-
try:
332-
attrs.append(getattr(obj, slot))
333-
except AttributeError:
334-
continue
335-
336-
for attr in attrs:
337-
set_group_recursive(attr, gid)
338-
339-
for transform in self.transforms:
340-
set_group_recursive(transform, group_id, allow_compose=True)
341-
342258
@LazyTransform.lazy.setter # type: ignore
343259
def lazy(self, val: bool):
344260
self._lazy = val

monai/transforms/inverse.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,6 @@ def _init_trace_threadlocal(self):
8282
if not hasattr(self._tracing, "value"):
8383
self._tracing.value = MONAIEnvVars.trace_transform() != "0"
8484

85-
# Initialize group identifier (set by Compose for automatic group tracking)
86-
if not hasattr(self, "_group"):
87-
self._group: str | None = None
88-
8985
def __getstate__(self):
9086
"""When pickling, remove the `_tracing` member from the output, if present, since it's not picklable."""
9187
_dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object
@@ -123,7 +119,6 @@ def get_transform_info(self) -> dict:
123119
"""
124120
Return a dictionary with the relevant information pertaining to an applied transform.
125121
"""
126-
# Ensure _group is initialized
127122
self._init_trace_threadlocal()
128123

129124
vals = (
@@ -132,13 +127,7 @@ def get_transform_info(self) -> dict:
132127
self.tracing,
133128
self._do_transform if hasattr(self, "_do_transform") else True,
134129
)
135-
info = dict(zip(self.transform_info_keys(), vals))
136-
137-
# Add group if set (automatically set by Compose)
138-
if self._group is not None:
139-
info[TraceKeys.GROUP] = self._group
140-
141-
return info
130+
return dict(zip(self.transform_info_keys(), vals))
142131

143132
def push_transform(self, data, *args, **kwargs):
144133
"""
@@ -314,24 +303,33 @@ def track_transform_meta(
314303

315304
def check_transforms_match(self, transform: Mapping) -> None:
316305
"""Check transforms are of same instance."""
317-
xform_id = transform.get(TraceKeys.ID, "")
318-
if xform_id == id(self):
319-
return
320-
# TraceKeys.NONE to skip the id check
321-
if xform_id == TraceKeys.NONE:
306+
if self._transforms_match(transform):
322307
return
308+
309+
xform_id = transform.get(TraceKeys.ID, "")
323310
xform_name = transform.get(TraceKeys.CLASS_NAME, "")
324311
warning_msg = transform.get(TraceKeys.EXTRA_INFO, {}).get("warn")
325312
if warning_msg:
326313
warnings.warn(warning_msg)
327-
# basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
328-
if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__:
329-
return
330314
raise RuntimeError(
331315
f"Error {self.__class__.__name__} getting the most recently "
332316
f"applied invertible transform {xform_name} {xform_id} != {id(self)}."
333317
)
334318

319+
def _transforms_match(self, transform: Mapping) -> bool:
320+
"""Return whether a traced transform entry matches this transform instance."""
321+
xform_id = transform.get(TraceKeys.ID, "")
322+
if xform_id == id(self):
323+
return True
324+
# TraceKeys.NONE to skip the id check
325+
if xform_id == TraceKeys.NONE:
326+
return True
327+
xform_name = transform.get(TraceKeys.CLASS_NAME, "")
328+
# basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID)
329+
if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__:
330+
return True
331+
return False
332+
335333
def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False):
336334
"""
337335
Get most recent matching transform for the current class from the sequence of applied operations.
@@ -363,10 +361,16 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr
363361
if not all_transforms:
364362
raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'")
365363

364+
match_idx = len(all_transforms) - 1
366365
if check:
367-
self.check_transforms_match(all_transforms[-1])
366+
for idx in range(len(all_transforms) - 1, -1, -1):
367+
if self._transforms_match(all_transforms[idx]):
368+
match_idx = idx
369+
break
370+
else:
371+
self.check_transforms_match(all_transforms[-1])
368372

369-
return all_transforms.pop(-1) if pop else all_transforms[-1]
373+
return all_transforms.pop(match_idx) if pop else all_transforms[match_idx]
370374

371375
def pop_transform(self, data, key: Hashable = None, check: bool = True):
372376
"""

monai/transforms/post/dictionary.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from monai.transforms.transform import MapTransform
4949
from monai.transforms.utility.array import ToTensor
5050
from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode
51-
from monai.utils import PostFix, TraceKeys, convert_to_tensor, ensure_tuple, ensure_tuple_rep
51+
from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep
5252
from monai.utils.type_conversion import convert_to_dst_type
5353

5454
__all__ = [
@@ -859,27 +859,6 @@ def __init__(
859859
self.post_func = ensure_tuple_rep(post_func, len(self.keys))
860860
self._totensor = ToTensor()
861861

862-
def _filter_transforms_by_group(self, all_transforms: list[dict]) -> list[dict]:
863-
"""Filter applied operations to only include transforms from the target pipeline.
864-
865-
Uses automatic group tracking where ``Compose`` assigns its ID to child transforms.
866-
867-
Args:
868-
all_transforms: Full list of applied transform metadata dictionaries.
869-
870-
Returns:
871-
Subset whose ``TraceKeys.GROUP`` matches ``str(id(self.transform))``, or the original
872-
list when no match is found for backward compatibility.
873-
"""
874-
# Get the group ID of the transform (Compose instance)
875-
target_group = str(id(self.transform))
876-
877-
# Filter transforms that match the target group
878-
filtered = [xform for xform in all_transforms if xform.get(TraceKeys.GROUP) == target_group]
879-
880-
# If no transforms match (backward compatibility), return all transforms
881-
return filtered if filtered else all_transforms
882-
883862
def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
884863
d = dict(data)
885864
for (
@@ -915,13 +894,10 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
915894

916895
orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}"
917896
if orig_key in d and isinstance(d[orig_key], MetaTensor):
918-
all_transforms = d[orig_key].applied_operations
897+
transform_info = d[orig_key].applied_operations
919898
meta_info = d[orig_key].meta
920-
921-
# Automatically filter by Compose instance group ID
922-
transform_info = self._filter_transforms_by_group(all_transforms)
923899
else:
924-
transform_info = self._filter_transforms_by_group(d[InvertibleTransform.trace_key(orig_key)])
900+
transform_info = d[InvertibleTransform.trace_key(orig_key)]
925901
meta_info = d.get(orig_meta_key, {})
926902
if nearest_interp:
927903
transform_info = convert_applied_interp_mode(

monai/utils/enums.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,6 @@ class TraceKeys(StrEnum):
334334
TRACING: str = "tracing"
335335
STATUSES: str = "statuses"
336336
LAZY: str = "lazy"
337-
GROUP: str = "group"
338337

339338

340339
class TraceStatusKeys(StrEnum):

tests/transforms/compose/test_compose.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -268,43 +268,6 @@ def test_data_loader_2(self):
268268
self.assertAlmostEqual(out_1.cpu().item(), 0.28602141572)
269269
set_determinism(None)
270270

271-
def test_set_transform_groups_on_wrapped_transform_attributes(self):
272-
class _IdentityInvertible(mt.InvertibleTransform):
273-
def __call__(self, data):
274-
return data
275-
276-
def inverse(self, data):
277-
return data
278-
279-
class _WrapperWithTransform:
280-
def __init__(self):
281-
self.transform = _IdentityInvertible()
282-
283-
def __call__(self, data):
284-
return self.transform(data)
285-
286-
class _WrapperWithTransforms:
287-
def __init__(self):
288-
self.transforms = [_IdentityInvertible(), {"inner": _IdentityInvertible()}]
289-
290-
def __call__(self, data):
291-
for transform in self.transforms:
292-
if isinstance(transform, dict):
293-
for nested_transform in transform.values():
294-
data = nested_transform(data)
295-
else:
296-
data = transform(data)
297-
return data
298-
299-
wrapped_transform = _WrapperWithTransform()
300-
wrapped_transforms = _WrapperWithTransforms()
301-
composed = mt.Compose([wrapped_transform, wrapped_transforms])
302-
expected_group = str(id(composed))
303-
304-
self.assertEqual(getattr(wrapped_transform.transform, "_group", None), expected_group)
305-
self.assertEqual(getattr(wrapped_transforms.transforms[0], "_group", None), expected_group)
306-
self.assertEqual(getattr(wrapped_transforms.transforms[1]["inner"], "_group", None), expected_group)
307-
308271
def test_flatten_and_len(self):
309272
x = mt.EnsureChannelFirst(channel_dim="no_channel")
310273
t1 = mt.Compose([x, x, x, x, mt.Compose([mt.Compose([x, x]), x, x])])

0 commit comments

Comments
 (0)