Hello ,
I am building a language model with an Embedding, a LSTM and a Linear module.
I want to change a bit the output computation: the linear module will project into the embedding space from the âhiddenâ space. Then I want to compute output probabilities as the (euclidean) distance between embeddings and the output of the model.
Letâs consider the following (simplified) module:
def __init__(self, voc_size, embed_size, hidden_size, initial_embedding):
super(RNNLanguageModel, self).__init__()
self.hidden_size = hidden_size
self.embed_size = embed_size
self.voc_size = voc_size
self.embedder = nn.Embedding(voc_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, 2, batch_first=True)
self.decoder = nn.Linear(hidden_size, embed_size)
self.pairdist = nn.PairwiseDistance(2)
def forward(self, input_seq, hidden):
embedded_seq = self.embedder(input_seq)
encoded_seq, hidden = self.lstm(embedded_seq, hidden)
decoded_seq = self.decoder(encoded_seq.contiguous().view(-1, self.hidden_size))
# Problem here: decoded_sed is 100*300 and self.embedder.weight is 10000*300
# First Try
probs = torch.stack([torch.norm(torch.add(self.embedder.weight, -dec), p=2, dim=1) for dec in decoded_seq])
# Second Try
probs = torch.cat([self.pairdist(dec.expand_as(self.embedder.weight), self.embedder.weight) for dec in decoded_seq])
# Third Try
probs = Variable(torch.cuda.FloatTensor(decoded_seq.size(0), self.voc_size))
for i, dec in enumerate(decoded_seq):
probs[i] = self.pairdist(dec.expand_as(self.embedder.weight), self.embedder.weight)
return probs.view(input_seq.size(0), input_seq.size(1), -1), hidden
I tried to get the probabilities for each word, however I get an out of memory error with this method. I tried using the PairwiseDistance function but it is not optimized for this use (no broadcasting) and I get an oom too.
I used a batch size of 20, a seq size of 50, a voc size of 10000 and an embedding size of 300.
I think what I am doing is very memory consuming, and particularly the âaddâ that creates a new 30010000 tensor for every 2050 words in the batchâŚ
Any idea how I could do this efficiently with pytorch ?
Update:
It appears that this is working if I detach the variable decoded_seq
and self.embedder.weight
(which I obviously donât want). Why is this happening ?
probs = Variable(torch.cuda.FloatTensor(decoded_seq.size(0), self.voc_size))
for i, dec in enumerate(decoded_seq):
probs[i] = self.pairdist(dec.expand_as(self.embedder.weight), self.embedder.weight)