How to enforce the on the same batch dataloader that has multiple datasets will return the items from the same dataset

I have the following MUXDataset which holds several instnaces of dataset objects:

class MUXDataset(Dataset):
    """
        Defines a dataset class that provides a way to read scenes and also visualization tools
    """

    def __init__(self, mux_dataset_params: MultiplexDatasetParams) -> None:
        self._params = mux_dataset_params
        self.sampling_vec = (
            mux_dataset_params.sampling if isinstance(mux_dataset_params.sampling, list) else self.init_sampling()
        )
        assert len(self.sampling_vec) == len(self._params.datasets), (
            'The datasets sampling weights are not equal to the amount of datasets'
        )

        if self._params.debug_flag:
            self.mux_debug_visualization()
            self.mux_save_datasets_info()
        return

    def init_sampling(self) -> List[float]:
        if self._params.sampling == 'uniform':
            num_datasets = len(self._params.datasets)
            sampling = [1 / num_datasets for _ in self._params.datasets]
        elif self._params.sampling == 'proportional':
            size_vec = [len(_d) for _d in self._params.datasets]
            sampling = [1 - _s / sum(size_vec) for _s in size_vec]
            sampling = [_s / sum(sampling) for _s in sampling]
        else:
            raise ValueError(f'{self._params.sampling} is not supported')
        return sampling

    def __len__(self) -> int:
        return len(self._params.datasets)

    def __getitem__(self, idx) -> DatasetOutput:
        # endless iterator that will sample the datasets based on the sampling vector
        _ds = self._params.datasets[
            random.choices(range(len(self._params.datasets)), weights=self.sampling_vec, k=1)[0]
        ]
        return _ds[random.choices(range(len(_ds)), weights=_ds.scenes_w, k=1)[0]]

But when I use dataloader it give me instance from several dataset, is there a way to enforce it to use the same dataset for the current batch?

You could use a BatchSampler and load the entire batch in your __getitem__ method as described here.