Why this model and training routine would give CUDA memory error?

Why do you think the following code would give CUDA memory error? Is there any accumulation happening that will hinder the variables to go out of scope?

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.data.utils import get_tokenizer
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.nn import TransformerDecoder, TransformerDecoderLayer

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):

    def __init__(self, ntoken_source, ntoken_target, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.trg_mask = None
        self.src_key_padding_mask = None
        self.trg_key_padding_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken_source, ninp)
        self.decoder = nn.Embedding(ntoken_target, ninp)
        self.ninp = ninp
        decoder_layers = TransformerDecoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, nlayers)
        self.pred_decoder = nn.Linear(ninp, ntoken_target)


    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def _generate_padding_mask(self, tsr):
        # padding index is 1
        msk = torch.tensor((tsr.data.cpu().numpy() == 1).astype(int), dtype = torch.bool)
        msk = msk.permute(1,0)
        return msk

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.pred_decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, trg=None):
        trg_device = trg.device
        self.trg_mask = self._generate_square_subsequent_mask(len(trg)).to(trg_device)

        src_device = src.device
        self.src_key_padding_mask = self._generate_padding_mask(src).to(src_device)
        self.trg_key_padding_mask = self._generate_padding_mask(trg).to(trg_device)

        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        enc_output = self.transformer_encoder(src, src_key_padding_mask=self.src_key_padding_mask)#, self.src_mask)
        trg = self.decoder(trg) * math.sqrt(self.ninp)
        trg = self.pos_encoder(trg)
        dec_output = self.transformer_decoder(trg, enc_output, tgt_mask=self.trg_mask,\
        preds = self.pred_decoder(dec_output)
        return preds

Training Procedure:

model = TransformerModel(ntokens_source, ntokens_target, emsize, nhead, nhid, nlayers, dropout).to(device)
	total_params = sum(p.numel() for p in model.parameters())
	total_params_t = sum(p.numel() for p in model.parameters() if p.requires_grad)
	print("Total Parameters: {}, Trainable Parameters: {}".format(total_params, total_params_t))

	iterators = BucketIterator.splits(
					d, batch_size=bsz, shuffle=True)
	print("Len Iterators: {}, {}".format(len(iterators[0]), len(iterators[1])))

	criterion = nn.CrossEntropyLoss()
	optimizer = torch.optim.Adam(model.parameters())
	epochs = 50
	check_after = 20
	best_val_loss = float('inf')
	for epoch in range(epochs):
		avg_loss = 0.0
		for i, batch in enumerate(iterators[0]):
			src = batch.src.to(device)[1:,:]
			trg = batch.trg.to(device)
			print(src.shape, trg.shape)
			output = model(src,trg[:-1,:])
			loss = criterion(output.view(-1, ntokens_target), trg[1:,:].reshape(-1))
			del output
			del src		
			del trg

		avg_eval_loss = 0.0
		avg_bleu_score = 0.0
		avg_gleu_score = 0.0
		for i, batch_ in enumerate(iterators[1]):
			src_eval = batch_.src.to(device)[1:,:]
			trg_eval = batch_.trg.to(device)
			output_eval = model(src_eval, trg_eval[:-1,:])#.view(-1, ntokens_target)
			pairs, avg_bleu, avg_gleu = decode(output_eval.detach().cpu(), trg_eval[1:,:].cpu(), TRG.vocab.itos)
			if i%check_after == 0:
				log(i, pairs)
			loss_eval = criterion(output_eval.view(-1, ntokens_target), trg_eval[1:,:].view(-1))
			del src_eval
			del trg_eval
			del output_eval

It’s happening with as low as 8 batch size. The train iterator performs approx. 2500 iterations. And the CUDA memory error occurs during these iterations but if it manges to pass through one epoch, it certainly occurs at around the first or second iteration in the 2nd epoch. Please help with understanding what could be the problem here (if there is). Also, how can I use nn.DataParallel here?

If the out of memory error is raised in the second epoch, you might be storing unnecessary tensors (or the whole computation graph) in some variables.


might attach the whole computation graph, if avg_* is still attached to it (you can check if by printing the .grad_fn or to be safe just detach() these tensors).

Also, to save memory, you could wrap the validation loop into a with torch.no_grad() block.

To use nn.DataParallel you should wrap the model in nn.DataParallel and make sure the input tensors can be split in dim0, so that each chunk is passed to the corresponding device.
This tutorial will give you more information.

No, I am not accumulating the computation graph. But with torch.no_grad() helped. Thanks!