Hello,
I am trying to implement a ‘one step gradient descent’ aproach wherein I accumulate the loss for the whole dataset, sum it, and then do a backpropagation. I have set my batch size to 8. The issue that I am facing is that after a few forward passes I obtain an OOM error. I think it is because pytorch is saving the forward computation graph for each instance. Is there any work around, where I can save the forward computation graph somewhere and access it when performing a backward pass? Or is there any other workaround?
I have also tried deleting en_input, en_masks, de_output, de_masks after accumulating the loss but no avail.
#reproduce error
from transformers import BertModel, BertForMaskedLM, BertConfig, EncoderDecoderModel
import torch
import torch.nn.functional as F
model1 = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
model1.cuda()
optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)
def train1(batch_size):
acc=torch.zeros(1)
acc=acc.to('cuda')
for i in range(20):
optimizer1.zero_grad()
#dummy inputs similar to my dataset
en_input=torch.tensor([[i for i in range(50)] for i in range (batch_size)])
en_masks=torch.tensor([[0 for i in range(50)] for i in range (batch_size)])
de_output=torch.tensor([[i for i in range(50)] for i in range (batch_size)])
de_masks=torch.tensor([[0 for i in range(50)] for i in range (batch_size)])
lm_labels=torch.tensor([[i for i in range(50)] for i in range (batch_size)])
en_input = en_input.to('cuda')
de_output = de_output.to('cuda')
en_masks = en_masks.to('cuda')
de_masks = de_masks.to('cuda')
lm_labels = de_output.clone().to('cuda')
out = model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output,
decoder_attention_mask=de_masks, labels=lm_labels)
prediction_scores = out[1]
predictions = F.log_softmax(prediction_scores, dim=2)
p=((predictions.sum() - de_output.sum())).sum() #some loss
p=torch.unsqueeze(p, dim=0)
acc = torch.cat((p,acc)) # accumulating the loss
loss=acc.sum()
loss.backward(retain_graph=True)
optimizer1.step()
train1(batch_size=8)