I have a custom dataset of 256x256px video frames. To limit my memory usage, I am using the filepath as a pointer. However, I would like to split these frames into 16x16 non-overlapping patches to feed into a Transformer, and as such am trying to compute the frame number as well as the patch number within the frame beforehand from the full frame * patch number.
Code:
class Dataset(torch.utils.data.Dataset):
def __init__(self, directory='../data/*', get_subdirs=True, size=(16,16), max_ctx_length=4096):
self.data = glob.glob(directory)
if get_subdirs:
data_temp = []
for i in self.data:
fata = glob.glob(i+"/*")
fata.sort(key=lambda r: int(''.join(x for x in r if (x.isdigit()))))
data_temp.extend(fata)
self.data = data_temp
self.max_ctx_length = max_ctx_length
self.size = size
def __len__(self):
return len(self.data)*self.size[0]-self.max_ctx_length-1
def __getitem__(self, key):
frame_start = int(np.ceil(key / self.size[0]))
frame_end = int(np.ceil((key+self.max_ctx_length+1) / self.size[0]))
frames = self.data[frame_start:frame_end]
patch_start = int(np.mod(key, self.size[0]))
patch_end = int(np.mod(key+self.max_ctx_length+1, self.size[0]))
print(patch_start, patch_end)
patches = []
for im in range(len(frames)):
image = (Tvio.read_image(frames[im], mode=Tvio.ImageReadMode.RGB).float() / 255).unsqueeze(0)
ppf = F.unfold(image, self.size, stride=self.size).transpose(1,2).split(1, 1)
if im == 0:
if not(im==len(frames)-1):
ppf = ppf[patch_start:]
else:
ppf = ppf[patch_start:patch_end]
elif im == len(frames)-1:
ppf = ppf[:patch_end]
patches.extend(ppf)
print(len(patches))
data_x = patches[:-2]
print(len(data_x))
data_y = patches[1:-1]
return torch.cat(data_x, dim=1).squeeze(0), torch.cat(data_y, dim=1).squeeze(0)
if __name__ == "__main__":
dataset = Dataset()
print(dataset.__getitem__(4)[0].shape)
Expected Output:
A tensor of shape [4096, 768].
Actual Output:
torch.Size([65279, 768])