I’m thinking of applying the transform torchvision.transforms.ColorJitter to a video,
but I need to make sure the same transform is applied to each frame.
I have a function like:
#vid_t of shape [batch_size, num_channels, num_frames, height, width]
def rgb_vid_color_jitter(vid_t, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2):
# note this transform expects the video/image to be in range [0,1], not [0, 255.0]
transform = torchvision.transforms.ColorJitter(brightness=brightness, contrast=contrast,
saturation=saturation, hue=hue)
vid_t = torch.permute(vid_t, (0, 2, 1, 3, 4))
vid_t = transform(vid_t)
vid_t = torch.permute(vid_t, (0, 2, 1, 3, 4))
return vid_t
The transformations should be applied on each image separately and you could check it quickly by applying your code onto a static input tensor (e.g. just torch.ones). If you need to apply the same “random” transformation on multiple inputs you could use the functional API and create the random parameters once.