Location Model Memory Growth

Hi There!

Here is a simple sample code of a location based attention. It takes into account the last_alignment to compute the next alignment.

import os

from torch import nn

import torch

class LocationAttention(nn.Module):
    """ Simple mock of a location sensative alignment similar to this paper:

    def __init__(self, hidden_size):
        self.project_alignment = nn.Linear(1, hidden_size)
        self.score = nn.Linear(hidden_size, 1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, last_alignment):
        last_alignment = last_alignment.unsqueeze(2)
        last_alignment = self.project_alignment(last_alignment)
        score = self.score(last_alignment)
        score = score.squeeze(2)
        last_alignment = self.softmax(score)
        # last_alignment = last_alignment.detach()
        return last_alignment

# Parameters
batch_size = 64
num_tokens = 150
num_iter = 3000
hidden_size = 128

# Instantiate
attention = LocationAttention(hidden_size=hidden_size)

# Run
last_alignment = torch.FloatTensor(batch_size, num_tokens).zero_().cuda()
for i in range(num_iter):  # WARNING: Uses 11 gig of memory in 1600 iterations
    last_alignment = attention(last_alignment)
    if i % 100 == 0:
        print('Frame: %d' % (i,))

However, regardless of its simplicity and size, within less than 1600 iterations, the memory grows to 11 gigabytes.

My investigation has led me to find that adding detach stems the problem; therefore, the memory consumed is due to gradient information. It is not clear to me how an alignment of size [64, 128] grows to 11 gigabytes of memory. An LSTM, a much more complicated unit, does not use this much memory after so many iterations.

Any ideas?