Hi all,
I am training a simple LSTM on Language Modeling task on the Penn Tree Bank. I have tied the weights of the encoder and the decoder.
However the caveat is during training I would like the embeddings to have unit L^2 norm. The following is my code but the losses do not decrease during training,
‘’’
class RNNModel(nn.Module):
“”“Container module with an encoder, a recurrent module, and a decoder.”""
def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.2):
super(RNNModel, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp) # Token2Embeddings
self.rnn = nn.LSTM(ninp, ninp, nlayers, dropout=dropout)
self.decoder = nn.Linear(nhid, ntoken, bias=False)
# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
self.decoder.weight = self.encoder.weight
self.init_weights()
self.nhid = nhid
self.nlayers = nlayers
def init_weights(self):
initrange = 0.1
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
nn.init.zeros_(self.decoder.weight)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)
def forward(self, input, hidden):
self.encoder.weight.data = F.normalize(self.encoder.weight.data, p=2, dim=1)
emb = self.drop(self.encoder(input))
emb = F.normalize(emb, p=2, dim=1)
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
output = F.normalize(output.view(output.size(0)*output.size(1), output.size(2)), p=2, dim=1)
decoded = self.decoder(output) #want here the dot product of unit length vectors
return decoded, hidden
def init_hidden(self, bsz):
weight = next(self.parameters()).data
return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()),
Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()))
‘’’
To explain my point further: when the decoder is doing the matrix multiplication with it’s weights and the output, I would want the weights to have unit norm. With my implementation as above, the losses do not go down. Any help will be highly appreciated.