Any idea on how to deal with BatchNorm2d when accumulating gradients?
It seems that BatchNorm2d updates the running mean and standard-deviation during the forward pass (see here).
I have a mini-batch size of n samples. I forward one sample at a time. (the loss function is divided by n). In this setup, I obtain bad performance compared to when I forward more than one samples at once (8 for instance). I expect to obtain the same result since the final accumulated gradient should be the same. I suspect that this has something to do with the BatchNorm2d in my model.
I use nn.CrossEntropyLoss(reduction=‘sum’) as a loss, and I divide it by the size of the mini-batch (i.e., n) when called.
Thank you!