@@ -693,28 +693,27 @@ def __call__(self, data: Mapping[Hashable, Any]) -> dict[Hashable, Any]:
693693 # If orig_key == key, the data at d[orig_key] may have been modified by
694694 # postprocessing transforms. We need to exclude any transforms that were
695695 # added after the preprocessing pipeline completed.
696+ # When orig_key == key, filter out postprocessing transforms to prevent
697+ # confusion during inversion (see issue #8396)
696698 if orig_key == key :
697- # Count how many invertible transforms are in the preprocessing pipeline
698- # This gives us the expected number of transforms that should be in
699- # applied_operations from preprocessing
700699 num_preproc_transforms = 0
701- if hasattr (self .transform , 'transforms' ):
702- for t in self .transform .flatten ().transforms :
703- if isinstance (t , InvertibleTransform ):
704- num_preproc_transforms += 1
705- elif isinstance (self .transform , InvertibleTransform ):
706- num_preproc_transforms = 1
707-
708- # Use only the first N transforms from applied_operations,
709- # where N is the number of preprocessing transforms
710- # This excludes any postprocessing transforms like Lambdad
700+ try :
701+ if hasattr (self .transform , 'transforms' ):
702+ for t in self .transform .flatten ().transforms :
703+ if isinstance (t , InvertibleTransform ):
704+ num_preproc_transforms += 1
705+ elif isinstance (self .transform , InvertibleTransform ):
706+ num_preproc_transforms = 1
707+ except AttributeError :
708+ # Fallback: use all transforms if flatten fails
709+ num_preproc_transforms = len (all_transforms )
710+
711711 if num_preproc_transforms > 0 :
712712 transform_info = all_transforms [:num_preproc_transforms ]
713713 else :
714714 transform_info = all_transforms
715715 else :
716716 transform_info = all_transforms
717- else :
718717 transform_info = d [InvertibleTransform .trace_key (orig_key )]
719718 meta_info = d .get (orig_meta_key , {})
720719 if nearest_interp :
0 commit comments