Custom Data Augmentation function does not align augmentation on image and mask

I’m not sure the way that I am using data augmentation for my semantic segmentation task is working properly. I define my own pyTorch Dataset PositiveOnly of images and their masks. However, when I try to plot the same instance of my dataset, the image and mask do not align. I have 3 channel images, and 1 channel masks.

My function to apply simple rotations to a semantic segmentation dataset is here, as well as the custom ToTensor function, and the overall wrapper for my dataset:

class Randomize(object):
    def __call__(self, sample):

        imgdata = sample['img']
        fptdata = sample['fpt']

        # mirror horizontally
        mirror = np.random.randint(0, 2)
        if mirror:
            imgdata = np.flip(imgdata, 2)
            fptdata = np.flip(fptdata, 1)

        # flip vertically
        flip = np.random.randint(0, 2)
        if flip:
            imgdata = np.flip(imgdata, 1)
            fptdata = np.flip(fptdata, 0)

        # rotate by [0,1,2,3]*90 deg
        rot = np.random.randint(0, 4)
        imgdata = np.rot90(imgdata, rot, axes=(1,2))
        fptdata = np.rot90(fptdata, rot, axes=(0,1))

        return {'img': imgdata.copy(),
                'fpt': fptdata.copy()}

class ToTensor(object):
    def __call__(self, sample):
        out = {'img': torch.from_numpy(sample['img'].copy()),
               'fpt': torch.from_numpy(sample['fpt'].copy())}
        return out
    
def create_dataset(*args, apply_transforms=True, **kwargs):

    if apply_transforms:
        data_transforms = transforms.Compose([
            Randomize(),
            ToTensor(),
        ])
    else:
        data_transforms = None

    data = PositiveOnly(*args, **kwargs, transform=data_transforms)
    return data

I’ve broken down the Randomize function line by line and it does appear to work, but once i call it in my dataset prior to the dataloader I see they don’t align.

My dataset is defined as:

class PositiveOnly(Dataset):
    def __init__(self, p_pkl, transform=None):
        self.p_pkl = p_pkl
        self.transform = transform
        
        with open(p_pkl, "rb") as fp:
            self.p_pkl = pkl.load(fp)
        
    def __len__(self):
        return len(self.p_pkl)
    
    def __getitem__(self, idx):
        sample = self.p_pkl[idx]
#         print(sample)
        if self.transform:
            sample = self.transform(sample)
        return sample

Here is the code I run to see that they don’t align:

p = create_dataset('data.pkl',
                   apply_transforms=True)

pic = np.moveaxis(np.array(p[0]['img']), 0,2) #is a pytorch tensor with shape (3, 288,288)
mak = p[0]['fpt']

plt.imshow(pic)
plt.show()
plt.imshow(mak)
plt.show()

Is there something wrong with my Randomize function, or the way that I am calling to plot them? I plan to run this in a deep learning model but need them to align so the results aren’t messed up.

Are you saying calling your transform manually works but using it as part of a dataloader does not? If you could provide a short snippet that only contains the necessary pieces to reproduce the problem that would be helpful (e.g., craft a synthetic image of a square + mask, apply the transform to both, and plot both to show the discrepancy).

Actually, opposite. When I manually call the transform the image and mask do not align, and thus I am unsure if it will call properly in the dataloader. Here is a thorough example to show my point:

import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms

#synthetic RGB image
width, height = 30, 30
image = np.zeros((height, width, 3), dtype=np.uint8)
x, y = 15, 5
rect_width, rect_height = 8, 15

image[y:y+rect_height, x:x+rect_width, 0] = 50  # Red channel
image[y:y+rect_height, x:x+rect_width, 1] = 20    # Green channel
image[y:y+rect_height, x:x+rect_width, 2] = 150    # Blue channel

#synthetic 1-D mask
width, height = 30, 30
mak = np.zeros((height, width), dtype=np.uint8)
x, y = 15, 5
rect_width, rect_height = 8, 15
mask = np.zeros_like(mak)
mask[y:y+rect_height, x:x+rect_width] = 1
mak = np.random.rand(height, width)
masked_image = np.multiply(mak, mask)

class Randomize(object):
    def __call__(self, sample):

        imgdata = sample['img']
        fptdata = sample['fpt']

        # mirror horizontally
        mirror = np.random.randint(0, 2)
        if mirror:
            imgdata = np.flip(imgdata, 2)
            fptdata = np.flip(fptdata, 1)
            
        # flip vertically
        flip = np.random.randint(0, 2)
        if flip:
            imgdata = np.flip(imgdata, 1)
            fptdata = np.flip(fptdata, 0)
            
        # rotate by [0,1,2,3]*90 deg
        rot = np.random.randint(0, 4)
        imgdata = np.rot90(imgdata, rot, axes=(0,2))
        fptdata = np.rot90(fptdata, rot, axes=(0,1))

        return {'img': imgdata.copy(),
                'fpt': fptdata.copy()}

