We had the same error message but it was related to the difference in the batch dimensions of the linear layer and the batch normalization layer. As far as I understand from documentation,
- nn.Linear uses batch dimensions such as [batch_size, *, features], where * can be any number of additional dimensions. (https://pytorch.org/docs/stable/nn.html#torch.nn.Linear)
- nn.BatchNorm1d uses batch dimensions such as [batch_size, features, optional third dim] (https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm1d)
If your batch has more than two dimensions and is processed by a linear layer, features are in the last dimension. To pass it to the batch normalization layer, you can flatten the tensor and reshape it after batch normalisation.
Snippet:
x = linear(x)
x = batchnorm1d(x.flatten(0,-2)).reshape(x.shape)
I know you stated that your batch only has two dimensions. But I came across this thread while searching for a solution to my problem and perhaps my answer is helpful to someone.