I’m trying to implement a “BatchScale” Module that is akin to BatchNorm, except that there is no shift/mean/bias, just a scaling/weighting and a running variance. Initially I thought I could just copy and modify the _BatchNorm class, but it calls the C version of batch_norm and as far as I can tell, that function will always compute and apply a running mean.
I tried writing a class and it seems straight forward enough, but the final line in forward() that computes the y=x/sqrt(var+eps)*gamma gives the following error, and I’m struggling to understand the problem.
*** RuntimeError: div() received an invalid combination of arguments - got (torch.FloatTensor), but expected one of:
(float other)
didn’t match because some of the arguments have invalid types: (torch.FloatTensor)
(Variable other)
didn’t match because some of the arguments have invalid types: (torch.FloatTensor)
Here’s is my code. Any advice would be much appreciated!
class Scale(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, linear=True):
super(Scale, self).__init__()
self.num_features = num_features
self.linear = linear
self.eps = eps
self.momentum = momentum
if linear:
self.weight = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
self.running_var.fill_(1)
if self.linear:
self.weight.data.uniform_()
def forward(self, input):
if self.training:
# Update variance estimate if in training mode
batch_var = input.var(dim=0).data
self.running_var = (1-self.momentum)*self.running_var + self.momentum*batch_var
return input / (self.running_var + self.eps).sqrt() * self.weight
def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
' linear={linear})'
.format(name=self.__class__.__name__, **self.__dict__))