Does any one know a way of implementing mean-only batch norm of https://arxiv.org/abs/1602.07868? I.e. I want to subtract the mean from the activations but not divide them by deviation.
maybe you can just subclass this:
https://pytorch.org/docs/master/_modules/torch/nn/modules/batchnorm.html
and override the forward function setting self.running_var.fill_(1) before calling F.batch_norm(blah blah)
Thanks! But, I want this mean-only behavior for training as well not just for inference. So, fixing runnning variance would not help?
It should work also while training. gamma and beta are still trainable parameters, the only difference is that they are applied to the inputs normalized using the batch mean statistics and fixed variance.
The issue is not beta and gamma. In training time, current batch statistics are used not running averages.
Oh I see, sorry. The normalization is in the cuda code so the simplest way I think is to port the openai code https://github.com/openai/weightnorm
lasagne code is easy
I’m trying to implement it for my task, loss is decreasing… but not sure if correct
class MeanOnlyBatchNorm(nn.Module):
def __init__(self, num_features, momentum=0.1):
super(MeanOnlyBatchNorm, self).__init__()
self.num_features = num_features
self.momentum = momentum
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
self.register_buffer('running_mean', torch.zeros(num_features))
self.reset_parameters()
def reset_parameters(self):
self.running_mean.zero_()
self.weight.data.uniform_()
self.bias.data.zero_()
def forward(self, inp):
size = list(inp.size())
gamma = self.weight.view(1, self.num_features, 1, 1)
beta = self.bias.view(1, self.num_features, 1, 1)
if self.training:
avg = torch.mean(inp.view(size[0], self.num_features, -1), dim=2)
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * torch.mean(avg.data, dim=0)
else:
avg = Variable(self.running_mean.repeat(size[0], 1), requires_grad=False)
output = inp - avg.view(size[0], size[1], 1, 1)
output = output*gamma + beta
return output