For segmentation how to perform data augmentation in Pytorch?

Here is what I do for data augmentation in semantic segmentation.

First I define a composed transform such as

transf_aug = tf.Compose([tf.RandomHorizontalFlip(), tf.RandomResizedCrop((height,width),scale=(0.7, 1.0))])

Then, during the training phase, I apply the transformation at each image and mask. Given that each time transf_aug is applied it is a different random transformation, to ensure that the same transformation is applied to the image and the mask, we do the following trick (based on this comment):

state = torch.get_rng_state()
img = augmentdata(img)
torch.set_rng_state(state)
mask = augmentdata(mask)

Hope it is useful,
Pablo.

1 Like