Text autoencoder NaN loss after first batch

Hello, I am working on a text CNN autoencoder, taking reference from https://github.com/ymym3412/textcnn-conv-deconv-pytorch. I modified the embedding layer so both the encoder and decoder share the same space. But I keep on getting huge and mostly NaN loss when I am training. Normalisation is already included. Any idea why? The input data is a padded 1D tensor of indexes for the corresponding vector in the embedding layer

Thanks a lot

class ConvEncoder(nn.Module):
    def __init__(self, embedDim, maxLength, filterSize, filterShape, latentSize):
        super(ConvEncoder, self).__init__()
        self.embedDim = embedDim
        self.maxLength = maxLength
        self.filterSize = filterSize
        self. filterShape = filterShape
        self.latentSize = latentSize
        
        t1 = maxLength + 2 * (filterShape - 1)
        t2 = int(math.floor((t1 - filterShape) / 2) + 1) # "2" means stride size
        t3 = int(math.floor((t2 - filterShape) / 2) + 1) - 2
        
        
        #self.embed = embedding
        self.conv1 = nn.Conv2d(1, filterSize, kernel_size=(filterShape, embedDim), stride=2)
        self.batchNorm1 = nn.BatchNorm2d(filterSize)
        self.conv2 = nn.Conv2d(filterSize, filterSize*2, kernel_size=(filterShape, 1), stride=2)
        self.batchNorm2 = nn.BatchNorm2d(filterSize*2)
        self.conv3 = nn.Conv2d(filterSize*2, latentSize, kernel_size=(t3, 1), stride=2)
        # weight initialize for conv layer
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
        
    def forward(self, x):
        # x.size() is (L, emb_dim) if batch_size is 1.
        # So interpolate x's dimension if batch_size is 1.
        if len(x.size()) < 3:
            x = x.view(1, *x.size())
        # reshape for convolution layer
        if len(x.size()) < 4:
            x = x.view(x.size()[0], 1, x.size()[1], x.size()[2])
        conv1Output = self.conv1(x)
        h1 = F.relu(self.batchNorm1(conv1Output))
        conv2Output = self.conv2(h1)
        h2 = F.relu(self.batchNorm2(conv2Output))
        h3 = F.relu(self.conv3(h2))

        return h3
    
class ConvDecoder(nn.Module):
    def __init__(self, tau, embedDim, maxLength, filterSize, filterShape, latentSize):
        super(ConvDecoder, self).__init__()
        self.tau = tau
        self.maxLength = maxLength
        self.embedDim = embedDim
        #self.embed = embedding
        

        
        t1 = maxLength + 2 * (filterShape - 1)
        t2 = int(math.floor((t1 - filterShape) / 2) + 1) # "2" means stride size
        t3 = int(math.floor((t2 - filterShape) / 2) + 1) - 2
        
        self.deconv1 = nn.ConvTranspose2d(latentSize, filterSize * 2, kernel_size=(t3, 1), stride=2)
        self.batchNorm1 = nn.BatchNorm2d(filterSize * 2)
        self.deconv2 = nn.ConvTranspose2d(filterSize * 2, filterSize, kernel_size=(filterShape, 1), stride=2)
        self.batchNorm2 = nn.BatchNorm2d(filterSize)
        self.deconv3 = nn.ConvTranspose2d(filterSize, 1, kernel_size=(filterShape, embedDim),output_padding=(1,0), stride=2)

        # weight initialize for conv_transpose layer
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
        
    def forward(self, h3):
        h2 = F.relu(self.batchNorm1(self.deconv1(h3)))
        h1 = F.relu(self.batchNorm2(self.deconv2(h2)))
        xHat = F.relu(self.deconv3(h1))
        xHat = xHat.squeeze()

        # x.size() is (L, emb_dim) if batch_size is 1.
        # So interpolate x's dimension if batch_size is 1.
        if len(xHat.size()) < 3:
            xHat = xHat.view(1, *xHat.size())
        # normalize
        normXHat = torch.norm(xHat, 2, dim=2, keepdim=True)
        recXHat = xHat / normXHat
        
        return recXHat
class ConvAutoencoder(nn.Module):
    def __init__(self, embedding, tau, embedDim, maxLength, filterSize, filterShape, latentSize):
        super(ConvAutoencoder, self).__init__()
        self.embed = embedding
        self.tau = tau
        self.encoder = ConvEncoder(embedDim, maxLength, filterSize, filterShape, latentSize)
        self.decoder = ConvDecoder(tau, embedDim, maxLength, filterSize, filterShape, latentSize)
    
    def forward(self, x):
        x = x.type(torch.long)
        x = self.embed(x)
        encodedInput = self.encoder(x)
        recXHat = self.decoder(encodedInput)
        # compute probability
        #oEmbedSize = originalEmbeddings.size()
        #normW = Variable(originalEmbeddings).permute(0, 2, 1)
        normW = Variable(self.embed.weight.data).t()
        
        #probLogits = torch.bmm(recXHat, normW) / self.tau
        probLogits = torch.bmm(recXHat, normW.expand(recXHat.size(0), *normW.size())) / self.tau
        
        logProb = F.log_softmax(probLogits, dim=2)
        return logProb
        #return decodedOutput
def computeCrossEntropy(log_prob, target):
    # compute reconstruction loss using cross entropy


    #Original loss calculation
    #[torch.sum(l) for l in loss] - takes > 60% of the time
    loss = [F.nll_loss(sentence_emb_matrix, word_ids, size_average=False) for sentence_emb_matrix, word_ids in zip(log_prob, target)]
    average_loss = sum([torch.sum(l) for l in loss]) / log_prob.size()[0]
    
    return average_loss

Update: apparently I have to set the learning rate to as low as 1e-8 to get a non-nan loss but the implementation I was referring to uses a decaying learning rate of 0.01. What’s wrong??? It becomes nan again when I use DataParallel instead