Skip to content

Commit 91b69ed

Browse files
Update monai/transforms/post/dictionary.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: sewon jeon <irocks0922@gmail.com>
1 parent 8e34cc0 commit 91b69ed

1 file changed

Lines changed: 13 additions & 14 deletions

File tree

monai/transforms/post/dictionary.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -693,28 +693,27 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
693693
# If orig_key == key, the data at d[orig_key] may have been modified by
694694
# postprocessing transforms. We need to exclude any transforms that were
695695
# added after the preprocessing pipeline completed.
696+
# When orig_key == key, filter out postprocessing transforms to prevent
697+
# confusion during inversion (see issue #8396)
696698
if orig_key == key:
697-
# Count how many invertible transforms are in the preprocessing pipeline
698-
# This gives us the expected number of transforms that should be in
699-
# applied_operations from preprocessing
700699
num_preproc_transforms = 0
701-
if hasattr(self.transform, 'transforms'):
702-
for t in self.transform.flatten().transforms:
703-
if isinstance(t, InvertibleTransform):
704-
num_preproc_transforms += 1
705-
elif isinstance(self.transform, InvertibleTransform):
706-
num_preproc_transforms = 1
707-
708-
# Use only the first N transforms from applied_operations,
709-
# where N is the number of preprocessing transforms
710-
# This excludes any postprocessing transforms like Lambdad
700+
try:
701+
if hasattr(self.transform, 'transforms'):
702+
for t in self.transform.flatten().transforms:
703+
if isinstance(t, InvertibleTransform):
704+
num_preproc_transforms += 1
705+
elif isinstance(self.transform, InvertibleTransform):
706+
num_preproc_transforms = 1
707+
except AttributeError:
708+
# Fallback: use all transforms if flatten fails
709+
num_preproc_transforms = len(all_transforms)
710+
711711
if num_preproc_transforms > 0:
712712
transform_info = all_transforms[:num_preproc_transforms]
713713
else:
714714
transform_info = all_transforms
715715
else:
716716
transform_info = all_transforms
717-
else:
718717
transform_info = d[InvertibleTransform.trace_key(orig_key)]
719718
meta_info = d.get(orig_meta_key, {})
720719
if nearest_interp:

0 commit comments

Comments
 (0)