I am working with large remote sensing imagery. These images have 11 bands and a size of 10980*10980. I am trying to ‘on the fly patch’ these images to (11,512,512) to load into the data loader. However, my current dataset class runs into a memory error, as when loading in a single batch it creates 441 patches. Secondly, I know this must not be the best way, as ideally I want to grab a random selection of patches from all the images in the dataset.
```{python}
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os
import gdal
import numpy as np
class CustomImageDataset(Dataset):
def __init__(self, folder_path):
'''
Args:
folder_path (string): path to image folder
'''
# Get image list
self.image_dir = folder_path
self.images = os.listdir(folder_path)
def __getitem__(self, index):
single_img_path = os.path.join(self.image_dir, self.images[index])
image = gdal.Open(single_img_path).ReadAsArray().astype(np.float32) ## gdal read in CHW
im2tensor = transforms.ToTensor()
image = im2tensor(image) ## will convert image to tensor with HCW
image = image.permute(1,0,2)
image = image.unfold(1,512,512).unfold(2,512,512)
image = image.contiguous().view(image.size(0), -1, image.size(3), image.size(4)) ## flatten patches
image = image.permute(1, 0, 2, 3)
return image
def __len__(self):
return len(self.images)
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=2)
x = next(iter(loader))
print(x.shape)