[SOLVED] Dataloader: runs out of RAM for a small dataset

I have a dataset with 100 images which occupy around 120 MB and their masks occupy around 4.2 MB. I have written a dataloader. When I try to plot some examples using the following code, the process is killed because of running out of memory.

import torchvision
import matplotlib.pyplot as plt
from augmentations import * 


path = '/path_to_data_set'
batch_size = 1
augs = Compose([RandomRotate(10),
						RandomHorizontallyFlip()])
root = SegDataset(root=path, split='train', is_transform=True, augmentations = augs)
trainloader = data.DataLoader(root, batch_size=batch_size, num_workers=0)
for i, data in enumerate(trainloader):
	print(i)
	imgs, masks = data
	imgs = imgs.numpy()[:, ::-1, :, :]
	imgs = np.transpose(imgs, (0,2,3,1))
	f, axarr = plt.subplots(batch_size,2)
	for j in range(batch_size):
		axarr[j][0].imshow(imgs[j])
		axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
	plt.show()
	a = input()
	if a =='ex':
		break
	else:
		plt.close()

I have tried num_workers=0 and other numbers , still same behaviour.

Can you show your dataset code and a traceback?
Does the error occur directly if you execute the script or only after a certain time?

The error occurs after a certain amount of time, the RAM slowly increases. I get process killed by signal error.

This is the dataset code

class SegDataset(data.Dataset):
	def __init__(self, root, split="train", is_transform = False, augmentations=None):
		self.root=root
		self.split = split
		self.augmentations = augmentations
		self.is_transform = is_transform
		self.n_classes = 3
		self.files = collections.defaultdict(list)
		self.images_base = os.path.join(self.root, self.split)
		self.masks_base = os.path.join(self.root, 'masks' + '/' + self.split)

		for split in ['train', 'val', 'test']:
			file_list = recursive_glob(rootdir=self.root + '/' + split , suffix='.jpg')
			self.files[split] = file_list

	def __len__(self):
		return len(self.files[self.split])

	def __getitem__(self, index):
		img_path = self.files[self.split][index].rstrip()
		img_name = img_path.split(os.sep)[-1].split('.')[0]
		mask_name = img_name + 	'_mask.png'
		mask_path = self.masks_base + '/' + mask_name
		
		img = m.imread(img_path)
		img = np.array(img, dtype=np.uint8)

		mask = m.imread(mask_path)
		mask = np.array(mask, dtype=np.uint8)

		img, window, scale, padding, crop = self.resize_image(img, min_dim=image_size[0], max_dim=image_size[1])
		mask = self.resize_mask(mask, padding=padding, crop=crop)

		if self.augmentations is not None:
			img, mask = self.augmentations(img, mask)

		if self.is_transform:
			img, mask = self.transform(img, mask)

		return img, mask

I have removed resize_image and resize_mask lines and ran the code and still I’m getting Killed as output after the RAM slowly filled completely. Any help would be much appreciated.

I have tested the code removing all the augmentations and transforms. I have tried different image reading methods from scipy, skimage, PIL, cv2. Still the dataloader is hogging RAM and causing it to crash due to running out of memory. Can some one assist me.

I think the problem is inside __getitem__ method. Remaining parts are working as they should.

This is the code after removing augmentations and transformations.

class SegDataset(data.Dataset):
	def __init__(self, root, split="train"):
		self.root=root
		self.split = split
		self.n_classes = 3
		self.files = collections.defaultdict(list)
		self.images_base = os.path.join(self.root, self.split)
		self.masks_base = os.path.join(self.root, 'masks' + '/' + self.split)
		for split in ['train', 'val', 'test']:
			file_list = recursive_glob(rootdir=self.root + '/' + split , suffix='.jpg')
			self.files[split] = file_list

	def __len__(self):
		return len(self.files[self.split])

	def __getitem__(self, index):
		img_path = self.files[self.split][index].rstrip()
		img_name = img_path.split(os.sep)[-1].split('.')[0]
		mask_name = img_name + 	'_mask.png'
		mask_path = self.masks_base + '/' + mask_name
		
		img = io.imread(img_path)
		mask = io.imread(mask_path)
		return img, mask

