Why is this seq2seq with attention model taking up so much memory?

I’m implementing a standard seq2seq with attention but im getting crazy memory usage while training. Anecdotally (experience this model in tensorflow), I know that this model should be able to run under 8 GB, if not less, with the following hyperparameters: word dim = 100, hidden size = 256, attention size = 256, vocab size = 30000, batch size = 32. My network summary looks like this:

Seq2Seq_Attn (
(embeddings): Embedding(30000, 100)
(encoder_lstm): EncoderLSTM (
(lstm): LSTM(100, 256, batch_first=True, dropout=0.2)
)
(decoder_lstm): DecoderLSTM_Attn (
(lstm): LSTM(100, 256, batch_first=True, dropout=0.2)
(attn): Attn (
  (W_attn): Linear (512 -> 256)
  (v_attn): Linear (256 -> 1)
)
(output_attn): Linear (256 -> 256)
(output): Linear (256 -> 30000)
)
)

The code for my network modules:

class EncoderLSTM(nn.Module):
def __init__(self, input_size, hidden_size, n_layers=1, dropout=0.0):
    super(EncoderLSTM, self).__init__()
    self.n_layers = n_layers
    self.hidden_size = hidden_size

    self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, dropout=dropout)

def forward(self, inpt, hidden):
    output = inpt
    for i in range(self.n_layers):
        output, hidden = self.lstm(output, hidden)
    return output, hidden

def initHidden(self, batch_size):
    result = Variable(torch.zeros(1, batch_size, self.hidden_size))
    if USE_CUDA:
        return result.cuda(GPU)
    else:
        return result

class DecoderLSTM_Attn(nn.Module):
def __init__(self, embedding_dim, hidden_size, output_size, n_layers=1, dropout=0.0):
    super(DecoderLSTM_Attn, self).__init__()
    self.n_layers = n_layers
    self.hidden_size = hidden_size

    self.lstm = nn.LSTM(embedding_dim, hidden_size, batch_first=True, dropout=dropout)
    self.attn = Attn(hidden_size)
    self.output_attn = nn.Linear(hidden_size, hidden_size)
    self.output = nn.Linear(hidden_size, output_size)

def forward(self, inpt, hidden, enc_outputs):
    enc_outputs = Variable(torch.stack(enc_outputs)).squeeze(2)
    if USE_CUDA:
        enc_outputs = enc_outputs.cuda(GPU)
    
    output = inpt
    for i in range(self.n_layers):
        output, hidden = self.lstm(output, hidden)
        output = output.squeeze(1)
        
        # attention
        attn_weights = self.attn(output, enc_outputs) # B x L
        attn_weights = attn_weights.view(attn_weights.size(1), attn_weights.size(0)) # L x B
        context_h = torch.sum(torch.mul(attn_weights.unsqueeze(2).expand(enc_outputs.size()), 
                                        enc_outputs), 0) # 1 x B x H
        context_h = context_h.squeeze(0) # B x H
        output = self.output_attn(output + context_h) # B x H
        
    output = self.output(output)
    return output, hidden

def initHidden(self, batch_size):
    result = Variable(torch.zeros(1, batch_size, self.hidden_size))
    if USE_CUDA:
        return result.cuda(GPU)
    else:
        return result

class Attn(nn.Module):
def __init__(self, hidden_size):
    super(Attn, self).__init__()
    self.hidden_size = hidden_size
    
    self.W_attn = nn.Linear(2*hidden_size, hidden_size)
    self.v_attn = nn.Linear(hidden_size, 1)
    
def forward(self, hidden, encoder_outputs):
    max_len = encoder_outputs.size(0)
    this_batch_size = encoder_outputs.size(1)

    # concat hidden vector to each encoder vector
    concat_states = torch.cat([hidden.expand(encoder_outputs.size()), encoder_outputs], 2) # L x B x 2*H
    concat_states = concat_states.view(-1, concat_states.size(-1)) # L*B x 2*H
    W_attn = self.W_attn(concat_states) # L*B x H
    attn_energies = self.v_attn(W_attn) # L*B x 1
    attn_energies = attn_energies.view(this_batch_size, max_len) # B x L
    return F.softmax(attn_energies) # B x L

class Seq2Seq_Attn(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_size, dropout=0.0, n_layers=1, pretrained_embeddings=None,
             trainable_embeddings=True):
    super(Seq2Seq_Attn, self).__init__()
    self.embedding_dim = embedding_dim
    self.hidden_size = hidden_size
    
    self.embeddings = nn.Embedding(vocab_size, embedding_dim)
    if pretrained_embeddings is not None:
        self.embeddings.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
        self.embeddings.weight.requires_grad = trainable_embeddings
    
    self.encoder_lstm = EncoderLSTM(embedding_dim, hidden_size, n_layers=n_layers, dropout=dropout)
    self.decoder_lstm = DecoderLSTM_Attn(embedding_dim, hidden_size, vocab_size, n_layers=n_layers, dropout=dropout)

