How to train with frozen BatchNorm?

Since pytorch does not support syncBN, I hope to freeze mean/var of BN layer while trainning. Mean/Var in pretrained model are used while weight/bias are learnable.

In this way, calculation of bottom_grad in BN will be different from that of the novel trainning mode. However, we do not find any flag in the function bellow to mark this difference.

pytorch/torch/csrc/cudnn/BatchNorm.cpp

void cudnn_batch_norm_backward(
    THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
    THVoidTensor* input, THVoidTensor* grad_output, THVoidTensor* grad_input,
    THVoidTensor* grad_weight, THVoidTensor* grad_bias, THVoidTensor* weight,
    THVoidTensor* running_mean, THVoidTensor* running_var,
    THVoidTensor* save_mean, THVoidTensor* save_var, bool training,
    double epsilon)
{
  CHECK(cudnnSetStream(handle, THCState_getCurrentStream(state)));
  assertSameGPU(dataType, input, grad_output, grad_input, grad_weight, grad_bias, weight,
      running_mean, running_var, save_mean, save_var);
  cudnnBatchNormMode_t mode;
  if (input->nDimension == 2) {
    mode = CUDNN_BATCHNORM_PER_ACTIVATION;
  } else {
    mode = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7003
    if(training)
      mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#endif

  }

  THVoidTensor_assertContiguous(input);
  THVoidTensor_assertContiguous(grad_output);
  THVoidTensor_assertContiguous(grad_weight);
  THVoidTensor_assertContiguous(grad_bias);
  THVoidTensor_assertContiguous(save_mean);
  THVoidTensor_assertContiguous(save_var);

  TensorDescriptor idesc;  // input descriptor
  TensorDescriptor odesc;  // output descriptor
  TensorDescriptor gdesc;  // grad_input descriptor
  TensorDescriptor wdesc;  // descriptor for weight, bias, running_mean, etc.
  setInputDescriptor(idesc, dataType, input);
  setInputDescriptor(odesc, dataType, grad_output);
  setInputDescriptor(gdesc, dataType, grad_input);
  setScaleDescriptor(wdesc, scaleDataType(dataType), weight, input->nDimension);

  Constant one(dataType, 1);
  Constant zero(dataType, 0);

  CHECK(cudnnBatchNormalizationBackward(
    handle, mode, &one, &zero, &one, &zero,
    idesc.desc, tensorPointer(dataType, input),
    odesc.desc, tensorPointer(dataType, grad_output),
    gdesc.desc, tensorPointer(dataType, grad_input),
    wdesc.desc, tensorPointer(dataType, weight),
    tensorPointer(dataType, grad_weight),
    tensorPointer(dataType, grad_bias),
    epsilon,
    tensorPointer(dataType, save_mean),
    tensorPointer(dataType, save_var)));
}

Anyone can give some help?

1 Like

my bad. see below…

Hi, @SimonW
Thanks for your reply.
Does BN module will use different back propagagtion method for TRAIN/EVAL mode w/o affine weights? I can not find the code for this difference.

Regards,

Yes it will use proper bwd calculation.

Thanks @SimonW
I test BN with 4 modes (with or w/o Affine, train or eval) and find that BN uses different bwd calculation for TRAIN/EVAL mode, regardless of the affine weights. See code below.

import torch
import torch.backends.cudnn as cudnn


