This is almost exactly the same as this question:
python, pytorch
I have two datasets A and B. A contains tensors of shape [256,4096] and B contains tensors of shape [32,4096].
Now i can use ConcatDataset to merge A and B, but how do i guarantee that each batch only contains elements from either A or B.
Note, I don’t want to resize elements withing A and B. These are not images.
The answer on stack overflow mentions batch_sampler.
Can somebody elaborate and give a minimal example?
naveenkb
(Naveen)
June 3, 2021, 8:10pm
2
Here is an example of custom batch_sampler for your case
def chunk(indices, size):
return torch.split(torch.tensor(indices), size)
class MyBatchSampler(Sampler):
def __init__(self, a_indices, b_indices, batch_size):
self.a_indices = a_indices
self.b_indices = b_indices
self.batch_size = batch_size
def __iter__(self):
random.shuffle(self.a_indices)
random.shuffle(self.b_indices)
a_batches = chunk(self.a_indices, self.batch_size)
b_batches = chunk(self.b_indices, self.batch_size)
all_batches = list(a_batches + b_batches)
all_batches = [batch.tolist() for batch in all_batches]
random.shuffle(all_batches)
return iter(all_batches)
new_dataset = ConcatDataset((dataset_a, dataset_b))
a_len = dataset_a.__len__()
ab_len = a_len + dataset_b.__len__()
a_indices = list(range(a_len))
b_indices = list(range(a_len, b_len))
batch_sampler = MyBatchSampler(a_indices, b_indices, batch_size)
dl = DataLoader(new_dataset, batch_sampler=batch_sampler)
to verify if each batch only contains elements from either A or B
for x in batch_sampler:
print(x)
1 Like
Thank you. What should def __len__(self)
be in MyBatchSampler
?
def __len__(self):
return (len(self.a_indices) + len(self.b_indices)) // self.batch_size
??
naveenkb
(Naveen)
June 4, 2021, 3:06am
5
Yes. You are right. Did it work ?