How to train with frozen BatchNorm?


(Zhang Yi) #1

Since pytorch does not support syncBN, I hope to freeze mean/var of BN layer while trainning. Mean/Var in pretrained model are used while weight/bias are learnable.

In this way, calculation of bottom_grad in BN will be different from that of the novel trainning mode. However, we do not find any flag in the function bellow to mark this difference.

pytorch/torch/csrc/cudnn/BatchNorm.cpp

void cudnn_batch_norm_backward(
    THCState* state, cudnnHandle_t handle, cudnnDataType_t dataType,
    THVoidTensor* input, THVoidTensor* grad_output, THVoidTensor* grad_input,
    THVoidTensor* grad_weight, THVoidTensor* grad_bias, THVoidTensor* weight,
    THVoidTensor* running_mean, THVoidTensor* running_var,
    THVoidTensor* save_mean, THVoidTensor* save_var, bool training,
    double epsilon)
{
  CHECK(cudnnSetStream(handle, THCState_getCurrentStream(state)));
  assertSameGPU(dataType, input, grad_output, grad_input, grad_weight, grad_bias, weight,
      running_mean, running_var, save_mean, save_var);
  cudnnBatchNormMode_t mode;
  if (input->nDimension == 2) {
    mode = CUDNN_BATCHNORM_PER_ACTIVATION;
  } else {
    mode = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7003
    if(training)
      mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#endif

  }

  THVoidTensor_assertContiguous(input);
  THVoidTensor_assertContiguous(grad_output);
  THVoidTensor_assertContiguous(grad_weight);
  THVoidTensor_assertContiguous(grad_bias);
  THVoidTensor_assertContiguous(save_mean);
  THVoidTensor_assertContiguous(save_var);

  TensorDescriptor idesc;  // input descriptor
  TensorDescriptor odesc;  // output descriptor
  TensorDescriptor gdesc;  // grad_input descriptor
  TensorDescriptor wdesc;  // descriptor for weight, bias, running_mean, etc.
  setInputDescriptor(idesc, dataType, input);
  setInputDescriptor(odesc, dataType, grad_output);
  setInputDescriptor(gdesc, dataType, grad_input);
  setScaleDescriptor(wdesc, scaleDataType(dataType), weight, input->nDimension);

  Constant one(dataType, 1);
  Constant zero(dataType, 0);

  CHECK(cudnnBatchNormalizationBackward(
    handle, mode, &one, &zero, &one, &zero,
    idesc.desc, tensorPointer(dataType, input),
    odesc.desc, tensorPointer(dataType, grad_output),
    gdesc.desc, tensorPointer(dataType, grad_input),
    wdesc.desc, tensorPointer(dataType, weight),
    tensorPointer(dataType, grad_weight),
    tensorPointer(dataType, grad_bias),
    epsilon,
    tensorPointer(dataType, save_mean),
    tensorPointer(dataType, save_var)));
}

Anyone can give some help?


(Simon Wang) #2

my bad. see below…


(Zhang Yi) #3

Hi, @SimonW
Thanks for your reply.
Does BN module will use different back propagagtion method for TRAIN/EVAL mode w/o affine weights? I can not find the code for this difference.

Regards,


(Simon Wang) #4

Yes it will use proper bwd calculation.


(Zhang Yi) #5

Thanks @SimonW
I test BN with 4 modes (with or w/o Affine, train or eval) and find that BN uses different bwd calculation for TRAIN/EVAL mode, regardless of the affine weights. See code below.

import torch
import torch.backends.cudnn as cudnn