if __name__ == '__main__':
    print("CUDNN Version: {}".format(cudnn.version()))
    cudnn.enabled = True
    cudnn.benchmark = True
    print("############## Check BN with Frozen Param")
    input_tensor = torch.rand(1, 64, 100, 100).cuda() * 100
    weight_init = torch.ones(64)
    bias_init = torch.zeros(64)
    mean_init = torch.rand(64) * 100
    var_init = torch.rand(64) * 100

    # torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)
    BN0 = torch.nn.BatchNorm2d(64, affine=True)
    state_dict_0 = BN0.state_dict()
    state_dict_0['weight'].copy_(weight_init)
    state_dict_0['bias'].copy_(bias_init)
    state_dict_0['running_mean'].copy_(mean_init)
    state_dict_0['running_var'].copy_(var_init)
    BN0 = BN0.cuda()

    BN0.weight.requires_grad = False
    BN0.bias.requires_grad = False
    BN0.train()
    input_tensor_0 = input_tensor.clone()
    input_var_0 = torch.autograd.Variable(input_tensor_0, requires_grad=True)
    output_var_0 = BN0(input_var_0)
    loss_0 = output_var_0.sum()
    loss_0.backward()

    BN1 = torch.nn.BatchNorm2d(64, affine=True)
    state_dict_1 = BN1.state_dict()
    state_dict_1['weight'].copy_(weight_init)
    state_dict_1['bias'].copy_(bias_init)
    state_dict_1['running_mean'].copy_(mean_init)
    state_dict_1['running_var'].copy_(var_init)
    BN1 = BN1.cuda()

    BN1.train()
    input_tensor_1 = input_tensor.clone()
    input_var_1 = torch.autograd.Variable(input_tensor_1, requires_grad=True)
    output_var_1 = BN1(input_var_1)
    loss_1 = output_var_1.sum()
    loss_1.backward()

    BN2 = torch.nn.BatchNorm2d(64, affine=True)
    state_dict_2 = BN2.state_dict()
    state_dict_2['weight'].copy_(weight_init)
    state_dict_2['bias'].copy_(bias_init)
    state_dict_2['running_mean'].copy_(mean_init)
    state_dict_2['running_var'].copy_(var_init)
    BN2 = BN2.cuda()

    BN2.eval()
    input_tensor_2 = input_tensor.clone()
    input_var_2 = torch.autograd.Variable(input_tensor_2, requires_grad=True)
    output_var_2 = BN2(input_var_2)
    loss_2 = output_var_2.sum()
    loss_2.backward()

    BN3 = torch.nn.BatchNorm2d(64, affine=False)
    state_dict_3 = BN3.state_dict()
    state_dict_3['running_mean'].copy_(mean_init)
    state_dict_3['running_var'].copy_(var_init)
    BN3 = BN3.cuda()

    BN3.train()
    input_tensor_3 = input_tensor.clone()
    input_var_3 = torch.autograd.Variable(input_tensor_3, requires_grad=True)
    output_var_3 = BN3(input_var_3)
    loss_3 = output_var_3.sum()
    loss_3.backward()

    BN4 = torch.nn.BatchNorm2d(64, affine=False)
    state_dict_4 = BN4.state_dict()
    state_dict_4['running_mean'].copy_(mean_init)
    state_dict_4['running_var'].copy_(var_init)
    BN4 = BN4.cuda()

    BN4.eval()
    input_tensor_4 = input_tensor.clone()
    input_var_4 = torch.autograd.Variable(input_tensor_4, requires_grad=True)
    output_var_4 = BN4(input_var_4)
    loss_4 = output_var_4.sum()
    loss_4.backward()

    print((input_var_2.grad - input_var_1.grad).abs().max().data[0])
    print((input_var_4.grad - input_var_3.grad).abs().max().data[0])
    print((input_var_4.grad - input_var_2.grad).abs().max().data[0])
    print((input_var_3.grad - input_var_1.grad).abs().max().data[0])
    print((input_var_1.grad - input_var_0.grad).abs().max().data[0])

The result is

CUDNN Version: 6021
############## Check BN with Frozen Param
0.8908671140670776
0.8908671140670776
0.0
1.646934677523859e-08
0.0

Could you please show me how BN module use diffenrent bwd calculation in dependence on TRAIN/EVAL mode. I can not find it in the code.

Regards,

My bad on the previous reply. You only need to set eval mode. No need to make a separate affine transform.

The branching to call training/eval backward is at https://github.com/pytorch/pytorch/blob/04ad23252a3ce592c5e5c30c6fd87000f8d178cf/tools/autograd/derivatives.yaml#L1038 (if you have cudnn installed)

@SimonW, Thanks very much for your reply. It is very helpful~

Hi!
Did you manage to solve this?
I’m trying to do the same thing, training with fixed mean/var for the batchnorm layer.
I set all batchnorm layers to eval mode during training using this function

net.train()
for module in net.modules():
        if isinstance(module, torch.nn.modules.BatchNorm1d):
            module.eval()
        if isinstance(module, torch.nn.modules.BatchNorm2d):
            module.eval()
        if isinstance(module, torch.nn.modules.BatchNorm3d):
            module.eval()  

However I get really bad training results. The validation score goes to zero straight away. I’ve tried doing the same training without setting the batchnorm layers to eval and that works fine.

I override the train() function of my model.

    def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        """
        super(MyNet, self).train(mode)
        if self.freeze_bn:
            print("Freezing Mean/Var of BatchNorm2D.")
            if self.freeze_bn_affine:
                print("Freezing Weight/Bias of BatchNorm2D.")
        if self.freeze_bn:
            for m in self.backbone.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    if self.freeze_bn_affine:
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False

This works for me.

15 Likes

Hi! I think you can fix the mean and variance by setting the affine parameter to False.
What that does is fix learnable parameters like “gamma” (variance) and “alpha” (mean) to 1 and 0. Effectively, you’ll only be using the running mean and variance (that’s an exponentially weighted average to keep track of mean and variance) during test time. (is when you call eval ())

I too have a similar issue with the train/eval scenario that you’re talking about but no benchmark to compare it against (RL problems). Both give me results of different orders on the first iteration. I’m hoping it’s because the running mean is updated only through the first sample. If someone knows the answer to this, please enlighten me…

WARNING: PROCEED WITH CAUTION. WHILE OVERRIDING NN.MODULE.TRAIN() CAN BE HELPFUL TO CONTROL THE FROZEN PARAMETERS, IT CAN BACKFIRE IF YOU FORGET SELF.FREEZE SET TO TRUE WHILE YOU INTEND TO TRAIN THE ENTIRE MODEL.

as @EthanZhangYi did, i recommend to override your model.train() because when you call the original, batchnorm will be back on (the running avg, var [buffers only], but not the params). it is safer this way.

    def train(self, mode=True):
        """
        Override nn.Module.train() to consider the freezing aspect if required.
        """
        super(YourModel, self).train(mode=mode)  # will turn on batchnorm (buffers not params).

        if self.freeze:
            self.freeze_stuff()   # call freeze to turn off the batch-norm. 

        return self

self.freeze_stuff() should take care of parameters, batchnorm, and dropout (at least).

1 Like