Hi all!
I have an implementation of Virtual Batch Normalization (VBN) that I’m using on each convolution of a discriminator network similar to DCGAN, more specifically, this is the discriminator network on the SEGAN paper.
Let’s assume disc is the state of the Discriminator after after a forward pass and that d_loss is the loss of disc after the same forward pass.
We then back-propagate the loss and take a peek at the norm of the gradients to find out that p.grad is equal to None
The code works normally if I remove the VBN after each conv layeror OR if in VBN’s forward method I set x.data = out and return x.
Any help is appreciated!!!
d_loss.backward()
d_grad_norm = 0
for k, p in disc.named_parameters():
d_grad_norm += p.grad.data.norm()
class Discriminator(nn.Module):
def __init__(self, ndf, kernel_size):
super(Discriminator, self).__init__()
self.encoder = nn.ModuleList([
# 16
nn.Conv1d(1, ndf, kernel_size, 2, 15, 1)
VBN(ndf),
# 32
nn.Conv1d(ndf, ndf * 2, kernel_size, 2, 15, 1),
VBN(ndf*2),
# more convolutions... and finally
# Linear for output
nn.Linear(8, 1)
])
class VBN(Module):
"""
Virtual Batch Normalization
"""
def __init__(self, n_features, epsilon=1e-5):
super(VBN, self).__init__()
assert isinstance(epsilon, float)
# batch statistics
self.epsilon = epsilon
self.mean = torch.zeros(1, n_features, 1)
self.mean_sq = torch.zeros(1, n_features, 1)
self.batch_size = None
# reference output
self.reference_output = None
def initialize(self, x):
# compute batch statistics
# self.mean = torch.mean(x, [0, 2], keepdim=True)
# self.mean_sq = torch.mean(x**2, [0, 2], keepdim=True)
self.mean = x.data.mean(2).mean(0).resize_(1, x.size(1), 1)
self.mean_sq = (x.data**2).mean(2).mean(0).resize_(1, x.size(1), 1)
self.batch_size = x.size(0)
assert x is not None
assert self.mean is not None
assert self.mean_sq is not None
# compute reference output
out = self._normalize(x, self.mean, self.mean_sq)
self.reference_output = out
def forward(self, x):
if self.reference_output is None:
self.initialize(x)
new_coeff = 1. / (self.batch_size + 1.)
old_coeff = 1. - new_coeff
# new_mean = torch.mean(x, [0, 2], keep_dims=True)
# new_mean_sq = torch.mean(x**2, [0, 2], keep_dims=True)
new_mean = x.data.mean(2).mean(0).resize_as_(self.mean)
new_mean_sq = (x.data**2).mean(2).mean(0).resize_as_(self.mean_sq)
mean = new_coeff * new_mean + old_coeff * self.mean
mean_sq = new_coeff * new_mean_sq + old_coeff * self.mean_sq
out = self._normalize(x, mean, mean_sq)
return Variable(out)
def _normalize(self, x, mean, mean_sq):
assert self.epsilon is not None
assert mean_sq is not None
assert mean is not None
assert len(x.size()) == 3
gamma = torch.normal(means=torch.ones(1, x.size(1), 1), std=0.02)
gamma = gamma.float().cuda(async=True)
beta = torch.cuda.FloatTensor(1, x.size(1), 1).fill_(0)
std = torch.sqrt(self.epsilon + mean_sq - mean**2)
out = x.data - mean
out = out / std
out = out * gamma
out = out + beta
return out
def __repr__(self):
return ('{name}({num_features}, eps={eps}, mean={mean}, mean_sq={mean_sq}'.format(
name=self.__class__.__name__, **self.__dict__))