I need to build a data loader which iterates through the loader until a certain sample size is met. For example, if my loader looks like below.
import torch
from torch.utils.data import DataLoader, Dataset
class CustomLoader(Dataset):
def __init__(self) -> None:
super().__init__()
self.all_files = torch.randn((250))
def __len__(self):
return len(self.all_files)
def __getitem__(self, idx):
rnd = self.all_files[idx]
x = torch.randn((21, 20, 4)) * rnd
target = idx
return x, target
if __name__ == "__main__":
dataset = CustomLoader()
loader = DataLoader(dataset, batch_size=8, shuffle=True)
for b, (x, target) in enumerate(loader, 0):
bat, peopl, tim, feat = x.size() # 8 * 21 * 20 * 4
# Flatten x dimension
x = x.view(bat * peopl, tim, feat) # 168 * 20 * 4
# Some Condition. Output size varies
rand_el = torch.randint(0, 22, size=(1,)).item()
x = x[:rand_el]
print(x.size()) # 13 * 20 * 4
My x
is of dim 8 X 21 X 20 X 4
, where 8
is my batch size. Once I have the data, I combine the batch
and 21
dimension which then becomes 168 X 20 X 4
. I apply some condition and filter out the data. So my final dimensions can then be 13 X 20 X 4
. I need to iterate through next batch until my data is append to 32
batch size.
What is the best way to go about this problem. Applying this filter in dataset class itself outputs variable length input which cannot collate my data effectively. A custom collate function does not allow me to iterate through next batch as far as I know. Please advise.