Confusion on how InstanceNorm1d is operating

As long as I understand InstanceNorm1d normalizes the features for each channel.
I am not sure why InstanceNorm1d works in both of the following examples and does not give me an error…

m = nn.InstanceNorm1d(5)
batch, sentence_length, embedding_dim = 2, 3, 5
embedding1 = torch.randn(batch, sentence_length, embedding_dim)
embedding2 = torch.randn(batch, embedding_dim, sentence_length)
case1) 
out1 = m(embedding1)
case2)
out2 = m(embedding2)

what is the difference of case 1 and 2 and and how the InstanceNorm1d is computed for each case?

If I’m understanding your question right, you are asking why the number you pass into InstanceNorm plays no role?
This is because it is only used when you

  • either collect statistics or
  • have weight and/or bias ,

neither of which instance norm does by default.

Best regards

Thomas

I am not sure, can you provide an example if possible please?

my question is what is happening for out1 = m(embedding1) and out1 = m(embedding1).
I dont get any error in case out1 = m(embedding1), although it does not have the format that suggested by pytorch doc.
I think in my example the InstanceNorm is actually doing something.
if I print (out1==embedding1).all(),(out2==embedding2).all() the results is false, false.
But I am not sure what is happening actually.

Also, it is not clear to me what it means when you say:

even if I use nn.InstanceNorm1d(in_channels, track_running_stats=False) it still works okay for both cases… :slight_smile:

It’s when you pass the (non-default) track_running_stats=True that you get an error.

The normalization should be over the last dimension, so
m(embedding1).mean(-1), m(embedding1).std(-1, unbiased=False) gives values very close to 0 and 1.

Best regards

Thomas

I see, yes that is correct, thanks for the clarification (I think you meant m(embedding2).mean(-1), m(embedding2).std(-1, unbiased=False) though).
so do you know when we track_running_stats=False what is computed in the m(embedding1)?

It is precisely (input - input.mean(-1, keepdim=True)) / input.std(-1, keepdim=True, unbiased=False)?

Yes, now it is clear to me!
Thanks a lot, Tom!