On the fly random patching of large images

I am working with large remote sensing imagery. These images have 11 bands and a size of 10980*10980. I am trying to ‘on the fly patch’ these images to (11,512,512) to load into the data loader. However, my current dataset class runs into a memory error, as when loading in a single batch it creates 441 patches. Secondly, I know this must not be the best way, as ideally I want to grab a random selection of patches from all the images in the dataset.

import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os
import gdal
import numpy as np
class CustomImageDataset(Dataset):
    def __init__(self, folder_path):
            folder_path (string): path to image folder
        # Get image list
        self.image_dir = folder_path
        self.images = os.listdir(folder_path)

    def __getitem__(self, index):
        single_img_path = os.path.join(self.image_dir, self.images[index])
        image = gdal.Open(single_img_path).ReadAsArray().astype(np.float32) ## gdal read in CHW
        im2tensor = transforms.ToTensor() 
        image = im2tensor(image) ## will convert image to tensor with HCW
        image = image.permute(1,0,2)
        image = image.unfold(1,512,512).unfold(2,512,512)
        image = image.contiguous().view(image.size(0), -1, image.size(3), image.size(4)) ## flatten patches
        image = image.permute(1, 0, 2, 3)
        return image
    def __len__(self):
        return len(self.images)

 dataset = MyDataset()
 loader = DataLoader(dataset, batch_size=2)
 x = next(iter(loader))

Based on your current code it seems you are unfolding the large input into the smaller patches are are thus returning multiple of these patches from the __getitem__ method. I’m not sure if you want to return a single patch only, but if so you could (randomly) sample the coordinates of a window and just slice the loaded image to create a single [11, 512, 512] patch.

Thank you for your response @ptrblck . In the end I have gone for something like this.

import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os
import gdal
import numpy as np
class CustomImageDataset(Dataset):
    def __init__(self, folder_path):

            folder_path (string): path to image folder
        # Get image list
        self.image_dir = folder_path
        self.images = os.listdir(folder_path)

    def __getitem__(self, index):
        single_img_path = os.path.join(self.image_dir, self.images[index])
        image = gdal.Open(single_img_path).ReadAsArray().astype(np.float32) ## gdal read in CHW
        array = np.zeros((64,11,256,256)) ## intialise empty array
        width = image.shape[2] ## get dimensions of image
        height = image.shape[1]
        for i in range(64):  ## set the number of patches here
          xstart = randrange(width-256) ## make sure you dont go over edge of image
          ystart = randrange(height-256)
          array[i,...] = image[:,xstart:xstart+256, ystart:ystart+256]
        array = torch.from_numpy(array)
        return array
    def __len__(self):
        return len(self.images)