if __name__ == '__main__':
    print("CUDNN Version: {}".format(cudnn.version()))
    cudnn.enabled = True
    cudnn.benchmark = True
    print("############## Check BN with Frozen Param")
    input_tensor = torch.rand(1, 64, 100, 100).cuda() * 100
    weight_init = torch.ones(64)
    bias_init = torch.zeros(64)
    mean_init = torch.rand(64) * 100
    var_init = torch.rand(64) * 100

    # torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)
    BN0 = torch.nn.BatchNorm2d(64, affine=True)
    state_dict_0 = BN0.state_dict()
    state_dict_0['weight'].copy_(weight_init)
    state_dict_0['bias'].copy_(bias_init)
    state_dict_0['running_mean'].copy_(mean_init)
    state_dict_0['running_var'].copy_(var_init)
    BN0 = BN0.cuda()

    BN0.weight.requires_grad = False
    BN0.bias.requires_grad = False
    BN0.train()
    input_tensor_0 = input_tensor.clone()
    input_var_0 = torch.autograd.Variable(input_tensor_0, requires_grad=True)
    output_var_0 = BN0(input_var_0)
    loss_0 = output_var_0.sum()
    loss_0.backward()

    BN1 = torch.nn.BatchNorm2d(64, affine=True)
    state_dict_1 = BN1.state_dict()
    state_dict_1['weight'].copy_(weight_init)
    state_dict_1['bias'].copy_(bias_init)
    state_dict_1['running_mean'].copy_(mean_init)
    state_dict_1['running_var'].copy_(var_init)
    BN1 = BN1.cuda()

    BN1.train()
    input_tensor_1 = input_tensor.clone()
    input_var_1 = torch.autograd.Variable(input_tensor_1, requires_grad=True)
    output_var_1 = BN1(input_var_1)
    loss_1 = output_var_1.sum()
    loss_1.backward()

    BN2 = torch.nn.BatchNorm2d(64, affine=True)
    state_dict_2 = BN2.state_dict()
    state_dict_2['weight'].copy_(weight_init)
    state_dict_2['bias'].copy_(bias_init)
    state_dict_2['running_mean'].copy_(mean_init)
    state_dict_2['running_var'].copy_(var_init)
    BN2 = BN2.cuda()

    BN2.eval()
    input_tensor_2 = input_tensor.clone()
    input_var_2 = torch.autograd.Variable(input_tensor_2, requires_grad=True)
    output_var_2 = BN2(input_var_2)
    loss_2 = output_var_2.sum()
    loss_2.backward()

    BN3 = torch.nn.BatchNorm2d(64, affine=False)
    state_dict_3 = BN3.state_dict()
    state_dict_3['running_mean'].copy_(mean_init)
    state_dict_3['running_var'].copy_(var_init)
    BN3 = BN3.cuda()

    BN3.train()
    input_tensor_3 = input_tensor.clone()
    input_var_3 = torch.autograd.Variable(input_tensor_3, requires_grad=True)
    output_var_3 = BN3(input_var_3)
    loss_3 = output_var_3.sum()
    loss_3.backward()

    BN4 = torch.nn.BatchNorm2d(64, affine=False)
    state_dict_4 = BN4.state_dict()
    state_dict_4['running_mean'].copy_(mean_init)
    state_dict_4['running_var'].copy_(var_init)
    BN4 = BN4.cuda()

    BN4.eval()
    input_tensor_4 = input_tensor.clone()
    input_var_4 = torch.autograd.Variable(input_tensor_4, requires_grad=True)
    output_var_4 = BN4(input_var_4)
    loss_4 = output_var_4.sum()
    loss_4.backward()

    print((input_var_2.grad - input_var_1.grad).abs().max().data[0])
    print((input_var_4.grad - input_var_3.grad).abs().max().data[0])
    print((input_var_4.grad - input_var_2.grad).abs().max().data[0])
    print((input_var_3.grad - input_var_1.grad).abs().max().data[0])
    print((input_var_1.grad - input_var_0.grad).abs().max().data[0])

The result is

CUDNN Version: 6021
############## Check BN with Frozen Param
0.8908671140670776
0.8908671140670776
0.0
1.646934677523859e-08
0.0

Could you please show me how BN module use diffenrent bwd calculation in dependence on TRAIN/EVAL mode. I can not find it in the code.

Regards,


(Simon Wang) #6

My bad on the previous reply. You only need to set eval mode. No need to make a separate affine transform.

The branching to call training/eval backward is at https://github.com/pytorch/pytorch/blob/04ad23252a3ce592c5e5c30c6fd87000f8d178cf/tools/autograd/derivatives.yaml#L1038 (if you have cudnn installed)


(Zhang Yi) #7

@SimonW, Thanks very much for your reply. It is very helpful~


(Måns Larsson) #8

Hi!
Did you manage to solve this?
I’m trying to do the same thing, training with fixed mean/var for the batchnorm layer.
I set all batchnorm layers to eval mode during training using this function

net.train()
for module in net.modules():
        if isinstance(module, torch.nn.modules.BatchNorm1d):
            module.eval()
        if isinstance(module, torch.nn.modules.BatchNorm2d):
            module.eval()
        if isinstance(module, torch.nn.modules.BatchNorm3d):
            module.eval()  

However I get really bad training results. The validation score goes to zero straight away. I’ve tried doing the same training without setting the batchnorm layers to eval and that works fine.


(Zhang Yi) #9

I override the train() function of my model.

    def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        """
        super(MyNet, self).train(mode)
        if self.freeze_bn:
            print("Freezing Mean/Var of BatchNorm2D.")
            if self.freeze_bn_affine:
                print("Freezing Weight/Bias of BatchNorm2D.")
        if self.freeze_bn:
            for m in self.backbone.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    if self.freeze_bn_affine:
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False

This works for me.


(Aishwarya Unnikrishnan) #10

Hi! I think you can fix the mean and variance by setting the affine parameter to False.
What that does is fix learnable parameters like “gamma” (variance) and “alpha” (mean) to 1 and 0. Effectively, you’ll only be using the running mean and variance (that’s an exponentially weighted average to keep track of mean and variance) during test time. (is when you call eval ())

I too have a similar issue with the train/eval scenario that you’re talking about but no benchmark to compare it against (RL problems). Both give me results of different orders on the first iteration. I’m hoping it’s because the running mean is updated only through the first sample. If someone knows the answer to this, please enlighten me…