Torchvision models biased against green

Hello everyone :nerd_face:

As some people might know, the human vision system is more sensitive toward the color green than red & blue which is why we adjust for this while capturing images from cameras.

It’s necessary to include more information from the green pixels in order to create an image that the eye will perceive as a “true color.” source

On a similar fashion, does our conv-nets show greater interest / derive more information from either red, green or blue? I don’t know the questions to this and invite anyone to share thoughts or links to more information.

We will focus on the most commonly used models where the first conv layer’s kernels are 3-dimensional (height, width, depth). I wanted to see how these kernels differed in the depth / RGB dimension. I wrote a script that summed the absolute values of the kernels so it was possible to see which channel ‘carried more weight’ under the assumption that larger weights are more important. Here are the results and the code.

alexnet                   RGB: [0.35563377 0.35734916 0.28701708]
densenet121               RGB: [0.3414047  0.4036838  0.25491145]
densenet161               RGB: [0.33822682 0.40176123 0.2600119 ]
densenet169               RGB: [0.34025708 0.4014946  0.25824833]
densenet201               RGB: [0.3429028 0.3985608 0.2585365]
googlenet                 RGB: [0.35481286 0.37088254 0.2743046 ]
inception_v3              RGB: [0.39560044 0.31095004 0.29344955]
mnasnet0_5                RGB: [0.3219102  0.44260073 0.2354891 ]
mnasnet1_0                RGB: [0.30014    0.48660594 0.21325403]
mobilenet_v2              RGB: [0.3212116  0.45090452 0.22788389]
resnet101                 RGB: [0.33598453 0.40465072 0.2593648 ]
resnet152                 RGB: [0.33805275 0.3991825  0.26276478]
resnet18                  RGB: [0.3450004  0.390642   0.26435766]
resnet34                  RGB: [0.3454328  0.38308167 0.2714856 ]
resnet50                  RGB: [0.34040007 0.38215744 0.27744249]
resnext101_32x8d          RGB: [0.33099824 0.4089949  0.26000684]
resnext50_32x4d           RGB: [0.3365792  0.39434975 0.269071  ]
shufflenet_v2_x0_5        RGB: [0.330948   0.45226046 0.21679151]
shufflenet_v2_x1_0        RGB: [0.326566   0.46278057 0.21065338]
squeezenet1_0             RGB: [0.3383747  0.38222563 0.27939966]
squeezenet1_1             RGB: [0.3423823  0.40102834 0.25658935]
vgg11                     RGB: [0.34595174 0.37333512 0.2807131 ]
vgg11_bn                  RGB: [0.32353705 0.41581088 0.26065207]
vgg13                     RGB: [0.3471459 0.3867222 0.2661319]
vgg13_bn                  RGB: [0.32355314 0.4105902  0.26585665]
vgg16                     RGB: [0.3416876  0.37477094 0.2835415 ]
vgg16_bn                  RGB: [0.32831633 0.4132697  0.25841403]
vgg19                     RGB: [0.3444211  0.37442288 0.2811561 ]
vgg19_bn                  RGB: [0.3258713  0.41411158 0.26001713]
wide_resnet101_2          RGB: [0.3324678  0.41679963 0.25073257]
wide_resnet50_2           RGB: [0.33688802 0.39650932 0.26660258]

Mean RGB: [0.3378277  0.40201578 0.26015648]
STD RGB: [0.01547325 0.03337919 0.02063317]
import torch
import torchvision.models as models
from types import FunctionType


def check_first_layer(model):
    for name, weights in model.named_parameters():
        w = weights.abs()
        chn = w.sum(dim=0).sum(-1).sum(-1)
        # Normalize so that R+G+B=1
        chn = chn / chn.sum(0).expand_as(chn)
        chn[torch.isnan(chn)] = 0
        return chn.detach().numpy()


chn_info = torch.tensor([])
for model_name in dir(models):
    if model_name[0].islower():
        attr = getattr(models, model_name)
        if isinstance(attr, FunctionType):
            try:
                model = attr(pretrained=True)
                rgb_vec = check_first_layer(model)
                print(f'{model_name: <25} RGB: {rgb_vec}')
                rgb_vec = torch.tensor(rgb_vec).view(1, -1)
                chn_info = torch.cat((chn_info, rgb_vec), dim=0)
            except:
                pass

mean = chn_info.mean(dim=0)
std = chn_info.std(dim=0)

print(f'Mean RGB: {mean.numpy()}')
print(f'STD RGB: {std.numpy()}')

Results: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Mean RGB: [0.3378277 0.40201578 0.26015648]
STD RGB: [0.01547325 0.03337919 0.02063317]
.~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

By these results, we can see that the weights that are convoluted with the green channel is the largest and the blue is the smallest.

Does this mean that the green layer is more important for the networks? What do you think?

What could be some other reasons for these differences?

These are some interesting results and I’m not sure, if the amplitude of the channel is corresponding to the importance of it.
Did you check, if the green color channel of the input images was “adjusted”?
E.g. wouldn’t the weight in the green channel be larger, if the signal from the green channel is smaller in amplitude? :thinking:

However, as a red-green color blind person, I cannot confirm the assumption of a higher sensitivity towards red and green. :wink:

I can’t speak about the relation between kernel weights and colors, but it’s possible that the Bayer filters in cameras play a role in it.

Most RGB images result from demosaicing, basically interpolation from the existing “color” values, and the green channel has two times the number of pixels than red and blue, therefore being more “accurate” in the green channel than red and blue…

1 Like

Good idea. The models are trained on imagenet but I don’t have this dataset so I can’t (yet) do this. I guess you could compare the relative values of the image channels, similarily as to was done with the weights. And it would be of interest to do this after the transformations/normalizations that are applied to the input images, as this is what actually goes into the models.

To my knowledge, all the models were trained with this normalization.

transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])

I’m not sure, but I believe that these parameter values were derived from the imagenet dataset so that normalized images would resemble a certain distribution, e.g. mean=0, std=1. If we assume that this is true, then →

The normalization parameters are quite similar for the different channels which suggest that imagenet doesn’t have large differences for the RGB-channels. Idk if my reasoning holds here, but it makes sense to me.

:laughing::rofl::joy:

I agree, this sounds like a somewhat plausible explanations. We already know that the nets largely take their cues from local image features (fur, eyes, etc) source. But that ‘image correctness’ on pixel-level would matter that much for a network?

I guess we have seen that ‘pixel correctness’ matters a lot in some adverserial attacks source1 source2

Is there a way to test for this? The only way I can think of right now would be to train on images where we delete green pixels corresponding to 50% of the original green pixels in the bayer filters, and interpolate it back. I see a few problems with this approach tho…