Applying collate_fn for multiple interleaved Datasets

I have two datasets say,
MyFirstDataset =MyFirstDatasetClass(data1)
MySecondDataset = MySecondDatasetClass(data2)
I have written a sampler which interleaves the sampling from two datasets which take input the concatenated dataset
concat_dataset = ConcatDataset([MyFirstDataset,MySecondDataset])
and give me the interleaved schedule. But the problem is and I am loading a dataset using a single data loader

dataloader = DataLoader(dataset = concat_dataset, sampler = BatchSchedulerSampler(dataset = concat_dataset,batch_size = batch_size), batch_schedule = args.batch_schedule),batch_size = batch_size)
I want to apply separate collate_fn for the two datasets but the dataloader only allows one collate_fn(correct me if I am wrong), please let me know how apply separate collate_fn in the dataloader if the dataset consists of 2 or more datasets.

return some type of identifier from the dataset, indicating if it’s dataset #1 or #2.
Then in the collate function just group them and do your thing