Running out of memory training two layer biLSTM w batch size 32

Hi there!
I am attempting to train a biLSTM model following a tutorial to use as a binary classifier for textual data. Currently I am using glove embeddings of dim 300 with a custom embedding layer constructed from the vocab used in my corpus and the pretrained glove embeddings. The embedding layer seems to be no problem to fit into memory and the gradients are frozen to not train the embeddings further. I’ve set the max sequence length of each sequence to 256.

I’ll post my code below for those interested

My dataset class and dataloader

class TextDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df.index)


    def __getitem__(self, idx):
        row = self.df.iloc[idx].values
        tokens, length = row[1]
        label = row[3]

        return tokens, label, length
    
    def get_class_weights(self):
        total_texts = self.__len__()
        num_non_shooter_texts, num_shooter_texts = self.df["label"].value_counts()
        print(f"Value counts:\n{self.df['label'].value_counts()}")

        non_shooter_wt = total_texts / num_non_shooter_texts
        shooter_wt = total_texts / num_shooter_texts
        print(f"non_shooter: {non_shooter_wt}\nshooter: {shooter_wt}")

        return [non_shooter_wt, shooter_wt]

train_loader = DataLoader(train_set, batch_size=32, pin_memory=True)

The network is a two layer biLSTM

class LSTMTextClassifier(nn.Module):
    def __init__(self, embs, emb_dim: int = 300, dropout: int = 0.5, hidden_size: int = 128, num_layers: int = 2):
        super(LSTMTextClassifier, self).__init__()

        self.embedding = nn.Embedding.from_pretrained(torch.from_numpy(embs).float())
        self.embedding.weight.requires_grad = False
        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(emb_dim, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(2*hidden_size, 1)
        self.dropout = nn.Dropout(p=dropout)
        self.sig = nn.Sigmoid()

    def forward(self, x, length):
        """Perform a forward pass through the network.

        Args:
            x (torch.Tensor): A tensor of token ids with shape (batch_size, max_sent_length)
            length: the length of the sequence before padding
        """

        embs = self.embedding(x)

        packed_input = pack_padded_sequence(embs, length, batch_first=True, enforce_sorted=False)
        packed_out, _ = self.lstm(packed_input)
        out, _ = pad_packed_sequence(packed_out, batch_first=True)

        out_forward = out[range(len(out)), length - 1, :self.hidden_size]
        out_backwards = out[:, 0, self.hidden_size:]

        out_reduced = torch.cat((out_forward, out_backwards), 1) # Concat for fc layer and final pred
        out_dropped = self.dropout(out_reduced) # Dropout layer
        out = self.fc(out_dropped)
        out = self.sig(out)

        return out

What confuses me is that my dataset is not particularly large. It consists of around 10000 posts averaging around 100 words in length, some a bit more. Seeing as the biLSTM model is not very large either I struggle to see how I would run out o fmemory on a 32GB gpu from just a few batches / epochs depending on the batch size. I am only running 10 epochs. An epoch looks like so:

def run_epoch():
        running_loss = 0.

        for _, data in enumerate(train_loader):
            inputs, labels, lengths = data
            labels = torch.tensor(np.array(labels), dtype=torch.float32).to(device)
            inputs = torch.from_numpy(np.array(inputs)).to(device)
            optimizer.zero_grad()
 
            # Weighting scheme to accomodate for small minority class
            # Weights are ab. 1.2 for majority class and 6 for minority
            weighting = []
            for l in labels:
                if l == 0:
                    weighting.append(class_wts[0])
                else:
                    weighting.append(class_wts[1])

            outputs = model(inputs, lengths)

            loss_fn = nn.BCELoss(weight=torch.tensor(weighting))
            loss = loss_fn(outputs.squeeze(dim=1), labels) # Unsqueeze target tensor to allow for batching and same dims for out and target
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
                        
        return running_loss/len(train_loader)

Any help is appreciated :slight_smile:

Some additional info:
My mapped embeddings used as weights in the embedding should take up no more than 200mb at a maximum.

Reducing the embedding size for glove to 50 still results in an oom error.

Training a 3 layer cnn on the same embeddings and hyperparams works no problem. So I suspect my implementation of the LSTM might be wrong in some way

I don’t see any obvious issues in your code and the memory footprint looks reasonable for this config:

print(torch.cuda.memory_allocated()/1024**2)
# 0.0

model = LSTMTextClassifier(np.random.randn(100, 300)).cuda()
x = torch.randint(0, 100, (10000, 100)).cuda()
print(torch.cuda.memory_allocated()/1024**2)
# 10.93310546875

out = model(x, length=torch.tensor([100]))
print(torch.cuda.memory_allocated()/1024**2)
# 20.8681640625

Did you check if the memory usage is increasing or if it’s static?

The memory usage seems to be increasing for each epoch. I ran the code on a different GPU now and the memory usage should be completely reasonable for the GPU I initially tried the task on. The output below shows the memory used for each epoch in the training process:

Start training...
EPOCH 0:
mem usage before epoch 0
torch.cuda.memory_allocated: 19.937500MB
torch.cuda.memory_reserved: 42.000000MB
torch.cuda.max_memory_reserved: 42.000000MB
validating
LOSS train 0.9641966915130615 valid 0.8248609900474548
{'tn': 2335, 'fp': 1036, 'fn': 48, 'tp': 541, 'accuracy': 0.7262626262626263, 'precision': 0.6614568794869387, 'recall': 0.8055893698322706, 'specificity': 0.6926727973894987, 'f1_score': 0.655573817370934, 'roc_auc': 0.8055893698322706, 'train_loss': 0.9641966915130615, 'val_loss': tensor(0.8249, device='cuda:0', grad_fn=<DivBackward0>)}
EPOCH 1:
mem usage before epoch 1
torch.cuda.memory_allocated: 407.688477MB
torch.cuda.memory_reserved: 664.000000MB
torch.cuda.max_memory_reserved: 664.000000MB
validating
LOSS train 0.6999987041950226 valid 0.7488411068916321
{'tn': 2470, 'fp': 901, 'fn': 55, 'tp': 534, 'accuracy': 0.7585858585858586, 'precision': 0.6751716286611239, 'recall': 0.8196708266201431, 'specificity': 0.7327202610501334, 'f1_score': 0.6827620789333961, 'roc_auc': 0.8196708266201431, 'train_loss': 0.6999987041950226, 'val_loss': tensor(0.7488, device='cuda:0', grad_fn=<DivBackward0>)}
EPOCH 2:
mem usage before epoch 2
torch.cuda.memory_allocated: 769.620605MB
torch.cuda.memory_reserved: 1026.000000MB
torch.cuda.max_memory_reserved: 1026.000000MB
validating
LOSS train 0.5922324299812317 valid 0.6362012028694153
{'tn': 2889, 'fp': 482, 'fn': 77, 'tp': 512, 'accuracy': 0.8588383838383838, 'precision': 0.7445648265859486, 'recall': 0.8631428357018995, 'specificity': 0.8570157223375853, 'f1_score': 0.779330469080756, 'roc_auc': 0.8631428357018996, 'train_loss': 0.5922324299812317, 'val_loss': tensor(0.6362, device='cuda:0', grad_fn=<DivBackward0>)}
EPOCH 3:
mem usage before epoch 3
torch.cuda.memory_allocated: 1131.552734MB
torch.cuda.memory_reserved: 1388.000000MB
torch.cuda.max_memory_reserved: 1388.000000MB
validating
LOSS train 0.45384600430727007 valid 0.7183728218078613
{'tn': 2946, 'fp': 425, 'fn': 109, 'tp': 480, 'accuracy': 0.8651515151515151, 'precision': 0.7473537629644365, 'recall': 0.844432614344159, 'specificity': 0.8739246514387422, 'f1_score': 0.7797351872475897, 'roc_auc': 0.844432614344159, 'train_loss': 0.45384600430727007, 'val_loss': tensor(0.7184, device='cuda:0', grad_fn=<DivBackward0>)}
EPOCH 4:
mem usage before epoch 4
torch.cuda.memory_allocated: 1493.484863MB
torch.cuda.memory_reserved: 1750.000000MB
torch.cuda.max_memory_reserved: 1750.000000MB
validating
LOSS train 0.44859391540288923 valid 1.0342718362808228
{'tn': 3179, 'fp': 192, 'fn': 167, 'tp': 422, 'accuracy': 0.9093434343434343, 'precision': 0.8186930381163955, 'recall': 0.8297560990350634, 'specificity': 0.9430436072382082, 'f1_score': 0.8240664528941157, 'roc_auc': 0.8297560990350633, 'train_loss': 0.44859391540288923, 'val_loss': tensor(1.0343, device='cuda:0', grad_fn=<DivBackward0>)}
EPOCH 5:
mem usage before epoch 5
torch.cuda.memory_allocated: 1855.416992MB
torch.cuda.memory_reserved: 2112.000000MB
torch.cuda.max_memory_reserved: 2112.000000MB
validating
LOSS train 0.36383001729846 valid 0.6951749324798584
{'tn': 2888, 'fp': 483, 'fn': 78, 'tp': 511, 'accuracy': 0.8583333333333333, 'precision': 0.7438932312689353, 'recall': 0.8621456153277808, 'specificity': 0.8567190744586176, 'f1_score': 0.7785409537644081, 'roc_auc': 0.8621456153277808, 'train_loss': 0.36383001729846, 'val_loss': tensor(0.6952, device='cuda:0', grad_fn=<DivBackward0>)}
EPOCH 6:
mem usage before epoch 6
torch.cuda.memory_allocated: 2217.349121MB
torch.cuda.memory_reserved: 2474.000000MB
torch.cuda.max_memory_reserved: 2474.000000MB
validating
LOSS train 0.2757247617095709 valid 1.0132341384887695
{'tn': 3105, 'fp': 266, 'fn': 138, 'tp': 451, 'accuracy': 0.897979797979798, 'precision': 0.7932282857058073, 'recall': 0.843398124117674, 'specificity': 0.921091664194601, 'f1_score': 0.8147879735361114, 'roc_auc': 0.843398124117674, 'train_loss': 0.2757247617095709, 'val_loss': tensor(1.0132, device='cuda:0', grad_fn=<DivBackward0>)}
EPOCH 7:
mem usage before epoch 7
torch.cuda.memory_allocated: 2579.281250MB
torch.cuda.memory_reserved: 2836.000000MB
torch.cuda.max_memory_reserved: 2836.000000MB
validating
LOSS train 0.17272374629974366 valid 0.9461060762405396
{'tn': 3103, 'fp': 268, 'fn': 115, 'tp': 474, 'accuracy': 0.9032828282828282, 'precision': 0.8015387669426859, 'recall': 0.8626260942353108, 'specificity': 0.9204983684366657, 'f1_score': 0.8270596247941411, 'roc_auc': 0.8626260942353108, 'train_loss': 0.17272374629974366, 'val_loss': tensor(0.9461, device='cuda:0', grad_fn=<DivBackward0>)}
EPOCH 8:
torch.cuda.memory_allocated: 2941.213379MB
torch.cuda.memory_reserved: 3198.000000MB
torch.cuda.max_memory_reserved: 3198.000000MB
validating
LOSS train 0.2081744890101254 valid 1.0820714235305786
{'tn': 3159, 'fp': 212, 'fn': 157, 'tp': 432, 'accuracy': 0.9068181818181819, 'precision': 0.8117306265874473, 'recall': 0.8352785845917365, 'specificity': 0.9371106496588549, 'f1_score': 0.8227741155897869, 'roc_auc': 0.8352785845917365, 'train_loss': 0.2081744890101254, 'val_loss': tensor(1.0821, device='cuda:0', grad_fn=<DivBackward0>)}
EPOCH 9:
mem usage before epoch 9
torch.cuda.memory_allocated: 3303.145508MB
torch.cuda.memory_reserved: 3560.000000MB
torch.cuda.max_memory_reserved: 3560.000000MB

Not sure why the usage is increasing though. It might be a tensor I am not sending to garbage collection somewhere. It might be a result of saving the loss in my metrics table. I see that my val loss is stored as the loss tensor and not the loss.item() scalar. Maybe this would accumulate up some memory since I am appending to this metrics table for each epoch and displaying it at the end of training. This tensor should be relatively small compared to the training set though. It is approx. 3k entries with the same dimensions as the training set

Hey, could you please post the code where this is happening?

If it’s being done under the no_grad context manager it might not be a problem otherwise this can very possibly be the reason why memory usage is increasing with epochs.

As @ptrblck stated, there doesn’t seem to be any obvious reasons as to why the memory might be increasing over time in the code you’ve posted as of now.

I fixed the loss metrics method, now appending the vloss.item() instead of the vloss itself. Somehow I forgot to do .item() on the validation step, but remembered to do it in the epoch itself. When changing from vloss to vloss.item() the memory usage stays more or less the same. Thank you for your help guys. I’ll have to troubleshoot some more before I post next time :slight_smile:

If you’re curious, this was the line that made my memory consumption increase for each epoch:

vloss = loss_fn(voutputs, vlabels.to(torch.float32).unsqueeze(1))
running_vloss += vloss

Changing to the code under fixed the issue

vloss = loss_fn(voutputs, vlabels.to(torch.float32).unsqueeze(1))
running_vloss += vloss.item()

Output is now more resonable!

Start training...
EPOCH 0:
mem usage before epoch 0
torch.cuda.memory_allocated: 19.993652MB
torch.cuda.memory_reserved: 42.000000MB
torch.cuda.max_memory_reserved: 42.000000MB
validating
LOSS train 1.0116420435905455 valid 0.9235664361354979
{'tn': 2428, 'fp': 450, 'fn': 158, 'tp': 371, 'accuracy': 0.8215438802465512, 'precision': 0.6953948601718426, 'recall': 0.772482334534458, 'specificity': 0.8436414176511466, 'f1_score': 0.7191779187679628, 'roc_auc': 0.772482334534458, 'train_loss': 1.0116420435905455, 'val_loss': 0.9235664361354979}
EPOCH 1:
mem usage before epoch 1
torch.cuda.memory_allocated: 46.227539MB
torch.cuda.memory_reserved: 332.000000MB
torch.cuda.max_memory_reserved: 332.000000MB
validating
LOSS train 0.6313787889480591 valid 0.8605835854599385
{'tn': 2642, 'fp': 236, 'fn': 169, 'tp': 360, 'accuracy': 0.8811270912826533, 'precision': 0.7719529461201082, 'recall': 0.7992639553565212, 'specificity': 0.9179986101459346, 'f1_score': 0.7844049920899983, 'roc_auc': 0.7992639553565212, 'train_loss': 0.6313787889480591, 'val_loss': 0.8605835854599385}
EPOCH 2:
mem usage before epoch 2
torch.cuda.memory_allocated: 46.584473MB
torch.cuda.memory_reserved: 332.000000MB
torch.cuda.max_memory_reserved: 332.000000MB
validating
LOSS train 0.6704263162612915 valid 0.7507183478225559
{'tn': 2460, 'fp': 418, 'fn': 107, 'tp': 422, 'accuracy': 0.8459054886997358, 'precision': 0.7303490270280297, 'recall': 0.8262459095859207, 'specificity': 0.8547602501737318, 'f1_score': 0.7600448337549075, 'roc_auc': 0.8262459095859207, 'train_loss': 0.6704263162612915, 'val_loss': 0.7507183478225559}
EPOCH 3:
mem usage before epoch 3
torch.cuda.memory_allocated: 46.227539MB
torch.cuda.memory_reserved: 390.000000MB
torch.cuda.max_memory_reserved: 390.000000MB
validating
LOSS train 0.4898209497332573 valid 0.7501729907758043
{'tn': 2510, 'fp': 368, 'fn': 116, 'tp': 413, 'accuracy': 0.8579395362488993, 'precision': 0.74231778540801, 'recall': 0.8264258812371015, 'specificity': 0.872133425990271, 'f1_score': 0.7712991523167052, 'roc_auc': 0.8264258812371016, 'train_loss': 0.4898209497332573, 'val_loss': 0.7501729907758043}
EPOCH 4:
mem usage before epoch 4
torch.cuda.memory_allocated: 46.584473MB
torch.cuda.memory_reserved: 390.000000MB
torch.cuda.max_memory_reserved: 390.000000MB
validating
LOSS train 0.5079629309475422 valid 0.727817810095925
{'tn': 2547, 'fp': 331, 'fn': 95, 'tp': 434, 'accuracy': 0.8749633108306428, 'precision': 0.7656813267825424, 'recall': 0.8527027275557617, 'specificity': 0.8849895760945101, 'f1_score': 0.7968071702170552, 'roc_auc': 0.8527027275557616, 'train_loss': 0.5079629309475422, 
'val_loss': 0.727817810095925}
EPOCH 5:
mem usage before epoch 5
torch.cuda.memory_allocated: 46.267090MB
torch.cuda.memory_reserved: 390.000000MB
torch.cuda.max_memory_reserved: 390.000000MB
validating
LOSS train 0.36113569617271424 valid 0.8384792497388612
{'tn': 2425, 'fp': 453, 'fn': 97, 'tp': 432, 'accuracy': 0.8385676548282947, 'precision': 0.7248370273794003, 'recall': 0.8296170938913419, 'specificity': 0.8425990271021543, 'f1_score': 0.7545903399863796, 'roc_auc': 0.8296170938913418, 'train_loss': 0.36113569617271424, 'val_loss': 0.8384792497388612}
EPOCH 6:
mem usage before epoch 6
torch.cuda.memory_allocated: 46.227539MB
torch.cuda.memory_reserved: 390.000000MB
torch.cuda.max_memory_reserved: 390.000000MB
validating
LOSS train 0.2776815439760685 valid 1.1387173141670421
{'tn': 2649, 'fp': 229, 'fn': 142, 'tp': 387, 'accuracy': 0.8911065453478133, 'precision': 0.7886844658387117, 'recall': 0.8259999264349456, 'specificity': 0.9204308547602502, 'f1_score': 0.8052694459486629, 'roc_auc': 0.8259999264349456, 'train_loss': 0.2776815439760685, 'val_loss': 1.1387173141670421}
EPOCH 7:
mem usage before epoch 7
torch.cuda.memory_allocated: 46.584473MB
torch.cuda.memory_reserved: 390.000000MB
torch.cuda.max_memory_reserved: 390.000000MB
validating
LOSS train 0.17835736103355884 valid 1.0326833028540423
{'tn': 2438, 'fp': 440, 'fn': 76, 'tp': 453, 'accuracy': 0.8485471088934546, 'precision': 0.7385240636756671, 'recall': 0.8517243780140324, 'specificity': 0.8471160528144545, 'f1_score': 0.7707167361554546, 'roc_auc': 0.8517243780140324, 'train_loss': 0.17835736103355884, 'val_loss': 1.0326833028540423}
EPOCH 8:
mem usage before epoch 8
torch.cuda.memory_allocated: 46.227539MB
torch.cuda.memory_reserved: 390.000000MB
torch.cuda.max_memory_reserved: 390.000000MB
validating
LOSS train 0.2255521084740758 valid 1.2945302003195946
{'tn': 2644, 'fp': 234, 'fn': 150, 'tp': 379, 'accuracy': 0.8872908717346639, 'precision': 0.7822921641690829, 'recall': 0.8175698309711507, 'specificity': 0.9186935371785963, 'f1_score': 0.7980234117760394, 'roc_auc': 0.8175698309711507, 'train_loss': 0.2255521084740758, 'val_loss': 1.2945302003195946}
EPOCH 9:
mem usage before epoch 9
torch.cuda.memory_allocated: 46.227539MB
torch.cuda.memory_reserved: 390.000000MB
torch.cuda.max_memory_reserved: 390.000000MB
validating
LOSS train 0.1795539532136172 valid 1.4612185525321832
{'tn': 2633, 'fp': 245, 'fn': 158, 'tp': 371, 'accuracy': 0.8817141179923687, 'precision': 0.7728310967069476, 'recall': 0.8080973449583635, 'specificity': 0.9148714384989576, 'f1_score': 0.7884732795614855, 'roc_auc': 0.8080973449583635, 'train_loss': 0.1795539532136172, 'val_loss': 1.4612185525321832}

Glad it’s solved.

Also, you might want to include the validation step under the torch.no_grad() context manager - this will save memory and speed things up.

You are very right. I did this as well. Thanks for all your help!