I am working with video datasets on PyTorch. But the memory consumption is too huge for a normal machine (15G of memory consumption with 8 workers, roughly increase 1G for an additional worker). I wonder if I have done something wrong for the dataloader or dataset. Here are my codes:
class SequenceDataset(torch.utils.data.Dataset):
"""
Base dataset of video sequence
Inputs:
root : str, path to the dataset
dataset : str, subset choose from 'train' 'val' 'test' and ''
"""
def __init__(self, root, dataset=''):
super().__init__()
self.root = root
self.dataset = dataset
self.datapath = os.path.join(self.root, dataset)
# load trajlist
filenames = sorted(os.listdir(self.datapath))
self.trajlist = [os.path.join(self.datapath, filename) for filename in filenames]
def __len__(self):
return len(self.trajlist)
def __getitem__(self, index):
traj_file = self.trajlist[index]
output = load_npz(traj_file) # each file contain an uint8 ndarray with shape [500, 64, 64, 3], so it is roughly 6Mb
return output
I only use a sub-sequence of the data in training, so I write a wrapper to do it.
class Split(torch.utils.data.Dataset):
"""
split part of the data
"""
def __init__(self, dataset, horizon, fix_start=False):
super().__init__()
self._dataset = dataset
self.horizon = horizon
self.fix_start = fix_start
def __getattr__(self, name):
return getattr(self._dataset, name)
def __len__(self):
return len(self._dataset)
def __getitem__(self, index):
data = self._dataset[index]
max_length = list(data.values())[0].shape[0]
start = 0 if self.fix_start else random.randint(0, max_length - self.horizon)
end = start + self.horizon
return {k : v[start:end] for k, v in data.items()}
And the dataloader is create like:
dataset = Split(SequenceDataset(path), horizon=32, fix_start=False)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=8, pin_memory=True)
It just use 15G of the memory. I only have 200 npz file in the dataset, so even if I load them all at once, they should only use 1.2G of the memory. So, I am confused with what other memory is used for?