Torch.stack() and torch.cat() taking too long

Hey guys,

The dataset is a collection of around 1500 three-dimensional(RGB) images with resolution 640x480 and I’m stacking them up for standardization.

I’m using torchvision.datasets.ImageFolder for fetching the dataset mostly because I don’t have a .csv file to map the images to their label.

dataset = ImageFolder(root=path, transform=transform)
stacked_dataset = torch.stack([im for im, _ in dataset], dim=3)

For transform I’m only using T.ToTensor()

I have waited for more than 15 minutes and it still kept going.
I also tried using torch.cat() with unsqueeze like so :

 stacked_dataset = torch.cat([img_t.unsqueeze(3) for img_t, _ in dataset], 3)

but it also had the same result.

Bringing down the resolution through T.Resize() worked but I would prefer to not bring down the resolution.

I could see from the Activity Monitor that it was taking up almost 12GB of memory (MBP M1Pro).

I would recommend to profile the code and check which part actually takes the majority of time as I would guess the data loading is slower than a torch.stack or torch.cat operation.