If we want to combine two imbalanced datasets and get balanced samples, I think we could use ConcatDataset and pass a WeightedRandomSampler to the DataLoader
dataset1 = custom_dataset1()
dataset2 = custom_dataset2()
concat_dataset = torch.utils.data.ConcatDataset([dataset1, dataset2])
dataloader = torch.utils.data.DataLoader(concat_dataset, batch_size= bs, weighted_sampler)