Hello, I am working on a text CNN autoencoder, taking reference from https://github.com/ymym3412/textcnn-conv-deconv-pytorch. I modified the embedding layer so both the encoder and decoder share the same space. But I keep on getting huge and mostly NaN loss when I am training. Normalisation is already included. Any idea why? The input data is a padded 1D tensor of indexes for the corresponding vector in the embedding layer
Thanks a lot
class ConvEncoder(nn.Module):
def __init__(self, embedDim, maxLength, filterSize, filterShape, latentSize):
super(ConvEncoder, self).__init__()
self.embedDim = embedDim
self.maxLength = maxLength
self.filterSize = filterSize
self. filterShape = filterShape
self.latentSize = latentSize
t1 = maxLength + 2 * (filterShape - 1)
t2 = int(math.floor((t1 - filterShape) / 2) + 1) # "2" means stride size
t3 = int(math.floor((t2 - filterShape) / 2) + 1) - 2
#self.embed = embedding
self.conv1 = nn.Conv2d(1, filterSize, kernel_size=(filterShape, embedDim), stride=2)
self.batchNorm1 = nn.BatchNorm2d(filterSize)
self.conv2 = nn.Conv2d(filterSize, filterSize*2, kernel_size=(filterShape, 1), stride=2)
self.batchNorm2 = nn.BatchNorm2d(filterSize*2)
self.conv3 = nn.Conv2d(filterSize*2, latentSize, kernel_size=(t3, 1), stride=2)
# weight initialize for conv layer
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
def forward(self, x):
# x.size() is (L, emb_dim) if batch_size is 1.
# So interpolate x's dimension if batch_size is 1.
if len(x.size()) < 3:
x = x.view(1, *x.size())
# reshape for convolution layer
if len(x.size()) < 4:
x = x.view(x.size()[0], 1, x.size()[1], x.size()[2])
conv1Output = self.conv1(x)
h1 = F.relu(self.batchNorm1(conv1Output))
conv2Output = self.conv2(h1)
h2 = F.relu(self.batchNorm2(conv2Output))
h3 = F.relu(self.conv3(h2))
return h3
class ConvDecoder(nn.Module):
def __init__(self, tau, embedDim, maxLength, filterSize, filterShape, latentSize):
super(ConvDecoder, self).__init__()
self.tau = tau
self.maxLength = maxLength
self.embedDim = embedDim
#self.embed = embedding
t1 = maxLength + 2 * (filterShape - 1)
t2 = int(math.floor((t1 - filterShape) / 2) + 1) # "2" means stride size
t3 = int(math.floor((t2 - filterShape) / 2) + 1) - 2
self.deconv1 = nn.ConvTranspose2d(latentSize, filterSize * 2, kernel_size=(t3, 1), stride=2)
self.batchNorm1 = nn.BatchNorm2d(filterSize * 2)
self.deconv2 = nn.ConvTranspose2d(filterSize * 2, filterSize, kernel_size=(filterShape, 1), stride=2)
self.batchNorm2 = nn.BatchNorm2d(filterSize)
self.deconv3 = nn.ConvTranspose2d(filterSize, 1, kernel_size=(filterShape, embedDim),output_padding=(1,0), stride=2)
# weight initialize for conv_transpose layer
for m in self.modules():
if isinstance(m, nn.ConvTranspose2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
def forward(self, h3):
h2 = F.relu(self.batchNorm1(self.deconv1(h3)))
h1 = F.relu(self.batchNorm2(self.deconv2(h2)))
xHat = F.relu(self.deconv3(h1))
xHat = xHat.squeeze()
# x.size() is (L, emb_dim) if batch_size is 1.
# So interpolate x's dimension if batch_size is 1.
if len(xHat.size()) < 3:
xHat = xHat.view(1, *xHat.size())
# normalize
normXHat = torch.norm(xHat, 2, dim=2, keepdim=True)
recXHat = xHat / normXHat
return recXHat
class ConvAutoencoder(nn.Module):
def __init__(self, embedding, tau, embedDim, maxLength, filterSize, filterShape, latentSize):
super(ConvAutoencoder, self).__init__()
self.embed = embedding
self.tau = tau
self.encoder = ConvEncoder(embedDim, maxLength, filterSize, filterShape, latentSize)
self.decoder = ConvDecoder(tau, embedDim, maxLength, filterSize, filterShape, latentSize)
def forward(self, x):
x = x.type(torch.long)
x = self.embed(x)
encodedInput = self.encoder(x)
recXHat = self.decoder(encodedInput)
# compute probability
#oEmbedSize = originalEmbeddings.size()
#normW = Variable(originalEmbeddings).permute(0, 2, 1)
normW = Variable(self.embed.weight.data).t()
#probLogits = torch.bmm(recXHat, normW) / self.tau
probLogits = torch.bmm(recXHat, normW.expand(recXHat.size(0), *normW.size())) / self.tau
logProb = F.log_softmax(probLogits, dim=2)
return logProb
#return decodedOutput
def computeCrossEntropy(log_prob, target):
# compute reconstruction loss using cross entropy
#Original loss calculation
#[torch.sum(l) for l in loss] - takes > 60% of the time
loss = [F.nll_loss(sentence_emb_matrix, word_ids, size_average=False) for sentence_emb_matrix, word_ids in zip(log_prob, target)]
average_loss = sum([torch.sum(l) for l in loss]) / log_prob.size()[0]
return average_loss
Update: apparently I have to set the learning rate to as low as 1e-8 to get a non-nan loss but the implementation I was referring to uses a decaying learning rate of 0.01. What’s wrong??? It becomes nan again when I use DataParallel instead