Pytorch color jitter

From the documentation: “brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]”

brightness by default is set to 0. This means that the brightness factor is chosen uniformly from [1, 1] meaning that brightness factor=1. The other parameters (contrast, saturation, hue) also seem to be constant under the default arguments. Does this mean that if color jitter is applied to the same image twice, the output will be the same?

If not, is there a way to perform the same color jitter twice on a pair of images? Thanks!

You could use the staticmethod get_params to apply the same “random” transformation via:

img = transforms.ToPILImage()(torch.randn(3, 224, 224))

color_jitter = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
transform = transforms.ColorJitter.get_params(
    color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,
    color_jitter.hue)

img_trans1 = transform(img)
img_trans2 = transform(img)

print((np.array(img_trans1) == np.array(img_trans2)).all())
> True

or alternatively use the funtional API directly via:

import torchvision.transforms.functional as TF

img = TF.adjust_brightness(img, brightness_factor)
img = TF.adjust_contrast(img, contrast_factor)
img = TF.adjust_saturation(img, saturation_factor)
img = TF.adjust_hue(img, hue_factor)
2 Likes

Thanks for your response!

I tried brightness=1, contrast=1, saturation=1, hue=0 in both the methods you suggested, which should theoretically return the original image (looking at the comments in the adjust_brightness function in source and the functions below it) but both methods do not, while they do return the same image.

I am using the following to display the output image:

plt.imshow(img)
fig = plt.gcf()
fig.set_size_inches(14, 10)
plt.show()

Could there be an issue here? Or did I misunderstand something?

The passed arguments will be used to determine the new factors using the posted uniform distribution in the range [max(0, 1 - arg), 1 + arg], where arg is brightness, contrast, saturation.
For hue, [-hue, +hue] will be used.

If you want to get the original image without any transformation, you should set all arguments to 0.0.

When I set all arguments to 0.0, I still do not get the same output.

I thought I was not converting the output images to tensors (transforms.ToTensor()) and normalizing them (transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])) to get them in exactly the same format as the original image displayed, but when I do this, I still do not get the original output.

Any ideas?

That’s weird, as I get the exact same outputs:

img = transforms.ToPILImage()(torch.randn(3, 224, 224))

color_jitter = transforms.ColorJitter(brightness=0.0, contrast=0., saturation=0., hue=0.)
transform = transforms.ColorJitter.get_params(
    color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,
    color_jitter.hue)

img_trans1 = transform(img)
img_trans2 = transform(img)

print((np.array(img) == np.array(img_trans1)).all())
> True
print((np.array(img) == np.array(img_trans2)).all())
> True
print((np.array(img_trans1) == np.array(img_trans1)).all())
> True

Could you check your code for other transformations, which could change the output?

Interesting.

I’m not sure what I’m doing wrong:

I have a data loader, which does the following transformation:

def imageNetTransformPIL(size=224):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

I get a torch tensor of images (called k; k.shape = (batch_size, 3, 224, 224)) from my data loader and display it using the following code:

plt.imshow(k[0].permute(1, 2, 0))
fig = plt.gcf()
fig.set_size_inches(14, 10)
plt.show()

This displays the image as expected.

I then use your code and display the images at various points:

img = transforms.ToPILImage()(k[0])

color_jitter = transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0)
transform = transforms.ColorJitter.get_params(
    color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,
    color_jitter.hue)
img = transform(img)

plt.imshow(img)
fig = plt.gcf()
fig.set_size_inches(14, 10)
plt.show()

transform_tensor = transforms.ToTensor()
img = transform_tensor(img)

transform_normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
img = transform_normalize(img)

plt.imshow(img.permute(1, 2, 0))
fig = plt.gcf()
fig.set_size_inches(14, 10)
plt.show()

Surprisingly, even the first image displayed is different (same structure but wildly different colors), suggesting that it is the method by which that I am displaying these images which is different. This led me to try turning the output into a torch tensor and displaying the image again, but this is also a very different image (same structure, very different colors).