When I try to sample one datapoint from the dataset, it still runs out of RAM.

code used for sampling.

path = '\path\to_dataset\
dst = FootDataset(root=path, split='train')
data = dst[22]

The process dies after this code snippet runs for some time. Never seen this kind of behaviour in dtaloaders.

Can you print the length of self.files (and it’s content) at the end of the __init__?
Edit: or better the length of every split?

Yes, I can print the outputs. The outputs are what I expected.

Could you change the __getitem__ to just return img_path and mask_paths?
Then you could observe your RAM usage and load the images outside the Dataset:

img_path, mask_path = dataset[0]
img = io.imread(img_path)
mask = io.imread(mask_path)

It is working for now. But if I include any image transformations in __getitem__ the RAM is running out.

It looks like you are using an own implementation for random rotation and horizontal flip.
Could you post the code for these functions?

Also, let’s see if one of these functions are buggy, by applying them one by one:

image = RandomRotate(10)(image)
# check ram
image = RandomHorizontallyFlip()(image)
# check ram

The torchvision package provides these functions as well.
You can find them here and here.

I have removed them, this is up-to-date ram-hogging code.

class SegDataset(data.Dataset):
	def __init__(self, root, split="train", is_transform = False):
		self.root=root
		self.split = split
		self.augmentations = augmentations
		self.is_transform = is_transform
		self.n_classes = 3
		self.files = collections.defaultdict(list)
		self.images_base = os.path.join(self.root, self.split)
		self.masks_base = os.path.join(self.root, 'masks' + '/' + self.split)

		for split in ['train', 'val', 'test']:
			file_list = recursive_glob(rootdir=self.root + '/' + split , suffix='.jpg')
			self.files[split] = file_list

	def __len__(self):
		return len(self.files[self.split])

	def __getitem__(self, index):
		img_path = self.files[self.split][index].rstrip()
		img_name = img_path.split(os.sep)[-1].split('.')[0]
		mask_name = img_name + 	'_mask.png'
		mask_path = self.masks_base + '/' + mask_name
		
		img = io.imread(img_path)
		mask = io.imread(mask_path)
		if not (len(img.shape) == 3 and len(mask.shape) == 2):
			return self.__getitem__(np.random.randint(0, self.__len__()))

		if self.is_transform:
			img, mask = self.transform(img, mask)

		return img, mask
    
	def transform(self, img, mask):
		img = img[:, :, ::-1]
		img = img.astype(np.float64)
		img = img.transpose(2, 0, 1)
		classes = np.unique(mask)
		mask = mask.astype(int)
		assert(np.all(classes == np.unique(mask)))
                labels = encode_mask(mask)
		img = torch.from_numpy(img).float()
		mask = torch.from_numpy(mask).long()
		

		return img, mask

	def get_colors(self):
		return np.asarray([[0, 0, 0],
						[125, 0, 0],
						[0, 125, 0]])

	def encode_mask(self, mask):
		mask = mask.astype(int)
		labels = np.zeros(mask.shape[0], mask.shape[1], dtype=np.int16)
		for i, label in enumerate(self.get_colors()):
			labels[np.where(np.all(mask == label, axis=-1))[:2]] = 1
		labels = labels.astype(int)
		return labels

The transform method returns only the mask not the encoded labels and I set the is_transform to False. When I run it the process gets killed after some time. If I remove the transform then the code works as expected.

Thanks for the info. I’ll try to debug your code.
In the meanwhile could you remove the recursion in __getitem__:

return self.__getitem__(np.random.randint(0, self.__len__()))

and just return random data instead:

return torch.randn(YOUR_SIZE), torch.randn(YOUR_SIZE)