Source code for mridc.collections.reconstruction.data.mri_data

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Parts of the code have been taken from https://github.com/facebookresearch/fastMRI

import logging
import os
import random
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import h5py
import numpy as np
import torch
import yaml  # type: ignore
from defusedxml.ElementTree import fromstring
from torch.utils.data import Dataset

from mridc.collections.common.parts.utils import is_none


[docs]def et_query(root: str, qlist: Sequence[str], namespace: str = "https://www.ismrm.org/ISMRMRD") -> str: """ Query an XML element for a list of attributes. Parameters ---------- root: The root element of the XML tree. qlist: A list of strings, each of which is an attribute name. namespace: The namespace of the XML tree. Returns ------- A string containing the value of the last attribute in the list. """ s = "." prefix = "ismrmrd_namespace" ns = {prefix: namespace} for el in qlist: s += f"//{prefix}:{el}" value = root.find(s, ns) # type: ignore if value is None: return "0" return str(value.text) # type: ignore
[docs]class FastMRICombinedSliceDataset(torch.utils.data.Dataset): """A dataset that combines multiple datasets.""" def __init__( self, roots: Sequence[Path], challenges: Sequence[str], sense_roots: Optional[Sequence[Path]] = None, transforms: Optional[Sequence[Optional[Callable]]] = None, sample_rates: Optional[Sequence[Optional[float]]] = None, volume_sample_rates: Optional[Sequence[Optional[float]]] = None, use_dataset_cache: bool = False, dataset_cache_file: Union[str, Path, os.PathLike] = "dataset_cache.yaml", num_cols: Optional[Tuple[int]] = None, ): """ Parameters ---------- roots: Paths to the datasets. challenges: "singlecoil" or "multicoil" depending on which challenge to use. sense_roots: Load pre-computed (stored) sensitivity maps. transforms: Optional; A sequence of callable objects that preprocesses the raw data into appropriate form. The transform function should take 'kspace', 'target', 'attributes', 'filename', and 'slice' as inputs. 'target' may be null for test data. sample_rates: Optional; A sequence of floats between 0 and 1. This controls what fraction of the slices should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both. volume_sample_rates: Optional; A sequence of floats between 0 and 1. This controls what fraction of the volumes should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both. use_dataset_cache: Whether to cache dataset metadata. This is very useful for large datasets like the brain data. dataset_cache_file: Optional; A file in which to cache dataset information for faster load times. num_cols: Optional; If provided, only slices with the desired number of columns will be considered. """ if sample_rates is not None and volume_sample_rates is not None: raise ValueError( "either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both" ) if transforms is None: transforms = [None] * len(roots) if sample_rates is None: sample_rates = [None] * len(roots) if volume_sample_rates is None: volume_sample_rates = [None] * len(roots) if not len(roots) == len(transforms) == len(challenges) == len(sample_rates) == len(volume_sample_rates): raise ValueError("Lengths of roots, transforms, challenges, sample_rates do not match") self.datasets = [] self.examples: List[Tuple[Path, int, Dict[str, object]]] = [] for i, _ in enumerate(roots): self.datasets.append( FastMRISliceDataset( root=roots[i], transform=transforms[i], sense_root=sense_roots[i] if sense_roots is not None else None, challenge=challenges[i], sample_rate=sample_rates[i], volume_sample_rate=volume_sample_rates[i], use_dataset_cache=use_dataset_cache, dataset_cache_file=dataset_cache_file, num_cols=num_cols, ) ) self.examples += self.datasets[-1].examples def __len__(self): return sum(len(dataset) for dataset in self.datasets) def __getitem__(self, i): for dataset in self.datasets: if i < len(dataset): return dataset[i] i = i - len(dataset)
[docs]class FastMRISliceDataset(Dataset): """A dataset that loads slices from a single dataset.""" def __init__( self, root: Union[str, Path, os.PathLike], challenge: str = "segmentation", transform: Optional[Callable] = None, sense_root: Union[str, Path, os.PathLike] = None, use_dataset_cache: bool = False, sample_rate: Optional[float] = None, volume_sample_rate: Optional[float] = None, dataset_cache_file: Union[str, Path, os.PathLike] = "dataset_cache.yaml", num_cols: Optional[Tuple[int]] = None, mask_root: Union[str, Path, os.PathLike] = None, consecutive_slices: int = 1, ): """ Parameters ---------- root: Path to the dataset. challenge: "singlecoil" or "multicoil" depending on which challenge to use. transform: Optional; A sequence of callable objects that preprocesses the raw data into appropriate form. The transform function should take 'kspace', 'target', 'attributes', 'filename', and 'slice' as inputs. 'target' may be null for test data. sense_root: Path to the coil sensitivities maps dataset. use_dataset_cache: Whether to cache dataset metadata. This is very useful for large datasets like the brain data. sample_rate: Optional; A sequence of floats between 0 and 1. This controls what fraction of the slices should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both. volume_sample_rate: Optional; A sequence of floats between 0 and 1. This controls what fraction of the volumes should be loaded. When creating subsampled datasets either set sample_rates (sample by slices) or volume_sample_rates (sample by volumes) but not both. dataset_cache_file: Optional; A file in which to cache dataset information for faster load times. num_cols: Optional; If provided, only slices with the desired number of columns will be considered. mask_root: Path to stored masks. consecutive_slices: An int (>0) that determine the amount of consecutive slices of the file to be loaded at the same time. Defaults to 1, loading single slices. """ if challenge not in ("singlecoil", "multicoil", "segmentation"): raise ValueError('challenge should be either "singlecoil" or "multicoil" or "segmentation"') self.challenge = challenge if sample_rate is not None and volume_sample_rate is not None: raise ValueError( "either set sample_rate (sample by slices) or volume_sample_rate (sample by volumes) but not both" ) self.sense_root = sense_root self.mask_root = mask_root self.dataset_cache_file = Path(dataset_cache_file) self.transform = transform self.recons_key = "reconstruction_esc" if challenge == "singlecoil" else "reconstruction_rss" self.examples = [] # set default sampling mode if none given if sample_rate is None: sample_rate = 1.0 if volume_sample_rate is None: volume_sample_rate = 1.0 # load dataset cache if we have and user wants to use it if self.dataset_cache_file.exists() and use_dataset_cache: with open(self.dataset_cache_file, "rb") as f: dataset_cache = yaml.safe_load(f) else: dataset_cache = {} # check if our dataset is in the cache # if there, use that metadata, if not, then regenerate the metadata if dataset_cache.get(root) is None or not use_dataset_cache: files = list(Path(root).iterdir()) for fname in sorted(files): metadata, num_slices = self._retrieve_metadata(fname) if not is_none(num_slices) and not is_none(consecutive_slices): num_slices = num_slices - (consecutive_slices - 1) self.examples += [(fname, slice_ind, metadata) for slice_ind in range(num_slices)] if dataset_cache.get(root) is None and use_dataset_cache: dataset_cache[root] = self.examples logging.info(f"Saving dataset cache to {self.dataset_cache_file}.") with open(self.dataset_cache_file, "wb") as f: # type: ignore yaml.dump(dataset_cache, f) # type: ignore else: logging.info(f"Using dataset cache from {self.dataset_cache_file}.") self.examples = dataset_cache[root] # subsample if desired if sample_rate < 1.0: # sample by slice random.shuffle(self.examples) num_examples = round(len(self.examples) * sample_rate) self.examples = self.examples[:num_examples] elif volume_sample_rate < 1.0: # sample by volume vol_names = sorted(list({f[0].stem for f in self.examples})) random.shuffle(vol_names) num_volumes = round(len(vol_names) * volume_sample_rate) sampled_vols = vol_names[:num_volumes] self.examples = [example for example in self.examples if example[0].stem in sampled_vols] if num_cols: self.examples = [ex for ex in self.examples if ex[2]["encoding_size"][1] in num_cols] # type: ignore # Create random number generator used for consecutive slice selection and set consecutive slice amount self.consecutive_slices = consecutive_slices if self.consecutive_slices < 1: raise ValueError("consecutive_slices value is out of range, must be > 0.") @staticmethod def _retrieve_metadata(fname): """ Retrieve metadata from a given file. Parameters ---------- fname: Path to file. Returns ------- A dictionary containing the metadata. """ with h5py.File(fname, "r") as hf: if "ismrmrd_header" in hf: et_root = fromstring(hf["ismrmrd_header"][()]) enc = ["encoding", "encodedSpace", "matrixSize"] enc_size = ( int(et_query(et_root, enc + ["x"])), int(et_query(et_root, enc + ["y"])), int(et_query(et_root, enc + ["z"])), ) rec = ["encoding", "reconSpace", "matrixSize"] recon_size = ( int(et_query(et_root, rec + ["x"])), int(et_query(et_root, rec + ["y"])), int(et_query(et_root, rec + ["z"])), ) params = ["encoding", "encodingLimits", "kspace_encoding_step_1"] enc_limits_center = int(et_query(et_root, params + ["center"])) enc_limits_max = int(et_query(et_root, params + ["maximum"])) + 1 padding_left = torch.div(enc_size[1], 2, rounding_mode="trunc").item() - enc_limits_center padding_right = padding_left + enc_limits_max else: padding_left = 0 padding_right = 0 enc_size = 0 recon_size = (0, 0) num_slices = hf["kspace"].shape[0] if "kspace" in hf else hf["reconstruction"].shape[0] metadata = { "padding_left": padding_left, "padding_right": padding_right, "encoding_size": enc_size, "recon_size": recon_size, } return metadata, num_slices
[docs] def get_consecutive_slices(self, data, key, dataslice): """ Get consecutive slices from a given data. Args: data: Data to extract slices from. key: Key to extract slices from. dataslice: Slice to extract slices from. Returns: A list of consecutive slices. """ data = data[key] if self.consecutive_slices == 1: if data.shape[0] == 1: return data[0] elif data.ndim != 2: return data[dataslice] return data num_slices = data.shape[0] if self.consecutive_slices > num_slices: return np.stack(data, axis=0) start_slice = dataslice if dataslice + self.consecutive_slices <= num_slices: end_slice = dataslice + self.consecutive_slices else: end_slice = num_slices return data[start_slice:end_slice]
def __len__(self): return len(self.examples) def __getitem__(self, i: int): fname, dataslice, metadata = self.examples[i] with h5py.File(fname, "r") as hf: kspace = self.get_consecutive_slices(hf, "kspace", dataslice).astype(np.complex64) if "sensitivity_map" in hf: sensitivity_map = self.get_consecutive_slices(hf, "sensitivity_map", dataslice).astype(np.complex64) elif self.sense_root is not None and self.sense_root != "None": with h5py.File(Path(self.sense_root) / Path(str(fname).split("/")[-2]) / fname.name, "r") as sf: if "sensitivity_map" in sf or "sensitivity_map" in next(iter(sf.keys())): sensitivity_map = self.get_consecutive_slices(sf, "sensitivity_map", dataslice) else: sensitivity_map = self.get_consecutive_slices(sf, "sense", dataslice) sensitivity_map = sensitivity_map.squeeze().astype(np.complex64) else: sensitivity_map = np.array([]) if "mask" in hf: mask = np.asarray(self.get_consecutive_slices(hf, "mask", dataslice)) if mask.ndim == 3: mask = mask[dataslice] elif self.mask_root is not None and self.mask_root != "None": with h5py.File(Path(self.mask_root) / fname.name, "r") as mf: mask = np.asarray(self.get_consecutive_slices(mf, "mask", dataslice)) else: mask = None eta = ( self.get_consecutive_slices(hf, "eta", dataslice).astype(np.complex64) if "eta" in hf else np.array([]) ) if "reconstruction_sense" in hf: self.recons_key = "reconstruction_sense" target = ( self.get_consecutive_slices(hf, self.recons_key, dataslice).astype(np.float32) if self.recons_key in hf else None ) attrs = dict(hf.attrs) attrs.update(metadata) if sensitivity_map.shape != kspace.shape: if sensitivity_map.ndim == 3: sensitivity_map = np.transpose(sensitivity_map, (2, 0, 1)) elif sensitivity_map.ndim == 4: sensitivity_map = np.transpose(sensitivity_map, (0, 3, 1, 2)) else: raise ValueError( f"Sensitivity map has invalid dimensions {sensitivity_map.shape} compared to kspace {kspace.shape}" ) return ( ( kspace, sensitivity_map, mask, eta, target, attrs, fname.name, dataslice, ) if self.transform is None else self.transform( kspace, sensitivity_map, mask, eta, target, attrs, fname.name, dataslice, ) )