Batch size and cache


I have remote sensing images that I am training a simsiam model on. My issue is these images are large (11, 10980, 10980). my current workflow is to load an image into the dataset, obtain a crop, augment it and return two augmented versions. Then load the image, or another random image again and obtain a different random crop.

My issue is that the current workflow takes a long time to load the entire image, which is repeated for every batch. ie I need to load 32 images for each epoch, only to grab a small 256x256x11 crop for each in the batch.

My question is, is there a better way to do this, without needing to pre-process all the images into smaller crops? I was thinking there could be a way to cache a subset of images for each epoch. I will post my dataclass and data loader code below: currently I am only testing this code with two images, but for final training there will be 100’s.

class CustomImageDataset(Dataset):
    def __init__(self, folder_path,  valid_exts: List[str] = ['tif', 'tiff']):
        # Get image list
        self.files = []
        for file in os.listdir(folder_path):
          ext = file.split('.')[-1]
          if ext in valid_exts:
            file = join(folder_path, file)
            for i in range(16):

        self.transforms = create_simsiam_transforms(size=256)
    def __getitem__(self, i: int):
        single_img_path = self.files[i]
        #single_img_path = os.path.join(self.image_dir, self.images[index])
        image = gdal.Open(single_img_path).ReadAsArray().astype(np.float32) ## gdal read in CHW
        array = np.zeros((11,256,256)) ## intialise empty array
        width = image.shape[2] ## get dimensions of image
        height = image.shape[1]
        ## set the number of patches here
        xstart = randrange(width-256) ## make sure you dont go over edge of image
        ystart = randrange(height-256)
        array[:,:,:] = image[:,xstart:xstart+256, ystart:ystart+256]
        array = np.transpose(array, (1,2,0)).astype(np.float32)

        x1 = self.transforms(image=array)
        x2 = self.transforms(image=array)
        return x1, x2
    def __len__(self):
        return len(self.files)

def create_simsiam_dataloader(folder_path,
                              batch_size: int = 32, 
                              num_workers: int = 8):
    Returns DataLoader from SimSiamDataset
    dataset = CustomImageDataset(img_dir)
    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, 
                            shuffle=True, num_workers=num_workers)
    return dataloader