Hi there, I am new of Pytorch, I want to apply my own function to transform pictures, but duing that the process slows down a lot. I think the problem here is that for each image it calls a class that takes a while to be loaded (but not sure). As I said I am new, so if you think this is the wrong approach just tell which is the better solution even if it is far away from this one.
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
import numpy as np
from PIL import Image
from mtcnn.mtcnn import MTCNN # recognise key points
detector = MTCNN() # It takes a while to be laoded
# Define Custom Transformation:
def occlude(input, p=0.5):
'''
This function occlude one of the three possible key points (mouth, nose and eyes),
with a certain probability p.
Args:
input (array): a picture
p (int): the probability that an occlusion occur
Returns:
modified picture (array): an occluded picture with probability p
'''
targets = ['mouth', 'eyes', 'nose'] + [None for _ in range(int(3 / p) - 3)]
faces = detector.detect_faces(input)
for f in faces:
choice = np.random.choice(targets)
if choice == 'mouth':
x1 = f['keypoints']['mouth_left'][1]
x2 = f['keypoints']['mouth_right'][1]
y1 = f['keypoints']['mouth_left'][0]
y2 = f['keypoints']['mouth_right'][0]
input[x1-5:x2+5,y1-5:y2+5] = 0
if choice == 'eyes':
x1 = f['keypoints']['left_eye'][1]
x2 = f['keypoints']['right_eye'][1]
y1 = f['keypoints']['left_eye'][0]
y2 = f['keypoints']['right_eye'][0]
input[x1-5:x1+5,y1-5:y1+5] = 0
input[x2-5:x2+5,y2-5:y2+5] = 0
if choice == 'nose':
x1 = f['keypoints']['nose'][1]
y1 = f['keypoints']['nose'][0]
input[x1-5:x1+5,y1-5:y1+5] = 0
return Image.fromarray(input)
train_transform = transforms.Compose([
transforms.Lambda(lambda x: occlude(x, p=0.3)),
transforms.RandomApply(torch.nn.ModuleList([transforms.RandomAffine(degrees=0, translate=(0.10,0.10))]), p=0.30),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0,0,0), (1,1,1)),
transforms.Resize((197,197))
])
train = DataLoader('./FER/train', transformations=train_transform)