How to concatenate different datasets, each with different dimensions

This is almost exactly the same as this question:

I have two datasets A and B. A contains tensors of shape [256,4096] and B contains tensors of shape [32,4096].
Now i can use ConcatDataset to merge A and B, but how do i guarantee that each batch only contains elements from either A or B.
Note, I don’t want to resize elements withing A and B. These are not images.
The answer on stack overflow mentions batch_sampler.
Can somebody elaborate and give a minimal example?

Here is an example of custom batch_sampler for your case

def chunk(indices, size):
    return torch.split(torch.tensor(indices), size)

class MyBatchSampler(Sampler):
    def __init__(self, a_indices, b_indices, batch_size): 
        self.a_indices = a_indices
        self.b_indices = b_indices
        self.batch_size = batch_size
    
    def __iter__(self):
        random.shuffle(self.a_indices)
        random.shuffle(self.b_indices)
        a_batches  = chunk(self.a_indices, self.batch_size)
        b_batches = chunk(self.b_indices, self.batch_size)
        all_batches = list(a_batches + b_batches)
        all_batches = [batch.tolist() for batch in all_batches]
        random.shuffle(all_batches)
        return iter(all_batches)
   
new_dataset = ConcatDataset((dataset_a, dataset_b))
a_len = dataset_a.__len__()
ab_len = a_len + dataset_b.__len__()
a_indices = list(range(a_len))
b_indices = list(range(a_len, b_len))

batch_sampler = MyBatchSampler(a_indices, b_indices, batch_size)

dl = DataLoader(new_dataset,  batch_sampler=batch_sampler)

to verify if each batch only contains elements from either A or B

for x in batch_sampler:
    print(x)

Thank you. What should def __len__(self) be in MyBatchSampler ?

    def __len__(self):
        return (len(self.a_indices) + len(self.b_indices)) // self.batch_size

??

Yes. You are right. Did it work ?

Yes it worked thanks