Normalized embeddings from LSTM

Hi all,

I am training a simple LSTM on Language Modeling task on the Penn Tree Bank. I have tied the weights of the encoder and the decoder.

However the caveat is during training I would like the embeddings to have unit L^2 norm. The following is my code but the losses do not decrease during training,

class RNNModel(nn.Module):
“”“Container module with an encoder, a recurrent module, and a decoder.”""

def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.2):
    super(RNNModel, self).__init__()
    self.drop = nn.Dropout(dropout)
    self.encoder = nn.Embedding(ntoken, ninp) # Token2Embeddings
    self.rnn = nn.LSTM(ninp, ninp, nlayers, dropout=dropout)
    self.decoder = nn.Linear(nhid, ntoken, bias=False)

    # Optionally tie weights as in:
    # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
    # and
    # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)

    self.decoder.weight = self.encoder.weight


    self.nhid = nhid
    self.nlayers = nlayers

def init_weights(self):
    initrange = 0.1
    nn.init.uniform_(self.encoder.weight, -initrange, initrange)
    nn.init.uniform_(self.decoder.weight, -initrange, initrange)

def forward(self, input, hidden): = F.normalize(, p=2, dim=1)
    emb = self.drop(self.encoder(input))
    emb = F.normalize(emb, p=2, dim=1)
    output, hidden = self.rnn(emb, hidden)
    output = self.drop(output)
    output = F.normalize(output.view(output.size(0)*output.size(1), output.size(2)), p=2, dim=1)
    decoded = self.decoder(output) #want here the dot product of unit length vectors
    return decoded, hidden

def init_hidden(self, bsz):
    weight = next(self.parameters()).data
    return (Variable(, bsz, self.nhid).zero_()),
                Variable(, bsz, self.nhid).zero_()))

To explain my point further: when the decoder is doing the matrix multiplication with it’s weights and the output, I would want the weights to have unit norm. With my implementation as above, the losses do not go down. Any help will be highly appreciated.