Transforms for images and masks

Dear community,

I’am relatively new to machine learning in general and pyTorch in particular. I just wanted to is short discuss something I encountered while implementing a custom DataSet class as a basis for my project which includes a simple classification (resnet34), object detection (Faster R-CNN) and instance segmentation (Mask R-CNN). I do some research in the medical domain and work with 10k images from the HAM10.000 dataset.

As data augmentation I perform some random transformations at the time images and masks are loaded. Those transformations include RandomRotations, RandomFlips, RandomCrop and some random ColorJitter. All of them are implemented in torchvision.transforms, however they are made for one input image. All random mutations of the image, e.g. rotations, should also be applied in the same way to the masks. At this point i decided to go with the given Structure of torchvision.transforms and implent some classes which inherit from those transforms but a) take image and masks and b) first obtain the random parameters and then apply the same transformation to both, the image and the mask.

What makes me think is that, this problem should not be unique to me, so i don’t understand why this is not implemented already. And also some projects I’ve came across perform data augmentation before the training process and only load already augmented datasets.

So, if there are better ways to do this, let me know.
Thank you all :wink:

As a reference i did something like that for all of my transformations i use:

class Compose(Compose):

    def __call__(self, img, mask):
        for t in self.transforms:
            img, mask = t(img, mask)
        return img, mask

class ColorJitter(ColorJitter):

    def __call__(self, img, mask):
        transform = self.get_params(self.brightness, self.contrast,
                                    self.saturation, self.hue)
        return transform(img), mask

I did something in that line. https://github.com/JuanFMontesinos/flerken/blob/master/flerken/dataloaders/transforms/transforms.py
Which is a beta repo xd
The core idea is:

class Compose(object):
    """Composes several transforms together.
    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms, dim=None):
        self.transforms = transforms
        self.dim = dim

    def __call__(self, inpt):
        if isinstance(inpt, (list, tuple)):
            return self.apply_sequence(inpt)
        else:
            return self.apply_img(inpt)

    def apply_img(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def apply_sequence(self, seq):
        output = list(map(self.apply_img, seq))
        if self.dim is not None:
            assert isinstance(self.dim, int)
            output = torch.stack(output, dim=self.dim)
        for t in self.transforms:
            t.reset_params()
        return output

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

That if you pass a list it will apply the same transformation to all the elements. As you can read in apply_sequence, each transformation should has a method which reset the parameters, namely, reset randomness.

So that if you are rotating, it will rotate the same way for all the images, when the list is empty it flip the coin again and so on.
Most of them are already implemented but there are bugs since I copy pasted them but didn’t test (lack of time)

1 Like

This looks like a very generic and elegant solution. Thank you for sharing :slight_smile:

Here is a blog post on how to use Albumentations library to apply the same set of transforms to the set of images, masks, key points, and bounding boxes. https://towardsdatascience.com/multi-target-in-albumentations-16a777e9006e

2 Likes

didn’t work for me, rotate transformation were different for image and a mask.