On the fly random patching of large images

I am working with large remote sensing imagery. These images have 11 bands and a size of 10980*10980. I am trying to ‘on the fly patch’ these images to (11,512,512) to load into the data loader. However, my current dataset class runs into a memory error, as when loading in a single batch it creates 441 patches. Secondly, I know this must not be the best way, as ideally I want to grab a random selection of patches from all the images in the dataset.

```{python}
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os
import gdal
import numpy as np
class CustomImageDataset(Dataset):
    def __init__(self, folder_path):
        '''
       
        Args:
            folder_path (string): path to image folder
        '''
        # Get image list
        self.image_dir = folder_path
        self.images = os.listdir(folder_path)

    
    def __getitem__(self, index):
        single_img_path = os.path.join(self.image_dir, self.images[index])
        image = gdal.Open(single_img_path).ReadAsArray().astype(np.float32) ## gdal read in CHW
        im2tensor = transforms.ToTensor() 
        image = im2tensor(image) ## will convert image to tensor with HCW
        image = image.permute(1,0,2)
        image = image.unfold(1,512,512).unfold(2,512,512)
        image = image.contiguous().view(image.size(0), -1, image.size(3), image.size(4)) ## flatten patches
        image = image.permute(1, 0, 2, 3)
     
        return image
        
    def __len__(self):
        return len(self.images)

 dataset = MyDataset()
 loader = DataLoader(dataset, batch_size=2)
 x = next(iter(loader))
 print(x.shape)

Based on your current code it seems you are unfolding the large input into the smaller patches are are thus returning multiple of these patches from the __getitem__ method. I’m not sure if you want to return a single patch only, but if so you could (randomly) sample the coordinates of a window and just slice the loaded image to create a single [11, 512, 512] patch.

Thank you for your response @ptrblck . In the end I have gone for something like this.

```{python}
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os
import gdal
import numpy as np
class CustomImageDataset(Dataset):
    def __init__(self, folder_path):
        '''

        Args:
            folder_path (string): path to image folder
        '''
        # Get image list
        self.image_dir = folder_path
        self.images = os.listdir(folder_path)

    
    def __getitem__(self, index):
        single_img_path = os.path.join(self.image_dir, self.images[index])
        image = gdal.Open(single_img_path).ReadAsArray().astype(np.float32) ## gdal read in CHW
        array = np.zeros((64,11,256,256)) ## intialise empty array
        width = image.shape[2] ## get dimensions of image
        height = image.shape[1]
        for i in range(64):  ## set the number of patches here
          xstart = randrange(width-256) ## make sure you dont go over edge of image
          ystart = randrange(height-256)
          array[i,...] = image[:,xstart:xstart+256, ystart:ystart+256]
        array = torch.from_numpy(array)
        
        return array
        
    def __len__(self):
        return len(self.images)