Since pytorch does not support syncBN, I hope to freeze mean/var of BN layer while trainning. Mean/Var in pretrained model are used while weight/bias are learnable.
In this way, calculation of bottom_grad in BN will be different from that of the novel trainning mode. However, we do not find any flag in the function bellow to mark this difference.
pytorch/torch/csrc/cudnn/BatchNorm.cpp
void cudnn_batch_norm_backward(
THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
THVoidTensor* input, THVoidTensor* grad_output, THVoidTensor* grad_input,
THVoidTensor* grad_weight, THVoidTensor* grad_bias, THVoidTensor* weight,
THVoidTensor* running_mean, THVoidTensor* running_var,
THVoidTensor* save_mean, THVoidTensor* save_var, bool training,
double epsilon)
{
CHECK(cudnnSetStream(handle, THCState_getCurrentStream(state)));
assertSameGPU(dataType, input, grad_output, grad_input, grad_weight, grad_bias, weight,
running_mean, running_var, save_mean, save_var);
cudnnBatchNormMode_t mode;
if (input->nDimension == 2) {
mode = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
mode = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7003
if(training)
mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#endif
}
THVoidTensor_assertContiguous(input);
THVoidTensor_assertContiguous(grad_output);
THVoidTensor_assertContiguous(grad_weight);
THVoidTensor_assertContiguous(grad_bias);
THVoidTensor_assertContiguous(save_mean);
THVoidTensor_assertContiguous(save_var);
TensorDescriptor idesc; // input descriptor
TensorDescriptor odesc; // output descriptor
TensorDescriptor gdesc; // grad_input descriptor
TensorDescriptor wdesc; // descriptor for weight, bias, running_mean, etc.
setInputDescriptor(idesc, dataType, input);
setInputDescriptor(odesc, dataType, grad_output);
setInputDescriptor(gdesc, dataType, grad_input);
setScaleDescriptor(wdesc, scaleDataType(dataType), weight, input->nDimension);
Constant one(dataType, 1);
Constant zero(dataType, 0);
CHECK(cudnnBatchNormalizationBackward(
handle, mode, &one, &zero, &one, &zero,
idesc.desc, tensorPointer(dataType, input),
odesc.desc, tensorPointer(dataType, grad_output),
gdesc.desc, tensorPointer(dataType, grad_input),
wdesc.desc, tensorPointer(dataType, weight),
tensorPointer(dataType, grad_weight),
tensorPointer(dataType, grad_bias),
epsilon,
tensorPointer(dataType, save_mean),
tensorPointer(dataType, save_var)));
}
Anyone can give some help?