How to write a "BatchScale" Module?

I’m trying to implement a “BatchScale” Module that is akin to BatchNorm, except that there is no shift/mean/bias, just a scaling/weighting and a running variance. Initially I thought I could just copy and modify the _BatchNorm class, but it calls the C version of batch_norm and as far as I can tell, that function will always compute and apply a running mean.

Any suggestions?

I tried writing a class and it seems straight forward enough, but the final line in forward() that computes the y=x/sqrt(var+eps)*gamma gives the following error, and I’m struggling to understand the problem.

*** RuntimeError: div() received an invalid combination of arguments - got (torch.FloatTensor), but expected one of:

  • (float other)
    didn’t match because some of the arguments have invalid types: (torch.FloatTensor)
  • (Variable other)
    didn’t match because some of the arguments have invalid types: (torch.FloatTensor)

Here’s is my code. Any advice would be much appreciated!

class Scale(Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, linear=True):
        super(Scale, self).__init__()
        self.num_features = num_features
        self.linear = linear
        self.eps = eps
        self.momentum = momentum
        if linear:
            self.weight = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
        self.register_buffer('running_var', torch.ones(num_features))
        self.reset_parameters()

    def reset_parameters(self):
        self.running_var.fill_(1)
        if self.linear:
            self.weight.data.uniform_()

    def forward(self, input):
        if self.training:
            # Update variance estimate if in training mode
            batch_var = input.var(dim=0).data
            self.running_var = (1-self.momentum)*self.running_var + self.momentum*batch_var
        return input / (self.running_var + self.eps).sqrt() * self.weight

    def __repr__(self):
        return ('{name}({num_features}, eps={eps}, momentum={momentum},'
                ' linear={linear})'
                .format(name=self.__class__.__name__, **self.__dict__))

Oops. I posted too soon. Wrapping the variance value in a Variable fixed the problem. Just had to change the last line of forward() to this:

return input / Variable((self.running_var + self.eps).sqrt()) * self.weight