Custom transforms don't work?

I’m trying to create a transform that pads a PIL image to be square. (I wish this was one of the included transforms. Anybody know why it isn’t?)

def square_pad(image):
    w, h = image.size
    if w==h:
        return image
    max_wh = np.max([w, h])
    hp = int((max_wh - w) // 2)
    vp = int((max_wh - h) // 2)
    hp2 = max_wh-w-hp
    vp2 = max_wh-h-vp
    padding = (hp, vp, hp2, vp2)
    return F.pad(image, padding, 255, 'constant')

How can I add this to a list of transforms for transforms.Compose() and put it into a DataLoader with multiple workers?

I always get the following error:

AttributeError: Can't pickle local object 'main_worker.<locals>.square_pad'

I’ve tried adding it to the list directly:
train_transforms = [...,square_pad,...]

I’ve tried using Lambda:
train_transforms = [...,transforms.Lambda(square_pad),...]

I’ve tried making it a class with a call method,

I always get the “Can’t pickle local object” error.

BTW, it does work when I use num_workers=0, but that will not work for my application.

Any suggestions? It seems like there really is no way to use a custom transform, and there is also no way to do it with built in transforms.

Your code works fine for me:

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, transform=None):
        self.data = torch.randn(10, 3, 220, 224)
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        if self.transform:
            x = self.transform(x)
        return x

def square_pad(image):
    w, h = image.size
    max_wh = np.max([w, h])
    hp = int((max_wh - w) // 2)
    vp = int((max_wh - h) // 2)
    padding = (hp, vp, hp, vp)
    return transforms.functional.pad(image, padding, 0, 'constant')

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Lambda(lambda x: square_pad(x)),
    transforms.ToTensor()
])

dataset = MyDataset(transform)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, num_workers=2)

for data in loader:
    print(data.shape)

@ptrblck Could this be because you have defined the Dataset in the same file as the lambda?

I don’t know what might be causing your error. Is my code snippet working for you?

Your code works for me. It even works if I replace MyDataset with something built in like EuroSAT.

Ok, for everyone having problems, I’ve discovered the cause of my issue.

The custom lambda function must have global scope. If the function/class is defined within your main function (or within any function) it will not pickle and it will not work. I simply moved it out of my main function to global scope and the exception is gone.

Thanks for your help @ptrblck. Your simplified example really helped me figure it out!