Am I too dumb to implement Instance Norm?

I tried to implement something related to Layer/Group norm from scratch (without using F.batchnorm) and it would not run properly, so I dumbed it down all the way to InstanceNorm and even that does not work. I know there is a pytorch implementation and that one works, but for what I would like to do I need a custom implementation. It must be something obvious that I am missing because it is a really simple thing to do but I just can’t figure out what’s wrong…

class InstanceNorm(nn.Module):
    def __init__(self, *args, **kwargs):
        super(InstanceNorm, self).__init__()

    def forward(self, x):

        :param x: has either shape (b, c, x, y) or shape (b, c, x, y, z)
        x2 = x.view(*x.shape[:2], -1) # x2.shape = (b, c, x*y(*z))

        #detach to stop gradients
        mn = x2.mean(-1).detach() # mn.shape = (b, c)
        sd = x2.std(-1).detach() # sd.shape = (b, c)
        mn = mn.view(*mn.shape[:2], *((len(x.shape) - 2) * [1])) # mn.shape = (b, c, 1, 1(, 1))
        sd = sd.view(*sd.shape[:2], *((len(x.shape) - 2) * [1])) # sd.shape = (b, c, 1, 1(, 1))

        x = (x - mn) / (sd + 1e-8)
        return x

Thanks a lot!


When you mean it does not work, what does that mean? It crashes? Or it fails to train?

I am not an expert at all but I think in batchnorm at least, you actually want to backprop through the mean and std computation, not doing makes the performances much worst.

Hi Alban,
sorry I should have been more specific. If I try to train a model that includes this normalization instead of the pytorch implementation it will not train properly (loss does not go down).
If I do not deatch() mn and sd then the loss is nan (which I don’t understand because I am adding +1e-8 to sd to prevent just that.

For single precision 1e-8 is too small I think, you would need at least 1e-6.

Swapped 1e-8 to 1e-6, still nan :frowning:

You can use the register_hook() function to have a function called when gradient is computed.
I would use that to check exactly where are the nans appearing here. Is it possible that std returns nan gradients for certain inputs?

hooks revealed that the further the gradients were propagated the larger they got (pretty rapidly) until they overflowed.
Interestingly, replacing .std() with .var() fixes the problem and the network will train properly. Any ideas on why that is?

The problem seems to be with the square root of the variance.
x = (x - mn) / (var.sqrt() + 1e-6)
will give nans while
x = (x - mn) / (var + 1e-6).sqrt()
works just fine. Strange.

This is what I expected: the gradient of square root at 0 does not exist and so will return nan. So in your case if the variance is 0, then you will get nan gradients.

Hi Alban,
thank you so much for your help. I did not think about what the gradient of sqrt would be at 0.
What I do not fully understand is why I need to not detach the mean and standard deviation. Could you explain that a little more (if you don’t mind?)
Thanks a lot again!

Well I do not want to be associated with this at all ! :smiley:
It’s just what they do in the paper. And if you do it in practice (even though it makes no sense in theory) it works better.

Alrighty then. I did not hear or read anything.

Do you know how the internal pytorch implementation works? Mine uses more memory and I need to reduce the batch size by abot 30% - that’s quite a lot. Inplace for x = (x - mn) / (sd) does not work :-/

I’m afraid this comes down to low level optimizations.
One simple thing you can do is implement your own Function (see doc on how to do it). That will allow you to keep only a minimal set of intermediary variables. But that means that you will have to reimplement the backward of your function by hand.

Thanks a lot for the reply. I think implementing my own function will be too much effort for what I want to do :slight_smile:
Thanks again for your help. I appreciate that a lot!