I have the following MUXDataset which holds several instnaces of dataset objects:
class MUXDataset(Dataset):
"""
Defines a dataset class that provides a way to read scenes and also visualization tools
"""
def __init__(self, mux_dataset_params: MultiplexDatasetParams) -> None:
self._params = mux_dataset_params
self.sampling_vec = (
mux_dataset_params.sampling if isinstance(mux_dataset_params.sampling, list) else self.init_sampling()
)
assert len(self.sampling_vec) == len(self._params.datasets), (
'The datasets sampling weights are not equal to the amount of datasets'
)
if self._params.debug_flag:
self.mux_debug_visualization()
self.mux_save_datasets_info()
return
def init_sampling(self) -> List[float]:
if self._params.sampling == 'uniform':
num_datasets = len(self._params.datasets)
sampling = [1 / num_datasets for _ in self._params.datasets]
elif self._params.sampling == 'proportional':
size_vec = [len(_d) for _d in self._params.datasets]
sampling = [1 - _s / sum(size_vec) for _s in size_vec]
sampling = [_s / sum(sampling) for _s in sampling]
else:
raise ValueError(f'{self._params.sampling} is not supported')
return sampling
def __len__(self) -> int:
return len(self._params.datasets)
def __getitem__(self, idx) -> DatasetOutput:
# endless iterator that will sample the datasets based on the sampling vector
_ds = self._params.datasets[
random.choices(range(len(self._params.datasets)), weights=self.sampling_vec, k=1)[0]
]
return _ds[random.choices(range(len(_ds)), weights=_ds.scenes_w, k=1)[0]]
But when I use dataloader it give me instance from several dataset, is there a way to enforce it to use the same dataset for the current batch?