BatchNorm2d mean and variance of pre-trained model

Hello everyone,

I have a pre-trained model called SoundNet, and the weights are available in TensorFlow. I created the model and I loaded the weights. The implementation is available here .
The preprocessing and many parameters were taken from the Tensorflow repository

The output of the first convolutional layer is exactly the same as the one in Tensorflow, however, the following BatchNorm2d layer yields different numbers.

There is an ambiguity in the way we should use the mean and the variance of Batchnorm layer in Pytorch. I am doubting about the way I load the weights of BatchNorm in Pytorch and the way I use it:

def put_weights(batchnorm, conv, params_w, batch_norm=True):
	if batch_norm:
		bn_bs = params_w['beta']
		batchnorm.bias = torch.nn.Parameter(torch.from_numpy(bn_bs))
		bn_ws = params_w['gamma']
		batchnorm.weight = torch.nn.Parameter(torch.from_numpy(bn_ws))
		bn_mean = params_w['mean']
		batchnorm.mean = torch.nn.Parameter(torch.from_numpy(bn_mean))
		bn_var = params_w['var']
		batchnorm.variance = torch.nn.Parameter(torch.from_numpy(bn_var))

	conv_bs = params_w['biases']
	conv.bias = torch.nn.Parameter(torch.from_numpy(conv_bs))
	conv_ws = params_w['weights']
	conv_ws = torch.from_numpy(conv_ws).permute(3, 2, 0, 1)
	conv.weight = torch.nn.Parameter(conv_ws)
	
	return batchnorm, conv

When I use the model to extract the features, I put the model in eval() mode

model.eval() # this to ensure that pre-calculated means and variances are used during feature extraction

Does anybody have an idea why the I got this kind of output? Do I use the model weights properly?

I calculated the BatchNorm2d manually by Numpy Broadcasting as defined in PyTorch reference. I got similar values to the ones in Tensorflow implementation, which confirms my doubt about the way I use BatchNorm2d in my implementation.

Any idea how to use these parameters from a pre-trained model for feature extraction?

.mean and variance do not exist in nn.BatchNorm.
Try to use .running_mean and .running_var instead.

EDIT: Also, I think you should register them as tensors not nn.Parameters.

After spending a few hours on this, as a beginner in PyTorch, I will put the answer, perhaps it will be useful for some people out there:

  • For mean and variance we should use running_mean.data and running_variance.data when assigning the pre-trained weights. This is strange since for the other variables, assigning the weights directly without explicit .data
  • It is safer to use the variable.data to assign weights in general. The function to load weights would look like this:
def put_weights(batchnorm, conv, params_w, batch_norm=True):
	if batch_norm:
		bn_bs = params_w['beta']
		batchnorm.bias.data = torch.from_numpy(bn_bs)
		bn_ws = params_w['gamma']
		batchnorm.weight.data = torch.from_numpy(bn_ws)
		bn_mean = params_w['mean']
		batchnorm.running_mean.data = torch.from_numpy(bn_mean)
		bn_var = params_w['var']
		batchnorm.running_var.data = torch.from_numpy(bn_var)
	
	conv_bs = params_w['biases']
	conv.bias.data = torch.from_numpy(conv_bs)
	conv_ws = params_w['weights']
	conv.weight.data = torch.from_numpy(conv_ws).permute(3, 2, 0, 1)
	return batchnorm, conv

Thank you very much,

That is true, and I should have registered them as tensors to tensor.data