batchNorm1d input N,C,L vs N,L

I expected the following code to have True for all entries of a but it doesn’t. Why is that? (I’m trying to understand the difference between an input shape N,C,L and N,L for batch norm 1D)

  self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
  def bn(self, x):
        batch_size, num_nodes, num_channels = x.size()
        y = x.clone()
        x = x.view(-1, num_channels)
        x = getattr(self, "bn1")(x)
        y = y.transpose(2,1)
        y = getattr(self,"bn1")(y)
        y = y.reshape(-1, num_channels)
        a = torch.eq(x,y)

The last reshape operation will interleave the output values and you would need to permute y before calling reshape:


hidden_channels = 4
bn1 = torch.nn.BatchNorm1d(hidden_channels)

x = torch.randn(2, 3, 4)
batch_size, num_nodes, num_channels = x.size()
y = x.clone()
x = x.view(-1, num_channels)
x = bn1(x)

y = y.transpose(2,1)
y = bn1(y)
y = y.permute(0, 2, 1).reshape(-1, num_channels)
print((x - y).abs().max())
> tensor(0., grad_fn=<MaxBackward1>)

y will have the shape [batch_size, channels, seq_len] after the call into bn1.
If you want to flatten the temporal dimension into the batch dimension, you would have to permute it first to [batch_size, seq_len, channels] before calling reshape(-1, num_channels).