OOM computing a Batch distance for Language Model

Hello :slight_smile:,

I am building a language model with an Embedding, a LSTM and a Linear module.
I want to change a bit the output computation: the linear module will project into the embedding space from the “hidden” space. Then I want to compute output probabilities as the (euclidean) distance between embeddings and the output of the model.

Let’s consider the following (simplified) module:

    def __init__(self, voc_size, embed_size, hidden_size, initial_embedding):
        super(RNNLanguageModel, self).__init__()
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.voc_size = voc_size

        self.embedder = nn.Embedding(voc_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, 2, batch_first=True)
        self.decoder = nn.Linear(hidden_size, embed_size)
        self.pairdist = nn.PairwiseDistance(2)

    def forward(self, input_seq, hidden):
        embedded_seq = self.embedder(input_seq)
        encoded_seq, hidden = self.lstm(embedded_seq, hidden)
        decoded_seq = self.decoder(encoded_seq.contiguous().view(-1, self.hidden_size))

        # Problem here: decoded_sed is 100*300 and self.embedder.weight is 10000*300
        # First Try
        probs = torch.stack([torch.norm(torch.add(self.embedder.weight, -dec), p=2, dim=1) for dec in decoded_seq])
        # Second Try
        probs = torch.cat([self.pairdist(dec.expand_as(self.embedder.weight), self.embedder.weight) for dec in decoded_seq])
        # Third Try
        probs = Variable(torch.cuda.FloatTensor(decoded_seq.size(0), self.voc_size))
        for i, dec in enumerate(decoded_seq):
            probs[i] = self.pairdist(dec.expand_as(self.embedder.weight), self.embedder.weight)

        return probs.view(input_seq.size(0), input_seq.size(1), -1), hidden

I tried to get the probabilities for each word, however I get an out of memory error with this method. I tried using the PairwiseDistance function but it is not optimized for this use (no broadcasting) and I get an oom too.

I used a batch size of 20, a seq size of 50, a voc size of 10000 and an embedding size of 300.
I think what I am doing is very memory consuming, and particularly the ‘add’ that creates a new 30010000 tensor for every 2050 words in the batch…

Any idea how I could do this efficiently with pytorch ?


Update:
It appears that this is working if I detach the variable decoded_seq and self.embedder.weight (which I obviously don’t want). Why is this happening ?

        probs = Variable(torch.cuda.FloatTensor(decoded_seq.size(0), self.voc_size))
        for i, dec in enumerate(decoded_seq):
            probs[i] = self.pairdist(dec.expand_as(self.embedder.weight), self.embedder.weight)