Cuda out of memory error during forward pass

I am trying to implement a ‘one step gradient descent’ aproach wherein I accumulate the loss for the whole dataset, sum it, and then do a backpropagation. I have set my batch size to 8. The issue that I am facing is that after a few forward passes I obtain an OOM error. I think it is because pytorch is saving the forward computation graph for each instance. Is there any work around, where I can save the forward computation graph somewhere and access it when performing a backward pass? Or is there any other workaround?
I have also tried deleting en_input, en_masks, de_output, de_masks after accumulating the loss but no avail.

#reproduce error
from transformers import BertModel, BertForMaskedLM, BertConfig, EncoderDecoderModel
import torch
import torch.nn.functional as F
model1 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints

optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

def train1(batch_size):
  for i in range(20):

    #dummy inputs similar to my dataset
    en_input=torch.tensor([[i for i in range(50)] for i in range (batch_size)])
    en_masks=torch.tensor([[0 for i in range(50)] for i in range (batch_size)])
    de_output=torch.tensor([[i for i in range(50)] for i in range (batch_size)])
    de_masks=torch.tensor([[0 for i in range(50)] for i in range (batch_size)])
    lm_labels=torch.tensor([[i for i in range(50)] for i in range (batch_size)])
    en_input ='cuda')
    de_output ='cuda')
    en_masks ='cuda')
    de_masks ='cuda')
    lm_labels = de_output.clone().to('cuda') 
    out = model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output, 
                        decoder_attention_mask=de_masks, labels=lm_labels)

    prediction_scores = out[1]
    predictions = F.log_softmax(prediction_scores, dim=2)
    p=((predictions.sum() - de_output.sum())).sum() #some loss
    p=torch.unsqueeze(p, dim=0)
    acc =,acc)) # accumulating the loss 



Yes, Autograd will save the computation graphs, if you sum the losses (or store the references to those graphs in any other way) until a backward operation is performed.
To accumulate gradients you could take a look at this post, which explains different approaches and their computation as well as memory usage.

Thank you so much! It helped me a lot!