This is almost exactly the same as this question:
I have two datasets A and B. A contains tensors of shape [256,4096] and B contains tensors of shape [32,4096].
Now i can use ConcatDataset to merge A and B, but how do i guarantee that each batch only contains elements from either A or B.
Note, I don’t want to resize elements withing A and B. These are not images.
The answer on stack overflow mentions batch_sampler.
Can somebody elaborate and give a minimal example?
Here is an example of custom batch_sampler for your case
def chunk(indices, size):
return torch.split(torch.tensor(indices), size)
def __init__(self, a_indices, b_indices, batch_size):
self.a_indices = a_indices
self.b_indices = b_indices
self.batch_size = batch_size
a_batches = chunk(self.a_indices, self.batch_size)
b_batches = chunk(self.b_indices, self.batch_size)
all_batches = list(a_batches + b_batches)
all_batches = [batch.tolist() for batch in all_batches]
new_dataset = ConcatDataset((dataset_a, dataset_b))
a_len = dataset_a.__len__()
ab_len = a_len + dataset_b.__len__()
a_indices = list(range(a_len))
b_indices = list(range(a_len, b_len))
batch_sampler = MyBatchSampler(a_indices, b_indices, batch_size)
dl = DataLoader(new_dataset, batch_sampler=batch_sampler)
to verify if each batch only contains elements from either A or B
for x in batch_sampler:
Thank you. What should
def __len__(self) be in
return (len(self.a_indices) + len(self.b_indices)) // self.batch_size
Yes. You are right. Did it work ?