Hi,
I have remote sensing images that I am training a simsiam model on. My issue is these images are large (11, 10980, 10980). my current workflow is to load an image into the dataset, obtain a crop, augment it and return two augmented versions. Then load the image, or another random image again and obtain a different random crop.
My issue is that the current workflow takes a long time to load the entire image, which is repeated for every batch. ie I need to load 32 images for each epoch, only to grab a small 256x256x11 crop for each in the batch.
My question is, is there a better way to do this, without needing to pre-process all the images into smaller crops? I was thinking there could be a way to cache a subset of images for each epoch. I will post my dataclass and data loader code below: currently I am only testing this code with two images, but for final training there will be 100’s.
```{}
class CustomImageDataset(Dataset):
def __init__(self, folder_path, valid_exts: List[str] = ['tif', 'tiff']):
# Get image list
self.files = []
for file in os.listdir(folder_path):
ext = file.split('.')[-1]
if ext in valid_exts:
file = join(folder_path, file)
for i in range(16):
self.files.append(file)
self.transforms = create_simsiam_transforms(size=256)
def __getitem__(self, i: int):
single_img_path = self.files[i]
#single_img_path = os.path.join(self.image_dir, self.images[index])
#print(single_img_path)
image = gdal.Open(single_img_path).ReadAsArray().astype(np.float32) ## gdal read in CHW
array = np.zeros((11,256,256)) ## intialise empty array
width = image.shape[2] ## get dimensions of image
height = image.shape[1]
## set the number of patches here
xstart = randrange(width-256) ## make sure you dont go over edge of image
ystart = randrange(height-256)
array[:,:,:] = image[:,xstart:xstart+256, ystart:ystart+256]
array = np.transpose(array, (1,2,0)).astype(np.float32)
x1 = self.transforms(image=array)
x2 = self.transforms(image=array)
return x1, x2
def __len__(self):
return len(self.files)
def create_simsiam_dataloader(folder_path,
batch_size: int = 32,
num_workers: int = 8):
"""
Returns DataLoader from SimSiamDataset
Args:
"""
dataset = CustomImageDataset(img_dir)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
shuffle=True, num_workers=num_workers)
return dataloader