I have a dataset with 100 images which occupy around 120 MB and their masks occupy around 4.2 MB. I have written a dataloader. When I try to plot some examples using the following code, the process is killed because of running out of memory.
import torchvision
import matplotlib.pyplot as plt
from augmentations import *
path = '/path_to_data_set'
batch_size = 1
augs = Compose([RandomRotate(10),
RandomHorizontallyFlip()])
root = SegDataset(root=path, split='train', is_transform=True, augmentations = augs)
trainloader = data.DataLoader(root, batch_size=batch_size, num_workers=0)
for i, data in enumerate(trainloader):
print(i)
imgs, masks = data
imgs = imgs.numpy()[:, ::-1, :, :]
imgs = np.transpose(imgs, (0,2,3,1))
f, axarr = plt.subplots(batch_size,2)
for j in range(batch_size):
axarr[j][0].imshow(imgs[j])
axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
plt.show()
a = input()
if a =='ex':
break
else:
plt.close()
I have tried num_workers=0
and other numbers , still same behaviour.
Can you show your dataset code and a traceback?
Does the error occur directly if you execute the script or only after a certain time?
The error occurs after a certain amount of time, the RAM slowly increases. I get process killed by signal
error.
This is the dataset code
class SegDataset(data.Dataset):
def __init__(self, root, split="train", is_transform = False, augmentations=None):
self.root=root
self.split = split
self.augmentations = augmentations
self.is_transform = is_transform
self.n_classes = 3
self.files = collections.defaultdict(list)
self.images_base = os.path.join(self.root, self.split)
self.masks_base = os.path.join(self.root, 'masks' + '/' + self.split)
for split in ['train', 'val', 'test']:
file_list = recursive_glob(rootdir=self.root + '/' + split , suffix='.jpg')
self.files[split] = file_list
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
img_path = self.files[self.split][index].rstrip()
img_name = img_path.split(os.sep)[-1].split('.')[0]
mask_name = img_name + '_mask.png'
mask_path = self.masks_base + '/' + mask_name
img = m.imread(img_path)
img = np.array(img, dtype=np.uint8)
mask = m.imread(mask_path)
mask = np.array(mask, dtype=np.uint8)
img, window, scale, padding, crop = self.resize_image(img, min_dim=image_size[0], max_dim=image_size[1])
mask = self.resize_mask(mask, padding=padding, crop=crop)
if self.augmentations is not None:
img, mask = self.augmentations(img, mask)
if self.is_transform:
img, mask = self.transform(img, mask)
return img, mask
I have removed resize_image
and resize_mask
lines and ran the code and still I’m getting Killed
as output after the RAM slowly filled completely. Any help would be much appreciated.
I have tested the code removing all the augmentations and transforms. I have tried different image reading methods from scipy
, skimage
, PIL
, cv2
. Still the dataloader is hogging RAM and causing it to crash due to running out of memory. Can some one assist me.
I think the problem is inside __getitem__
method. Remaining parts are working as they should.
This is the code after removing augmentations
and transformations
.
class SegDataset(data.Dataset):
def __init__(self, root, split="train"):
self.root=root
self.split = split
self.n_classes = 3
self.files = collections.defaultdict(list)
self.images_base = os.path.join(self.root, self.split)
self.masks_base = os.path.join(self.root, 'masks' + '/' + self.split)
for split in ['train', 'val', 'test']:
file_list = recursive_glob(rootdir=self.root + '/' + split , suffix='.jpg')
self.files[split] = file_list
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
img_path = self.files[self.split][index].rstrip()
img_name = img_path.split(os.sep)[-1].split('.')[0]
mask_name = img_name + '_mask.png'
mask_path = self.masks_base + '/' + mask_name
img = io.imread(img_path)
mask = io.imread(mask_path)
return img, mask
When I try to sample one datapoint from the dataset, it still runs out of RAM.
code used for sampling.
path = '\path\to_dataset\
dst = FootDataset(root=path, split='train')
data = dst[22]
The process dies after this code snippet runs for some time. Never seen this kind of behaviour in dtaloaders.
Can you print the length of self.files
(and it’s content) at the end of the __init__
?
Edit: or better the length of every split?
Yes, I can print the outputs. The outputs are what I expected.
Could you change the __getitem__
to just return img_path
and mask_paths
?
Then you could observe your RAM usage and load the images outside the Dataset
:
img_path, mask_path = dataset[0]
img = io.imread(img_path)
mask = io.imread(mask_path)
It is working for now. But if I include any image transformations in __getitem__
the RAM is running out.
It looks like you are using an own implementation for random rotation and horizontal flip.
Could you post the code for these functions?
Also, let’s see if one of these functions are buggy, by applying them one by one:
image = RandomRotate(10)(image)
# check ram
image = RandomHorizontallyFlip()(image)
# check ram
The torchvision
package provides these functions as well.
You can find them here and here.
I have removed them, this is up-to-date ram-hogging code.
class SegDataset(data.Dataset):
def __init__(self, root, split="train", is_transform = False):
self.root=root
self.split = split
self.augmentations = augmentations
self.is_transform = is_transform
self.n_classes = 3
self.files = collections.defaultdict(list)
self.images_base = os.path.join(self.root, self.split)
self.masks_base = os.path.join(self.root, 'masks' + '/' + self.split)
for split in ['train', 'val', 'test']:
file_list = recursive_glob(rootdir=self.root + '/' + split , suffix='.jpg')
self.files[split] = file_list
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
img_path = self.files[self.split][index].rstrip()
img_name = img_path.split(os.sep)[-1].split('.')[0]
mask_name = img_name + '_mask.png'
mask_path = self.masks_base + '/' + mask_name
img = io.imread(img_path)
mask = io.imread(mask_path)
if not (len(img.shape) == 3 and len(mask.shape) == 2):
return self.__getitem__(np.random.randint(0, self.__len__()))
if self.is_transform:
img, mask = self.transform(img, mask)
return img, mask
def transform(self, img, mask):
img = img[:, :, ::-1]
img = img.astype(np.float64)
img = img.transpose(2, 0, 1)
classes = np.unique(mask)
mask = mask.astype(int)
assert(np.all(classes == np.unique(mask)))
labels = encode_mask(mask)
img = torch.from_numpy(img).float()
mask = torch.from_numpy(mask).long()
return img, mask
def get_colors(self):
return np.asarray([[0, 0, 0],
[125, 0, 0],
[0, 125, 0]])
def encode_mask(self, mask):
mask = mask.astype(int)
labels = np.zeros(mask.shape[0], mask.shape[1], dtype=np.int16)
for i, label in enumerate(self.get_colors()):
labels[np.where(np.all(mask == label, axis=-1))[:2]] = 1
labels = labels.astype(int)
return labels
The transform
method returns only the mask not the encoded labels and I set the is_transform
to False
. When I run it the process gets killed after some time. If I remove the transform then the code works as expected.
Thanks for the info. I’ll try to debug your code.
In the meanwhile could you remove the recursion in __getitem__
:
return self.__getitem__(np.random.randint(0, self.__len__()))
and just return random data instead:
return torch.randn(YOUR_SIZE), torch.randn(YOUR_SIZE)