Hi,
I’m currently having a use case of creating custom data loader that can: (i) change batch_size
value dynamically during training and (ii) process the data sample with different operation for each different batch_size
.
For (i), I successfully accomplish it by creating my custom batch_sampler
that returns (a) batch_sample_indices
with different number of indices (its len
= batch_size
). However, I still cannot solve (ii). In order to achieve (ii), I need to know (a), or at least the len
of (a) in the __getitem__()
method inside the data loader class. Any body can give a direction for this? Or any other approaches to achieve (i) & (ii) is also welcomed.
Thanks in advance!
techkang
(kang sheng)
October 30, 2021, 8:27am
2
You can implement both (i) and (ii) like this:
import random
from torch.utils.data import DataLoader, Sampler, RandomSampler, Dataset
class MyBatchSampler(Sampler):
def __init__(self, sampler, batch_size_list, drop_last):
self.sampler = sampler
self.batch_size_list = batch_size_list
self.drop_last = drop_last
def __iter__(self):
batch = []
batch_size = random.choice(self.batch_size_list)
for idx in self.sampler:
batch.append((idx, batch_size))
if len(batch) == batch_size:
yield batch
batch = []
batch_size = random.choice(self.batch_size_list)
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore[arg-type]
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
class MyDataSet(Dataset):
def __getitem__(self, item):
index, batch_size = item
return index, batch_size
def __len__(self):
return 10
my_dataset = MyDataSet()
my_sampler = MyBatchSampler(RandomSampler(my_dataset), [1, 2, 3], False)
dataloader = DataLoader(my_dataset, batch_sampler=my_sampler)
for data in dataloader:
print(data)
1 Like
techkang:
import random
from torch.utils.data import DataLoader, Sampler, RandomSampler, Dataset
class MyBatchSampler(Sampler):
def __init__(self, sampler, batch_size_list, drop_last):
self.sampler = sampler
self.batch_size_list = batch_size_list
self.drop_last = drop_last
def __iter__(self):
batch = []
batch_size = random.choice(self.batch_size_list)
for idx in self.sampler:
batch.append((idx, batch_size))
if len(batch) == batch_size:
yield batch
batch = []
batch_size = random.choice(self.batch_size_list)
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore[arg-type]
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
class MyDataSet(Dataset):
def __getitem__(self, item):
index, batch_size = item
return index, batch_size
def __len__(self):
return 10
my_dataset = MyDataSet()
my_sampler = MyBatchSampler(RandomSampler(my_dataset), [1, 2, 3], False)
dataloader = DataLoader(my_dataset, batch_sampler=my_sampler)
for data in dataloader:
print(data)
Thanks a lot! Work like a charm