Does torchvision.transforms.ColorJitter apply the same transform to each batch element

I’m thinking of applying the transform torchvision.transforms.ColorJitter to a video,
but I need to make sure the same transform is applied to each frame.

I have a function like:

#vid_t of shape [batch_size, num_channels, num_frames, height, width]
def rgb_vid_color_jitter(vid_t, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2):
    # note this transform expects the video/image to be in range [0,1], not [0, 255.0]
    transform = torchvision.transforms.ColorJitter(brightness=brightness, contrast=contrast,
                                                   saturation=saturation, hue=hue)
    vid_t = torch.permute(vid_t, (0, 2, 1, 3, 4))
    vid_t = transform(vid_t)
    vid_t = torch.permute(vid_t, (0, 2, 1, 3, 4))
    return vid_t

Any insights appreciated.

The transformations should be applied on each image separately and you could check it quickly by applying your code onto a static input tensor (e.g. just torch.ones). If you need to apply the same “random” transformation on multiple inputs you could use the functional API and create the random parameters once.

1 Like

I ended up with some code like:

#vid of shape [batch_size, num_channels, num_frames, height, width]
def random_color_jitter(vid, brightness, contrast, saturation, hue):
    batch_size = vid.shape[0]
    for i in range(batch_size): # a different transform per video
        if brightness > 0:
            brightness_factor = random.uniform(
                max(0, 1 - brightness), 1 + brightness)
        else:
            brightness_factor = None
        if contrast > 0:
            contrast_factor = random.uniform(
                max(0, 1 - contrast), 1 + contrast)
        else:
            contrast_factor = None
        if saturation > 0:
            saturation_factor = random.uniform(
                max(0, 1 - saturation), 1 + saturation)
        else:
            saturation_factor = None
        if hue > 0:
            hue_factor = random.uniform(-hue, hue)
        else:
            hue_factor = None
        vid_transforms = []
        if brightness is not None:
            vid_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness_factor))
        if saturation is not None:
            vid_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation_factor))
        if hue is not None:
            vid_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue_factor))
        if contrast is not None:
            vid_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast_factor))
        random.shuffle(vid_transforms)
        v = vid[i]
        v = torch.permute(v, (1, 0, 2, 3))
        for transform in vid_transforms:
            v = transform(v)
        v = torch.permute(v, (1, 0, 2, 3))
        vid[i] = v
    return vid