I tried to implement Batch Renormalization(arXiv 1702.03275) in PyTorch. The program stop when compute the gradients. Trackback information is attached below:
Traceback (most recent call last):
File "cifar.py", line 187, in <module>
loss.backward()
File "/usr/local/lib/python3.5/dist-packages/torch/autograd/variable.py", line 158, in backward
self._execution_engine.run_backward((self,), (gradient,), retain_variables)
RuntimeError: could not compute gradients for some functions (View, ConvNd)
My implementation of batch renormalization is shown below:
class BatchRenorm2d(nn.Module):
def __init__(self, channels, eps = 1e-5, rmax=3, dmax=5, lr=0.001):
super(BatchRenorm2d, self).__init__()
self.is_train = True
self.is_unlock = False
self.eps = eps
self.channels = channels
self.rmax = rmax
self.dmax = dmax
self.lr = lr
self.sigma = torch.from_numpy(np.zeros((1, channels, 1, 1), dtype=np.float32)).cuda()
self.mean = torch.from_numpy(np.zeros((1,channels), dtype=np.float32)).cuda()
def forward(self, x):
if self.is_train:
batch_size = x.size()[0]
feature_shape_size = x.size()[2] * x.size()[3]
sig_sqr_sum = Variable(torch.zeros(batch_size, self.channels)).cuda()
mu_b = x.mean(0).mean(2).mean(3).view(1, self.channels)
xview = x.view(batch_size, self.channels, feature_shape_size)
for j in range(self.channels):
mu_b_0_j = mu_b[0, j].repeat(feature_shape_size)
for i in range(batch_size):
sig_sqr_sum[i,j] = ((xview[i,j] - mu_b_0_j) ** 2).mean()
sigma_b = sig_sqr_sum.mean(0)
sigma_b += self.eps
sigma_b = torch.sqrt(sigma_b)
if self.is_unlock:
r = sigma_b.data / self.sigma
r.clamp_(1.0/rmax, rmax)
d = (mu_b.data - self.mean) / (self.sigma + torch.sqrt(eps) )
d.clamp_(-self.dmax, self.dmax)
else:
r = torch.zeros(1, self.channels) + 1.0
d = torch.zeros(1, self.channels)
x_hat = Variable(torch.zeros(x.size()).cuda())
for j in range(self.channels):
mu_b_0_j = mu_b[0, j].repeat(feature_shape_size).view(x.size()[2], x.size()[3])
sigma_b_0_j = sigma_b[0, j].repeat(feature_shape_size).view(x.size()[2], x.size()[3])
for i in range(batch_size):
x_hat_i_j = x[i,j,:,:].clone()
x_hat_i_j -= mu_b_0_j
x_hat_i_j /= sigma_b_0_j
x_hat_i_j *= r[0, j]
x_hat_i_j += d[0, j]
x_hat[i,j,:,:] = x_hat_i_j
self.mean += self.lr * (mu_b.data - self.mean)
self.sigma += self.lr * (sigma_b.data - self.sigma)
else:
mu_b = Variable(self.mean)
sigma_b = Variable(self.sigma)
for j in range(self.channels):
mu_b_0_j = mu_b[0, j].repeat(feature_shape_size).view(x.size()[2], x.size()[3])
sigma_b_0_j = sigma_b[0, j].repeat(feature_shape_size).view(x.size()[2], x.size()[3])
for i in range(batch_size):
x_hat_i_j = x[i,j,:,:].clone()
x_hat_i_j -= mu_b_0_j
x_hat_i_j /= sigma_b_0_j
x_hat_i_j *= r[0, j]
x_hat_i_j += d[0, j]
x_hat[i,j,:,:] = x_hat_i_j
return x_hat
What should I do to solve this problem? Thanks.