Problems Using BatchNorm1d

We had the same error message but it was related to the difference in the batch dimensions of the linear layer and the batch normalization layer. As far as I understand from documentation,

If your batch has more than two dimensions and is processed by a linear layer, features are in the last dimension. To pass it to the batch normalization layer, you can flatten the tensor and reshape it after batch normalisation.
Snippet:

x = linear(x)
x = batchnorm1d(x.flatten(0,-2)).reshape(x.shape)

I know you stated that your batch only has two dimensions. But I came across this thread while searching for a solution to my problem and perhaps my answer is helpful to someone.