Hello,
I recently upgraded from libtorch 1.2 to 1.4 and, among other things, replaced torch::nn::BatchNorm
with torch::nn::BatchNorm2d
in my code, since the former is now deprecated. I started seeing vastly different results with the new library, and tracked them to batch norm. Below is a minimal example that already shows the issue.
I would maybe expect some small delta between the two classes, but the differences I’m seeing in large runs are huge, going all the way from training fine to not converging at all. I did not see anything in the documentation that would explain such change in behavior.
Could someone please advise if there was any intentional (possibly undocumented) change in behavior between BatchNorm
and BatchNorm2d
in libtorch 1.4? Thanks!
#include <torch/torch.h>
#include <iostream>
int main()
{
int64_t const C(3), H(4), W(4);
torch::manual_seed(1);
auto const input = torch::rand({1, C, H, W});
/*** The two outputs below should be identical, but are vastly different! ***/
torch::manual_seed(1);
std::cout << torch::nn::BatchNorm(C)->forward(input) << std::endl;
torch::manual_seed(1);
std::cout << torch::nn::BatchNorm2d(C)->forward(input) << std::endl;
return 0;
}
Here’s the output of the above:
Warning: torch::nn::BatchNorm module is deprecated and will be removed in 1.5. Use BatchNorm{1,2,3}d instead. (BatchNormImpl at ../../torch/csrc/api/src/nn/modules/batchnorm.cpp:21)
(1,1,.,.) =
0.9024 -0.8736 -0.4140 0.8172
-1.8019 1.0592 -0.4361 0.8903
0.2040 -0.2815 0.4608 0.0375
0.6239 -0.7776 -0.1895 -0.2213
(1,2,.,.) =
0.0272 -0.0509 0.4098 0.1145
-0.2442 -0.3657 -0.1367 -0.2751
-0.2169 -0.0237 0.2639 0.2363
-0.5616 0.2764 0.0978 0.4488
(1,3,.,.) =
0.4244 -0.5917 -0.6174 -0.3144
0.5654 -0.1102 0.2743 -0.3030
0.6321 -0.2761 0.4825 0.0052
-0.3432 0.3300 -0.3458 0.1877
[ CPUFloatType{1,3,4,4} ]
(1,1,.,.) =
1.1911 -1.1530 -0.5465 1.0787
-2.3783 1.3981 -0.5756 1.1752
0.2692 -0.3715 0.6082 0.0494
0.8235 -1.0264 -0.2501 -0.2921
(1,2,.,.) =
0.0975 -0.1822 1.4672 0.4098
-0.8742 -1.3091 -0.4895 -0.9849
-0.7766 -0.0848 0.9450 0.8462
-2.0108 0.9897 0.3500 1.6069
(1,3,.,.) =
1.0529 -1.4680 -1.5317 -0.7800
1.4028 -0.2734 0.6805 -0.7517
1.5682 -0.6850 1.1971 0.0129
-0.8514 0.8186 -0.8578 0.4658
[ CPUFloatType{1,3,4,4} ]