def forward(self, enc_inpt, dec_inpt, teacher_forcing_ratio=1.0):
    enc_inpt_embd = self.embeddings(enc_inpt)
    dec_inpt_embd = self.embeddings(dec_inpt)
    
    enc_inpt_len = enc_inpt_embd.size()[1]
    dec_inpt_len = dec_inpt_embd.size()[1]
    batch_size = enc_inpt_embd.size()[0]
    
    # encoder LSTM
    enc_outputs = []
    enc_outputs_weighted = []
    hidden = (self.encoder_lstm.initHidden(batch_size), self.encoder_lstm.initHidden(batch_size))
    for i in range(enc_inpt_len):
        enc_inpt_t = enc_inpt_embd.narrow(1, i, 1)
        enc_output, hidden = self.encoder_lstm(enc_inpt_t, hidden)
        enc_outputs.append(enc_output.data)
        
    # decoder LSTM Attn
    dec_outputs= []
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    for i in range(dec_inpt_len):
        if i == 0:
            dec_inpt_t = dec_inpt_embd.narrow(1, i, 1)
        else:     
            if use_teacher_forcing:
                dec_inpt_t = dec_inpt_embd.narrow(1, i, 1)
            else:
                softmax = F.softmax(dec_output)
                _, prediction = torch.max(softmax, 1)
                dec_inpt_t = self.embeddings(prediction)
        
        dec_output, hidden = self.decoder_lstm(dec_inpt_t, hidden, enc_outputs)
        dec_outputs.append(dec_output)
    return dec_outputs

def deploy(self, enc_inpt, dec_len, go_token):
    enc_inpt_embd = self.embeddings(enc_inpt)
    enc_inpt_len = enc_inpt_embd.size()[1]
    batch_size = enc_inpt_embd.size()[0]
    go_var = Variable(go_token*torch.ones(batch_size, 1).long())
    if USE_CUDA:
        go_var = go_var.cuda(GPU)
    go_embd = self.embeddings(go_var)
    
    # encoder LSTM
    enc_outputs = []
    enc_outputs_weighted = []
    hidden = (self.encoder_lstm.initHidden(batch_size), self.encoder_lstm.initHidden(batch_size))
    for i in range(enc_inpt_len):
        enc_inpt_t = enc_inpt_embd.narrow(1, i, 1)
        enc_output, hidden = self.encoder_lstm(enc_inpt_t, hidden)
        enc_outputs.append(enc_output.data)
        
    # decoder LSTM Attn
    dec_outputs= []
    for i in range(dec_len):
        if i == 0:
            dec_inpt_t = go_embd
        else:     
            softmax = F.softmax(dec_output)
            _, prediction = torch.max(softmax, 1)
            dec_inpt_t = self.embeddings(prediction)
            
        dec_output, hidden = self.decoder_lstm(dec_inpt_t, hidden, enc_outputs)
        dec_outputs.append(dec_output)
    return dec_outputs

def num_flat_features(self, x):
    size = x.size()[1:]  # all dimensions except the batch dimension
    num_features = 1
    for s in size:
        num_features *= s
    return num_features

The problem is that when I start training, my network is using 13 GB and my GPU quickly runs out of memory. I initialize my network with:

net = Seq2Seq_Attn(vocab_lookup.num_words, embd_dim, hidden_size, dropout=dropout, 
               pretrained_embeddings=pretrained_embeddings, trainable_embeddings=True)
criterion = nn.CrossEntropyLoss()
if USE_CUDA:
    net = net.cuda(GPU)
    criterion = criterion.cuda(GPU)

optimizer = optim.Adam(net.parameters(), lr=lr)

And all I am running is:

for itr in range(1, n_iters+1):
if itr % epoch_len == 0:
    epoch += 1
if itr in teacher_forcing_steps:
    teacher_forcing_ratio = teacher_forcing_ratios[teacher_forcing_steps.index(itr)]
    
filenames, _ = train_batcher.next_batch()
batch = Batch(filenames, d_pad_len, s_pad_len, vocab_lookup)

enc_inpts = Variable(torch.LongTensor(batch.enc_inpts))
dec_inpts = Variable(torch.LongTensor(batch.dec_inpts))
targets = Variable(torch.LongTensor(batch.target_ids))

if USE_CUDA:
    enc_inpts = enc_inpts.cuda(GPU)
    dec_inpts = dec_inpts.cuda(GPU)
    targets = targets.cuda(GPU)

optimizer.zero_grad()

net.train()
outputs = net(enc_inpts, dec_inpts, teacher_forcing_ratio=teacher_forcing_ratio)

Is there anything that stands out that might be causing a redundancy in memory allocation resulting in such a high memory usage?

2 Likes

there are some patches that went into master branch that reduce memory usage of nn.LSTM, etc.

Can you try that?

@pgigioli Did you solve the memory problem? I encountered the same out-of-memory issue using seq2seq .

1 Like