2020from monai .apps .reconstruction .transforms .array import EquispacedKspaceMask , RandomKspaceMask
2121from monai .config import DtypeLike , KeysCollection
2222from monai .config .type_definitions import NdarrayOrTensor
23+ from monai .data .meta_tensor import MetaTensor
2324from monai .transforms import InvertibleTransform
2425from monai .transforms .croppad .array import SpatialCrop
2526from monai .transforms .intensity .array import NormalizeIntensity
@@ -33,15 +34,36 @@ class ExtractDataKeyFromMetaKeyd(MapTransform):
3334 Moves keys from meta to data. It is useful when a dataset of paired samples
3435 is loaded and certain keys should be moved from meta to data.
3536
37+ This transform supports two modes:
38+
39+ 1. When ``meta_key`` references a metadata dictionary in the data (e.g., when
40+ ``image_only=False`` was used with ``LoadImaged``), the requested keys are
41+ extracted directly from that dictionary.
42+
43+ 2. When ``meta_key`` references a ``MetaTensor`` in the data (e.g., when
44+ ``image_only=True`` was used with ``LoadImaged``), the requested keys are
45+ extracted from its ``.meta`` attribute.
46+
3647 Args:
3748 keys: keys to be transferred from meta to data
38- meta_key: the meta key where all the meta-data is stored
49+ meta_key: the key in the data dictionary where the metadata source is
50+ stored. This can be either a metadata dictionary or a ``MetaTensor``.
3951 allow_missing_keys: don't raise exception if key is missing
4052
4153 Example:
4254 When the fastMRI dataset is loaded, "kspace" is stored in the data dictionary,
4355 but the ground-truth image with the key "reconstruction_rss" is stored in the meta data.
4456 In this case, ExtractDataKeyFromMetaKeyd moves "reconstruction_rss" to data.
57+
58+ When ``LoadImaged`` is used with ``image_only=True`` (the default), the loaded
59+ data is a ``MetaTensor`` with metadata accessible via ``.meta``. In this case,
60+ set ``meta_key`` to the key of the ``MetaTensor`` itself::
61+
62+ li = LoadImaged(keys="image") # image_only=True by default
63+ dat = li({"image": "image.nii"})
64+ e = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image")
65+ dat = e(dat)
66+ assert dat["image"].meta["filename_or_obj"] == dat["filename_or_obj"]
4567 """
4668
4769 def __init__ (self , keys : KeysCollection , meta_key : str , allow_missing_keys : bool = False ) -> None :
@@ -58,9 +80,18 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T
5880 the new data dictionary
5981 """
6082 d = dict (data )
83+ meta_obj = d [self .meta_key ]
84+
85+ # If meta_key references a MetaTensor, extract from its .meta attribute;
86+ # otherwise treat it as a metadata dictionary directly.
87+ if isinstance (meta_obj , MetaTensor ):
88+ meta_dict : dict = meta_obj .meta
89+ else :
90+ meta_dict = dict (meta_obj )
91+
6192 for key in self .keys :
62- if key in d [ self . meta_key ] :
63- d [key ] = d [ self . meta_key ] [key ] # type: ignore
93+ if key in meta_dict :
94+ d [key ] = meta_dict [key ] # type: ignore
6495 elif not self .allow_missing_keys :
6596 raise KeyError (
6697 f"Key `{ key } ` of transform `{ self .__class__ .__name__ } ` was missing in the meta data"
0 commit comments