sample = {'img':image, 'fpt':mask}
f, a = plt.subplots(1,2)
a[0].imshow(sample['img'])
a[1].imshow(sample['fpt'], cmap='gray')

transform = transforms.Compose([
            Randomize()])

sample = transform(sample)

f, a = plt.subplots(1,2)
a[0].imshow(sample['img'])
a[1].imshow(sample['fpt'], cmap='gray')

After calling the Randomize function, the mask and image do not align. When I break apart the Randomize to run the code in it, it works, but as an object it doesn’t. Thank you for your response! I hope this can help.

I’ve made a few changes to your repro snippet that appear to fix the issue:

import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms

#synthetic RGB image
width, height = 30, 30
image = np.zeros((height, width, 3), dtype=np.uint8)
x, y = 15, 5
rect_width, rect_height = 8, 15

image[y:y+rect_height, x:x+rect_width, 0] = 50  # Red channel
image[y:y+rect_height, x:x+rect_width, 1] = 20    # Green channel
image[y:y+rect_height, x:x+rect_width, 2] = 150    # Blue channel

#synthetic 1-D mask
width, height = 30, 30
mak = np.zeros((height, width), dtype=np.uint8)
x, y = 15, 5
rect_width, rect_height = 8, 15
mask = np.zeros_like(mak)
mask[y:y+rect_height, x:x+rect_width] = 1
mak = np.random.rand(height, width)
masked_image = np.multiply(mak, mask)

class Randomize(object):
    def __call__(self, sample):

        imgdata = sample['img']
        fptdata = sample['fpt']

        # mirror horizontally
        mirror = np.random.randint(0, 2)
        if mirror:
            imgdata = np.flip(imgdata, 1)
            fptdata = np.flip(fptdata, 1)
            
        # flip vertically
        flip = np.random.randint(0, 2)
        if flip:
            imgdata = np.flip(imgdata, 0)
            fptdata = np.flip(fptdata, 0)
            
        # rotate by [0,1,2,3]*90 deg
        rot = np.random.randint(0, 4)
        imgdata = np.rot90(imgdata, rot, axes=(0,1))
        fptdata = np.rot90(fptdata, rot, axes=(0,1))

        return {'img': imgdata.copy(),
                'fpt': fptdata.copy()}

sample = {'img':image, 'fpt':mask}
f, a = plt.subplots(1,2)
a[0].imshow(sample['img'])
a[1].imshow(sample['fpt'], cmap='gray')

transform = transforms.Compose([
            Randomize()])

sample = transform(sample)

f, a = plt.subplots(1,2)
a[0].imshow(sample['img'])
a[1].imshow(sample['fpt'], cmap='gray')

It looks like the flip and rot calls are operating on the wrong axes, as the data would be in a “channels-last” format (HWC) in this convention. A hint that the wrong axis was being operated on is that sometimes the color of the image was being shifted which should not happen with spatial axis transformations. I would double-check your assumptions about when the image is in HWC vs. CHW format in your original code.

1 Like

Thank you for your help! I just wanted to provide a code snippet of use with my RGB images that come out of my dataloader as tensors that require a little bit of work to show.

p = create_dataset('data.pkl',
                   apply_transforms=True)

h = p[188]    #random img
im = np.moveaxis(np.array(h['img']),0,2) #original shape as tensor is (3,288,288) so i move to (288,288,3)
im.shape

f,a = plt.subplots(1,2)
a[0].imshow(im)
a[1].imshow(h['fpt'])

I do have a follow up question though. Is it correct, that everytime I call my dataloader, defined as p in this instance, it will randomize my data, and that is why a code like:

f,a = plt.subplots(1,2)
a[0].imshow(np.moveaxis(np.array(p[0]['img']),0,2))
a[1].imshow(p[0]['fpt'])

would produce mismatched results despite the same index?

Yes, if your __getitem__ method calls transforms explicitly then that would cause differences upon each access unless you manually reset the seeds e.g., via torch.manual_seed and numpy.random.seed (and any other libraries that are randomizing things that you depend on).

Thank you for that explanation. I have encountered another problem where running my revised augmentation function causes the axis order of my torch tensors to be inconsistent, swapping between torch.Size([288, 3, 288]) and torch.Size([3, 288, 288]) depending on the flip applied. Do you know how to resolve? My error comes from simply calling the dataloader in a train function

Are you somehow calling permute during the augmentation in a data-dependent/random way? I would try isolating the problem by removing transformations until it goes away and checking why a given transformation is permuting the axes.

It seems unlikely, but I would also sanity check that the axes of the input you load are also always ordered the same way as well.