# encoding: utf-8
__author__ = "Dimitrios Karkalousos"
# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/data/dataset.py
from abc import ABC
from typing import Any, Dict, List
import numpy as np
import torch.utils.data as pt_data
from torch.utils.data import Dataset, IterableDataset
__all__ = ["ConcatDataset", "ConcatMapDataset"]
[docs]class ConcatDataset(pt_data.IterableDataset, ABC):
"""
A dataset that accepts as argument multiple datasets and then samples from them based on the specified
sampling technique.
Parameters
----------
datasets: A list of datasets to sample from.
shuffle: Whether to shuffle individual datasets. Only works with non-iterable datasets. Defaults to True.
sampling_technique: Sampling technique to choose which dataset to draw a sample from. Defaults to 'random'.
Currently supports 'random' and 'round-robin'.
sampling_probabilities: Probability values for sampling. Only used when sampling_technique = 'random'.
global_rank: Worker rank, used for partitioning map style datasets. Defaults to 0.
world_size: Total number of processes, used for partitioning map style datasets. Defaults to 1.
"""
def __init__(
self,
datasets: List[Any],
shuffle: bool = True,
sampling_technique: str = "random",
sampling_probabilities: List[float] = None,
global_rank: int = 0,
world_size: int = 1,
):
super().__init__()
self.datasets = datasets
self.iterables = [None] * len(datasets)
self.shuffle = shuffle
self.global_rank = global_rank
self.world_size = world_size
self.sampling_kwargs = {}
if sampling_technique == "random":
self.index_generator = ConcatDataset.random_generator
self.sampling_kwargs["p"] = sampling_probabilities # type: ignore
elif sampling_technique == "round-robin":
self.index_generator = ConcatDataset.round_robin_generator
else:
supported_sampling_techniques = ["random", "round-robin"]
raise ValueError(f"Currently we only support sampling techniques in {supported_sampling_techniques}.")
self.length = 0
if isinstance(datasets[0], pt_data.IterableDataset):
self.kind = "iterable"
else:
self.kind = "map"
for dataset in datasets:
isiterable = isinstance(dataset, pt_data.IterableDataset)
if isiterable and self.kind != "iterable" or (not isiterable and self.kind == "iterable"):
raise ValueError("All datasets in ConcatDataset must be of the same kind (Iterable or Map).")
if self.kind == "map":
self.length += np.floor_divide(len(dataset), world_size)
else:
self.length += len(dataset)
[docs] def get_iterable(self, dataset):
"""Returns an iterable dataset."""
if isinstance(dataset, pt_data.IterableDataset):
return dataset.__iter__()
indices = np.arange(len(dataset))
if self.shuffle:
np.random.shuffle(indices)
return iter(indices)
[docs] def __iter__(self):
"""Returns an iterator over the dataset."""
worker_info = pt_data.get_worker_info()
if worker_info is None:
max_elements = self.length
wid = 0
wnum = 1
else:
wid = worker_info.id
wnum = worker_info.num_workers
max_elements = len(range(wid, self.length, wnum))
if self.kind == "map":
for idx in range(len(self.datasets)):
start_idx = np.floor_divide(len(self.datasets[idx]), self.world_size) * self.global_rank
end_idx = start_idx + np.floor_divide(len(self.datasets[idx]), self.world_size)
if self.global_rank == self.world_size - 1:
end_idx = len(self.datasets[idx])
indices = range(start_idx + wid, end_idx, wnum)
self.datasets[idx] = pt_data.Subset(self.datasets[idx], indices)
for idx, dataset in enumerate(self.datasets):
iterable = self.get_iterable(dataset)
self.iterables[idx] = iterable # type: ignore
n = 0
ind_gen = self.index_generator(self.datasets, **self.sampling_kwargs)
while n < max_elements:
n += 1
try:
ind = next(ind_gen)
except StopIteration:
return
try:
val = next(self.iterables[ind]) # type: ignore
if self.kind == "map":
val = self.datasets[ind][val]
yield val
except StopIteration:
self.iterables[ind] = self.get_iterable(self.datasets[ind]) # type: ignore
n -= 1
[docs] def __len__(self):
"""Returns the number of elements in the dataset."""
return self.length
[docs] @staticmethod
def round_robin_generator(datasets, **kwargs):
"""Generates indices in a round-robin fashion."""
num = len(datasets)
while True:
yield from range(num)
[docs] @staticmethod
def random_generator(datasets, **kwargs):
"""Generates random indices."""
p = kwargs.get("p")
if not p:
raise ValueError("Random generator expects a 'p' keyowrd argument for sampling probabilities.")
num = len(datasets)
if len(p) != num:
raise ValueError("Length of probabilities list must be equal to the number of datasets.")
while True:
yield np.random.choice(np.arange(num), p=p)
[docs]class ConcatMapDataset(Dataset):
"""
A dataset that accepts as argument multiple datasets and then samples from them based on the specified
sampling technique.
Parameters
----------
datasets: A list of datasets to sample from.
shuffle: Whether to shuffle individual datasets. Only works with non-iterable datasets. Defaults to True.
sampling_technique: Sampling technique to choose which dataset to draw a sample from. Defaults to 'random'.
Currently supports 'random' and 'round-robin'.
sampling_probabilities: Probability values for sampling. Only used when sampling_technique = 'random'.
global_rank: Worker rank, used for partitioning map style datasets. Defaults to 0.
world_size: Total number of processes, used for partitioning map style datasets. Defaults to 1.
"""
def __init__(
self,
datasets: List[Any],
sampling_technique: str = "temperature",
sampling_temperature: int = 5,
sampling_probabilities: List[float] = None,
consumed_samples: int = 0,
):
super().__init__()
self.datasets = datasets
self.sampling_kwargs: Dict = {}
self.size = 0
self.sampling_technique = sampling_technique
self.sampling_temperature = sampling_temperature
self.sampling_probabilities = sampling_probabilities
self.consumed_samples = consumed_samples
self.np_rng = np.random.RandomState(consumed_samples)
for dataset in datasets:
self.size += len(dataset)
self.dataset_index = np.zeros(len(self.datasets), dtype=np.uint8)
self.permuted_dataset_indices = []
for dataset in self.datasets:
permuted_indices = np.arange(len(dataset))
self.np_rng.shuffle(permuted_indices)
self.permuted_dataset_indices.append(permuted_indices)
if self.sampling_technique == "temperature":
lengths = [len(dataset) for dataset in datasets]
p = np.array(lengths) / np.sum(lengths)
p = np.power(p, 1 / self.sampling_temperature)
p = p / np.sum(p)
self.p = p
elif self.sampling_technique == "random":
if not self.sampling_probabilities:
raise ValueError(
"Random generator expects a 'sampling_probabilities' - a list of probability values corresponding "
"to each dataset."
)
if len(self.sampling_probabilities) != len(self.datasets):
raise ValueError(
"Length of probabilities list must be equal to the number of datasets. " # type: ignore
f"Found {len(sampling_probabilities)} probs and {len(self.datasets)} datasets." # type: ignore
)
p = np.array(self.sampling_probabilities)
self.p = p / np.sum(p)
def __len__(self):
return self.size
def _get_dataset_index(self, idx):
"""Returns the index of the dataset to sample from."""
if self.sampling_technique in ["temperature", "random"]:
return self.np_rng.choice(np.arange(len(self.datasets)), p=self.p)
elif self.sampling_technique == "round-robin":
return idx % len(self.datasets)
def __getitem__(self, idx):
# Get the dataset we want to sample from
dataset_index = self._get_dataset_index(idx)
# Get the index of the sample we want to fetch from the dataset
sample_idx = self.dataset_index[dataset_index]
# If the sample idx > dataset size, reset to 0.
if sample_idx > len(self.datasets[dataset_index]):
sample_idx = 0
self.dataset_index[dataset_index] = 0
# Sample index -> shuffled sample index
shuffled_sample_idx = self.permuted_dataset_indices[dataset_index][sample_idx]
sample = self.datasets[dataset_index][shuffled_sample_idx]
self.dataset_index[dataset_index] += 1
return sample