Do you have any ideas what I might be doing wrong? Thanks a bunch for your help thus far.

What’s odd is that even if I comment out the color jittering, I don’t get the same output. Seems like it’s not a problem with your provided solution but rather the way in which I am displaying these images.

What is the right way to convert the PIL Image back into a torch tensor for both viewing purposes as well as to use in the forward() function?

ToPILImage() will not reconstruct the original image, if you’ve normalized it as seen here:

img = transforms.ToPILImage()(torch.randn(3, 224, 224))

norm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # comment this line
])


x_norm = norm(img)
img_reconstructed = transforms.ToPILImage()(x_norm)
print((np.array(img_reconstructed) == np.array(img)).all())
> False

If you comment out the mentioned line it’ll work.

Alternatively, you could also “unnormalize” the image, but this would most likely yield small rounding errors, so that some pixel in the uint8 format might be off by 1:

x_norm = norm(img)
x = x_norm * torch.tensor([0.229, 0.224, 0.225])[:, None, None] + torch.tensor([0.485, 0.456, 0.406])[:, None, None]

img_reconstructed = transforms.ToPILImage()(x)
print((np.array(img_reconstructed).astype(np.long) - np.array(img).astype(np.long)))

Unfortunately, I have to use the original transformation I referenced.

Is there norm() function in your code torch.norm()? I just added:

x_norm = norm(img)
x = x_norm * torch.tensor([0.229, 0.224, 0.225])[:, None, None] + torch.tensor([0.485, 0.456, 0.406])[:, None, None]

in my code before ToPILImage() yet this yields an image which is just entirely blue. Sorry for all this back and forth. I’m still unsure why I’m not getting the expected output…

norm is defined as the ToTensor and Normalize transformation in my code.
Could you post your current code you are using, which creates these blue images, so that I could take another look, please?

torch.norm() was the issue. Thank you so much for your help! Very much appreciated. :slight_smile:

Hi @ptrblck,

Your code snippet worked fine for me until I upgraded to PyTorch 1.8.1.
Now I get a type error: TypeError: 'tuple' object is not callable

Any suggestions?

Could you post an executable code snippet to reproduce this issue in the latest PyTorch and torchvision release, please?

import torch
import PIL
from torchvision import transforms

img = transforms.ToPILImage()(torch.randn(3, 224, 224))

color_jitter = transforms.ColorJitter(brightness=0.0, contrast=0., saturation=0., hue=0.)
transform = transforms.ColorJitter.get_params(
    color_jitter.brightness, color_jitter.contrast, color_jitter.saturation,
    color_jitter.hue)

img_trans1 = transform(img)
img_trans2 = transform(img)

using pytorch 1.8.1 and torchvision 0.9.1
build: py3.8_cuda11.1_cudnn8.0.5_0

Thanks for the code. The get_params method will return a tuple with the applied parameters for each transformation as well as their order and thus cannot be used as a function call.

From the docs:

Returns:
The parameters used to apply the randomized transform along with their random order.

If you want to apply the transformation in the image, call it via out = color_jitter(img).

Hi ptrblck,

I understand that color_jitter(img) will apply the transformation to the img.
However, I want to use the same exact transformation and apply it to multiple images (different sizes so can’t batch them together). This used to work with the code snippet you provided. Since torch 1.9.1 transforms.ColorJitter.get_params seems to return order and values of the transformations instead of the lambda functions as in previous torch versions.
I guess a workaround would be to take the values of transforms.ColorJitter.get_params and use them in functions like torchvision.transforms.functional.adjust_brightness but this seems like a roundabout way of doing things.

Thank you for your help so far.

The behavior was changed in this PR and is BC-breaking.
I think returning the parameters instead of transformations sounds reasonable given the get_params name and the behavior of other transformations (which also return the parameters, not transformations).
I’m not sure, why it was implemented in the previous way.

Yes, you could use the functional API to apply the parameters and could reuse the old implementation.

Ahh, I didn’t think to just go get the code in the previous release.
And I agree that the naming conventions make more sense now.
Thank you for the suggestion!