Computing the gradients for batch renormalization

I tried to implement Batch Renormalization(arXiv 1702.03275) in PyTorch. The program stop when compute the gradients. Trackback information is attached below:

Traceback (most recent call last):
  File "cifar.py", line 187, in <module>
    loss.backward()
  File "/usr/local/lib/python3.5/dist-packages/torch/autograd/variable.py", line 158, in backward
    self._execution_engine.run_backward((self,), (gradient,), retain_variables)
RuntimeError: could not compute gradients for some functions (View, ConvNd)

My implementation of batch renormalization is shown below:

class BatchRenorm2d(nn.Module):
    def __init__(self, channels, eps = 1e-5, rmax=3, dmax=5, lr=0.001):
        super(BatchRenorm2d, self).__init__()
        self.is_train = True
        self.is_unlock = False
        self.eps = eps 
        self.channels = channels
        self.rmax = rmax
        self.dmax = dmax
        self.lr = lr
        self.sigma = torch.from_numpy(np.zeros((1, channels, 1, 1), dtype=np.float32)).cuda()
        self.mean = torch.from_numpy(np.zeros((1,channels), dtype=np.float32)).cuda()

    def forward(self, x): 
        if self.is_train:
            batch_size = x.size()[0]
            feature_shape_size = x.size()[2] * x.size()[3]
            sig_sqr_sum = Variable(torch.zeros(batch_size, self.channels)).cuda()
            mu_b = x.mean(0).mean(2).mean(3).view(1, self.channels)
            xview = x.view(batch_size, self.channels, feature_shape_size)

            for j in range(self.channels):
                mu_b_0_j = mu_b[0, j].repeat(feature_shape_size)
                for i in range(batch_size):
                    sig_sqr_sum[i,j] = ((xview[i,j] - mu_b_0_j) ** 2).mean()
            sigma_b = sig_sqr_sum.mean(0)
            sigma_b += self.eps
            sigma_b = torch.sqrt(sigma_b)
            if self.is_unlock:
                r = sigma_b.data / self.sigma
                r.clamp_(1.0/rmax, rmax)
                d = (mu_b.data - self.mean) / (self.sigma + torch.sqrt(eps) )
                d.clamp_(-self.dmax, self.dmax)
            else:
                r = torch.zeros(1, self.channels) + 1.0 
                d = torch.zeros(1, self.channels)
            x_hat = Variable(torch.zeros(x.size()).cuda())
            for j in range(self.channels):
                mu_b_0_j = mu_b[0, j].repeat(feature_shape_size).view(x.size()[2], x.size()[3])
                sigma_b_0_j = sigma_b[0, j].repeat(feature_shape_size).view(x.size()[2], x.size()[3])
                for i in range(batch_size):
                    x_hat_i_j = x[i,j,:,:].clone()
                    x_hat_i_j -= mu_b_0_j
                    x_hat_i_j /= sigma_b_0_j
                    x_hat_i_j *= r[0, j]
                    x_hat_i_j += d[0, j]
                    x_hat[i,j,:,:] = x_hat_i_j
                    self.mean += self.lr * (mu_b.data - self.mean)
            self.sigma += self.lr * (sigma_b.data - self.sigma)
        else:
            mu_b = Variable(self.mean)
            sigma_b = Variable(self.sigma)
            for j in range(self.channels):
                mu_b_0_j = mu_b[0, j].repeat(feature_shape_size).view(x.size()[2], x.size()[3])
                sigma_b_0_j = sigma_b[0, j].repeat(feature_shape_size).view(x.size()[2], x.size()[3])
                for i in range(batch_size):
                    x_hat_i_j = x[i,j,:,:].clone()
                    x_hat_i_j -= mu_b_0_j
                    x_hat_i_j /= sigma_b_0_j
                    x_hat_i_j *= r[0, j]
                    x_hat_i_j += d[0, j]
                    x_hat[i,j,:,:] = x_hat_i_j
        return x_hat

What should I do to solve this problem? Thanks.

Hi,
Are you running the latest version of pytorch? Does this error still occurs with the latest distribution?
Because some bugs causing this error have been fixed recently.

Hi, what is the clamp_ function doing? Is it differentiable? Also, can you show me the whole implementation of BatchReNorm layer?

Looking at the BatchRenorm paper, the author states that gradients are not propagated through the clip functions. He uses a stop_gradient function to denote this in Algorithm 1. I’m not sure how to do this in PyTorch though. I’d would think that it would be pretty easy to do.

Reference: https://arxiv.org/pdf/1702.03275.pdf

You can detach a Variable, which stops the gradient flow. Variable.detach()