# encoding: utf-8
__author__ = "Dimitrios Karkalousos, Chaoping Zhang"
import logging
import os
import random
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
import h5py
import numpy as np
import yaml # type: ignore
from torch.utils.data import Dataset
from mridc.collections.common.parts.utils import is_none
[docs]class qMRISliceDataset(Dataset):
"""A dataset that loads slices from a single dataset."""
def __init__(
self,
root: Union[str, Path, os.PathLike],
transform: Optional[Callable] = None,
sense_root: Union[str, Path, os.PathLike] = None,
sequence: str = None,
use_dataset_cache: bool = False,
sample_rate: Optional[float] = None,
volume_sample_rate: Optional[float] = None,
dataset_cache_file: Union[str, Path, os.PathLike] = "dataset_cache.yaml",
num_cols: Optional[Tuple[int]] = None,
mask_root: Union[str, Path, os.PathLike] = None,
consecutive_slices: int = 1,
data_saved_per_slice: bool = False,
init_coil_dim: int = 0,
fixed_precomputed_acceleration: Optional[int] = None,
kspace_scaling_factor: float = 10000,
):
"""
Parameters
----------
root: Path to the dataset.
transform: Optional; A sequence of callable objects that preprocesses the raw data into appropriate form.
The transform function should take 'kspace', 'target', 'attributes', 'filename', and 'slice' as inputs.
'target' may be null for test data.
sense_root: Path to the coil sensitivities maps dataset.
sequence: Sequence of the dataset.
use_dataset_cache: Whether to cache dataset metadata. This is very useful for large datasets like the brain
data.
sample_rate: Optional; A sequence of floats between 0 and 1. This controls what fraction of the slices
should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or
volume_sample_rates (sample by volumes) but not both.
volume_sample_rate: Optional; A sequence of floats between 0 and 1. This controls what fraction of the
volumes should be loaded. When creating subsampled datasets either set sample_rates (sample by slices)
or volume_sample_rates (sample by volumes) but not both.
dataset_cache_file: Optional; A file in which to cache dataset information for faster load times.
num_cols: Optional; If provided, only slices with the desired number of columns will be considered.
mask_root: Path to stored masks.
consecutive_slices: An int (>0) that determine the amount of consecutive slices of the file to be loaded at
the same time. Defaults to 1, loading single slices.
data_saved_per_slice: Whether the data is saved per slice or per volume.
init_coil_dim: The initial coil dimension of the data.
fixed_precomputed_acceleration: Optional; A list of integers that determine the fixed acceleration of the
data. If provided, the data will be loaded with the fixed acceleration.
kspace_scaling_factor: A float that determines the scaling factor of the k-space data.
"""
if sequence not in ("MEGRE", "FUTURE_SEQUENCES"):
raise ValueError(f'Sequence should be either "MEGRE" or "FUTURE_SEQUENCES". Found {sequence}.')
if sample_rate is not None and volume_sample_rate is not None:
raise ValueError(
"either set sample_rate (sample by slices) or volume_sample_rate (sample by volumes) but not both"
)
self.sense_root = sense_root
self.mask_root = mask_root
self.dataset_cache_file = Path(dataset_cache_file)
self.data_saved_per_slice = data_saved_per_slice
self.init_coil_dim = init_coil_dim
self.fixed_precomputed_acceleration = fixed_precomputed_acceleration
self.kspace_scaling_factor = kspace_scaling_factor
self.transform = transform
self.recons_key = "reconstruction"
self.examples = []
# set default sampling mode if none given
if sample_rate is None:
sample_rate = 1.0
if volume_sample_rate is None:
volume_sample_rate = 1.0
# load dataset cache if we have and user wants to use it
if self.dataset_cache_file.exists() and use_dataset_cache:
with open(self.dataset_cache_file, "rb") as f:
dataset_cache = yaml.safe_load(f)
else:
dataset_cache = {}
# check if our dataset is in the cache
# if there, use that metadata, if not, then regenerate the metadata
if dataset_cache.get(root) is None or not use_dataset_cache:
files = list(Path(root).iterdir())
for fname in sorted(files):
metadata, num_slices = self._retrieve_metadata(fname, data_saved_per_slice=self.data_saved_per_slice)
if not is_none(num_slices) and not is_none(consecutive_slices):
num_slices = num_slices - (consecutive_slices - 1)
self.examples += [(fname, slice_ind, metadata) for slice_ind in range(num_slices)]
if dataset_cache.get(root) is None and use_dataset_cache:
dataset_cache[root] = self.examples
logging.info(f"Saving dataset cache to {self.dataset_cache_file}.")
with open(self.dataset_cache_file, "wb") as f: # type: ignore
yaml.dump(dataset_cache, f) # type: ignore
else:
logging.info(f"Using dataset cache from {self.dataset_cache_file}.")
self.examples = dataset_cache[root]
# subsample if desired
if sample_rate < 1.0: # sample by slice
random.shuffle(self.examples)
num_examples = round(len(self.examples) * sample_rate)
self.examples = self.examples[:num_examples]
elif volume_sample_rate < 1.0: # sample by volume
vol_names = sorted(list({f[0].stem for f in self.examples}))
random.shuffle(vol_names)
num_volumes = round(len(vol_names) * volume_sample_rate)
sampled_vols = vol_names[:num_volumes]
self.examples = [example for example in self.examples if example[0].stem in sampled_vols]
if num_cols:
self.examples = [ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols] # type: ignore
# Create random number generator used for consecutive slice selection and set consecutive slice amount
self.consecutive_slices = consecutive_slices
if self.consecutive_slices < 1:
raise ValueError("consecutive_slices value is out of range, must be > 0.")
@staticmethod
def _retrieve_metadata(fname, data_saved_per_slice=False):
"""
Retrieve metadata from a given file.
Parameters
----------
fname: Path to file.
data_saved_per_slice: Whether the data is saved per slice or per volume.
Returns
-------
A dictionary containing the metadata.
"""
with h5py.File(fname, "r") as hf:
padding_left = 0
padding_right = 0
enc_size = 0
recon_size = (0, 0)
if "kspace" in hf:
shape = hf["kspace"].shape
elif "ksp" in hf:
shape = hf["ksp"].shape
elif "reconstruction" in hf:
shape = hf["reconstruction"].shape
else:
raise ValueError(f"{fname} does not contain kspace or reconstruction data.")
num_slices = 1 if data_saved_per_slice else shape[0]
metadata = {
"padding_left": padding_left,
"padding_right": padding_right,
"encoding_size": enc_size,
"recon_size": recon_size,
}
return metadata, num_slices
[docs] def get_consecutive_slices(self, data, key, dataslice):
"""
Get consecutive slices from a given data.
Args:
data: Data to extract slices from.
key: Key to extract slices from.
dataslice: Slice to extract slices from.
Returns:
A list of consecutive slices.
"""
data = data[key]
if self.data_saved_per_slice:
data = np.expand_dims(data, axis=0)
if self.consecutive_slices == 1:
if data.shape[0] == 1:
return data[0]
elif data.ndim != 2:
return data[dataslice]
return data
num_slices = data.shape[0]
if self.consecutive_slices > num_slices:
return np.stack(data, axis=0)
start_slice = dataslice
if dataslice + self.consecutive_slices <= num_slices:
end_slice = dataslice + self.consecutive_slices
else:
end_slice = num_slices
return data[start_slice:end_slice]
[docs] def check_stored_qdata(self, data, key, dataslice):
"""
Check if quantitative data are stored in the dataset.
Parameters
----------
data: Data to extract slices from.
key: Key to extract slices from.
dataslice: Slice to extract.
"""
qdata = []
count = 0
for k in data.keys():
if key in k:
acc = k.split("_")[-1].split("x")[0]
if acc not in ["brain", "head"]:
x = self.get_consecutive_slices(data, key + str(acc) + "x", dataslice)
if x.ndim == 3:
x = x[dataslice]
if (
self.fixed_precomputed_acceleration is not None
and int(acc) == self.fixed_precomputed_acceleration
or self.fixed_precomputed_acceleration is None
):
qdata.append(x)
else:
count += 1
if self.fixed_precomputed_acceleration is not None:
qdata = [qdata[0] * count]
return qdata
def __len__(self):
return len(self.examples)
def __getitem__(self, i: int):
fname, dataslice, metadata = self.examples[i]
with h5py.File(fname, "r") as hf:
if "kspace" in hf:
kspace = self.get_consecutive_slices(hf, "kspace", dataslice).astype(np.complex64)
elif "ksp" in hf:
kspace = self.get_consecutive_slices(hf, "ksp", dataslice).astype(np.complex64)
else:
raise ValueError("No kspace data found in file. Only 'kspace' or 'ksp' keys are supported.")
if self.init_coil_dim in [3, 4, -1]:
kspace = np.transpose(kspace, (0, 3, 1, 2)) # [nr_TEs, nr_channels, nr_rows, nr_cols]
kspace = kspace / self.kspace_scaling_factor
if "sensitivity_map" in hf:
sensitivity_map = self.get_consecutive_slices(hf, "sensitivity_map", dataslice).astype(np.complex64)
elif "sense" in hf:
sensitivity_map = self.get_consecutive_slices(hf, "sense", dataslice).astype(np.complex64)
elif self.sense_root is not None and self.sense_root != "None":
with h5py.File(Path(self.sense_root) / Path(str(fname).split("/")[-2]) / fname.name, "r") as sf:
if "sensitivity_map" in sf or "sensitivity_map" in next(iter(sf.keys())):
sensitivity_map = self.get_consecutive_slices(sf, "sensitivity_map", dataslice)
else:
sensitivity_map = self.get_consecutive_slices(sf, "sense", dataslice)
sensitivity_map = sensitivity_map.squeeze().astype(np.complex64)
else:
sensitivity_map = np.array([])
if self.init_coil_dim in [3, 4, -1]:
sensitivity_map = np.transpose(sensitivity_map, (2, 0, 1)) # [nr_channels, nr_rows, nr_cols]
if "mask" in hf:
mask = np.asarray(self.get_consecutive_slices(hf, "mask", dataslice))
if mask.ndim == 3:
mask = mask[dataslice]
elif any("mask_" in _ for _ in hf.keys()):
mask = self.check_stored_qdata(hf, "mask_", dataslice)
elif self.mask_root is not None and self.mask_root != "None":
with h5py.File(Path(self.mask_root) / fname.name, "r") as mf:
mask = np.asarray(self.get_consecutive_slices(mf, "mask", dataslice))
else:
mask = np.empty([])
if "mask_brain" in hf:
mask_brain = np.asarray(self.get_consecutive_slices(hf, "mask_brain", dataslice))
else:
mask_brain = np.empty([])
if "mask_head" in hf.keys():
mask_head = np.asarray(self.get_consecutive_slices(hf, "mask_head", dataslice))
else:
mask_head = np.empty([])
mask = [mask, mask_brain, mask_head]
if any("B0_map_init_" in _ for _ in hf.keys()):
B0_map = self.check_stored_qdata(hf, "B0_map_init_", dataslice)
if all("B0_map_target" not in _ for _ in hf.keys()):
raise ValueError("While B0 map initializations are found, no B0 map target found in file.")
B0_map_target = self.get_consecutive_slices(hf, "B0_map_target", dataslice)
B0_map.append(B0_map_target)
else:
B0_map = np.empty([])
if any("S0_map_init_" in _ for _ in hf.keys()):
S0_map = self.check_stored_qdata(hf, "S0_map_init_", dataslice)
if all("S0_map_target" not in _ for _ in hf.keys()):
raise ValueError("While S0 map initializations are found, no S0 map target found in file.")
S0_map_target = self.get_consecutive_slices(hf, "S0_map_target", dataslice)
S0_map.append(S0_map_target)
else:
S0_map = np.empty([])
if any("R2star_map_init_" in _ for _ in hf.keys()):
R2star_map = self.check_stored_qdata(hf, "R2star_map_init_", dataslice)
if all("R2star_map_target" not in _ for _ in hf.keys()):
raise ValueError("While R2star map initializations are found, no R2star map target found in file.")
R2star_map_target = self.get_consecutive_slices(hf, "R2star_map_target", dataslice)
R2star_map.append(R2star_map_target)
else:
R2star_map = np.empty([])
if any("phi_map_init_" in _ for _ in hf.keys()):
phi_map = self.check_stored_qdata(hf, "phi_map_init_", dataslice)
if all("phi_map_target" not in _ for _ in hf.keys()):
raise ValueError("While phi map initializations are found, no phi map target found in file.")
phi_map_target = self.get_consecutive_slices(hf, "phi_map_target", dataslice)
phi_map.append(phi_map_target)
else:
phi_map = np.empty([])
qmaps = [B0_map, S0_map, R2star_map, phi_map]
eta = (
self.get_consecutive_slices(hf, "eta", dataslice).astype(np.complex64) if "eta" in hf else np.array([])
)
if "reconstruction_sense" in hf:
self.recons_key = "reconstruction_sense"
target = self.get_consecutive_slices(hf, self.recons_key, dataslice) if self.recons_key in hf else None
attrs = dict(hf.attrs)
attrs.update(metadata)
if self.data_saved_per_slice:
# arbitrary slice number for logging purposes
fname = fname.name # type: ignore
dataslice = int(fname.split("_")[-1]) # type: ignore
fname = "_".join(fname.split("_")[:-1]) # type: ignore
return (
(
kspace,
sensitivity_map,
qmaps,
mask,
eta,
target,
attrs,
fname,
dataslice,
)
if self.transform is None
else self.transform(
kspace,
sensitivity_map,
qmaps,
mask,
eta,
target,
attrs,
fname,
dataslice,
)
)