Hello, I am facing this CUDA out of memory error while executing the 3rd step (val_model2) of my pipeline. I am training a BERT-BERT seq to seq model. I even tried setting my batch size to 1.
In the first step I am training a seq to seq model (say M1) on a dataset D1
In the second step I am training another similar model using the output from M1 on an unlabeled dataset D2
In the third step I am learning A by reducing model2’s loss on validation set of D1.
Here is my training code.
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.92 GiB total capacity; 9.48 GiB already allocated; 19.38 MiB free; 10.15 GiB reserved in total by PyTorch)
'''
MT model class to load pretrained models. Training occurs in 3 steps
1) Train first MT model using matrix A to calculate loss
2) Train second MT model on a dataset created by first MT model on the unlabeled dataset
3) Estimate A by reducing the validation loss of second MT model on validation set of MT dataset
'''
class TranslationModel:
def __init__(self, device, batch_size, logging, model1_path, model2_path, config):
self.model1=EncoderDecoderModel.from_pretrained(model1_path)
self.model2=EncoderDecoderModel.from_pretrained(model2_path)
self.device=device
self.batch_size=batch_size
self.model1 = self.model1.cuda()
self.model2 = self.model2.cuda()
self.logger=logging
self.config=config
def train_model1(self, A_batch, train_dataloader, optimizer1, tokenizer, criterion, scheduler1):
self.model1.train()
epoch_loss = 0
optimizer1.zero_grad()
num_train_batches = len(train_dataloader)
for i, ((en_input, en_masks, de_output, de_masks), a) in enumerate(zip(train_dataloader, A_batch)):
optimizer1.zero_grad()
en_input = en_input.to(self.device)
de_output = de_output.to(self.device)
en_masks = en_masks.to(self.device)
de_masks = de_masks.to(self.device)
lm_labels = de_output.clone().to(self.device)
out = self.model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output,
decoder_attention_mask=de_masks, labels=lm_labels)
predictions = F.log_softmax(out[1], dim=2)
loss1=compute_loss1(predictions, de_output, a, self.device, criterion)
epoch_loss+=loss1.item()
loss1.backward(inputs=list(self.model1.parameters()), retain_graph=True)
torch.nn.utils.clip_grad_norm_(self.model1.parameters(), self.config["model1"]['grad_clip'])
optimizer1.step() # wt updation
scheduler1.step()
if ((i+1)*self.batch_size)% self.config['report_freq'] == 0:
self.logger.info('loss after %d instances: %d', (i+1)*self.batch_size, epoch_loss)
self.logger.info('bleu score after %d instances: %d', (i+1)*self.batch_size, calc_bleu(en_input, lm_labels, self.model1, tokenizer))
self.logger.info('Mean epoch loss for step 1: %d', (epoch_loss / num_train_batches))
return ((epoch_loss / num_train_batches))
def train_model2(self, unlabeled_dataloader, optimizer2, tokenizer, criterion, scheduler2):
epoch_loss=0
optimizer2.zero_grad()
self.model2.train()
num_train_batches = len(unlabeled_dataloader)
for i, (en_input, en_masks, de_output, de_masks) in enumerate(unlabeled_dataloader):
en_input = en_input.to(self.device)
outputs=self.model1(input_ids=en_input, decoder_input_ids=en_input, output_hidden_states=True, return_dict=True)
predictions = F.log_softmax(outputs.logits, dim=2)
values, new_labels = torch.max(predictions, 2)
out=self.model2(input_ids=en_input, decoder_inputs_embeds=outputs.decoder_hidden_states[-1], labels=new_labels)
predictions = F.log_softmax(out[1], dim=2)
loss2=compute_loss2(predictions, new_labels, self.device, criterion)
epoch_loss += loss2.item()
loss2.backward(inputs=list(self.model2.parameters()), retain_graph=True)
torch.nn.utils.clip_grad_norm_(self.model2.parameters(), self.config["model2"]['grad_clip'])
optimizer2.step()
scheduler2.step()
if ((i+1)*self.batch_size)% self.config['report_freq'] == 0:
self.logger.info('loss after %d instances: %d', (i+1)*self.batch_size, epoch_loss)
self.logger.info('bleu score after %d instances: %d', (i+1)*self.batch_size, calc_bleu(en_input, new_labels, self.model2, tokenizer))
self.logger.info('Mean epoch loss for step 2: %d', (epoch_loss / num_train_batches))
return ((epoch_loss / num_train_batches))
def val_model2(self, valid_dataloader, optimizer3, A, A_batch, tokenizer, criterion, scheduler3, a_ind):
epoch_loss=0
self.model2.eval()
optimizer3.zero_grad()
A.grad=torch.zeros(len(A), device='cpu')
for i, ((en_input, en_masks, de_output, de_masks), a) in enumerate(zip(valid_dataloader, A_batch)):
en_input = en_input.to(self.device)
de_output = de_output.to(self.device)
en_masks = en_masks.to(self.device)
de_masks = de_masks.to(self.device)
lm_labels = de_output.clone().to(self.device)
out=self.model2(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output,
decoder_attention_mask=de_masks, labels=de_output.clone())
predictions = F.log_softmax(out[1], dim=2)
loss3 = compute_loss2(predictions, de_output, self.device, criterion)
epoch_loss+=loss3.item()
loss3.backward(inputs=list(self.model2.parameters()), retain_graph=True)
del loss3
# compute hessian vector product
r=1e-2
vector=[]
for param in self.model2.parameters():
if param.grad!=None:
vector.append(param.grad.data.to(self.device))
else:
vector.append(torch.ones(1).to(self.device))
R = r / _concat(vector, self.device).norm()
print(R)
for p, v in zip(self.model2.parameters(), vector):
p.data.to(self.device)
p.data.add_(alpha=R, other=v)
#calculate loss
outputs=self.model1(input_ids=en_input, decoder_input_ids=en_input, output_hidden_states=True, return_dict=True)
predictions = F.log_softmax(outputs.logits, dim=2)
values, new_labels = torch.max(predictions, 2)
out=self.model2(input_ids=en_input, decoder_inputs_embeds=outputs.decoder_hidden_states[-1], labels=new_labels)
predictions = F.log_softmax(out[1], dim=2)
loss2=compute_loss2(predictions, new_labels, self.device, criterion)
grads_p=torch.autograd.grad(loss2, self.model1.parameters(), allow_unused=True, retain_graph=True)
del loss2
del predictions
del out
del outputs
del new_labels
torch.cuda.empty_cache()
for p, v in zip(self.model2.parameters(), vector):
p.data.sub_(alpha=2 * R, other=v)
#calculate loss
outputs=self.model1(input_ids=en_input, decoder_input_ids=en_input, output_hidden_states=True, return_dict=True)
predictions = F.log_softmax(outputs.logits, dim=2)
values, new_labels = torch.max(predictions, 2)
out=self.model2(input_ids=en_input, decoder_inputs_embeds=outputs.decoder_hidden_states[-1], labels=new_labels)
predictions = F.log_softmax(out[1], dim=2)
loss2=compute_loss2(predictions, new_labels, self.device, criterion)
grads_n = torch.autograd.grad(loss2, self.model1.parameters(), allow_unused=True, retain_graph=True)
del loss2
del predictions
del out
del outputs
del new_labels
for p, v in zip(self.model2.parameters(), vector):
p.data.add_(R, v)
del vector
torch.cuda.empty_cache()
# OOM error occurs here
vector=[]
for x,y in zip(grads_p, grads_n):
if x!=None and y!=None:
vector.append(((x - y).div_(2 * R)))
else:
vector.append(torch.ones(1))
del grads_n
del grads_p
torch.cuda.empty_cache()
for p, v in zip(self.model1.parameters(), vector):
p.data.add_(alpha=R, other=v)
#calculate loss
out = self.model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output,
decoder_attention_mask=de_masks, labels=lm_labels.clone())
predictions = F.log_softmax(out[1], dim=2)
loss1=compute_loss1(predictions, de_output, a, self.device, criterion)
grads_p=torch.autograd.grad(loss1, a, allow_unused=True, retain_graph=True)
for p, v in zip(self.model1.parameters(), vector):
p.data.sub_(2 * R, v)
del out
del predictions
del loss1
torch.cuda.empty_cache()
#calculate loss
out = self.model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output,
decoder_attention_mask=de_masks, labels=lm_labels.clone())
predictions = F.log_softmax(out[1], dim=2)
loss1=compute_loss1(predictions, de_output, a, self.device, criterion)
grads_n=torch.autograd.grad(loss1, a, allow_unused=True, retain_graph=True)
del out
del predictions
del loss1
torch.cuda.empty_cache()
for p, v in zip(self.model1.parameters(), vector):
p.data.add_(R, v)
A.grad[a_ind:a_ind+self.batch_size]=[(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)][0]
print(A.grad)
del grads_p
del grads_n
torch.cuda.empty_cache()
optimizer3.step()
scheduler3.step()
a_ind=a_ind+self.batch_size
A.grad=torch.zeros(len(A), device=self.device)
if ((i+1)*self.batch_size)% self.config['report_freq'] == 0:
self.logger.info('loss after %d instances: %d', (i+1)*self.batch_size, epoch_loss)
self.logger.info('bleu score after %d instances: %d', (i+1)*self.batch_size, calc_bleu(en_input, lm_labels, self.model2, tokenizer))
self.logger.info('Mean epoch loss for step 3: %d', (epoch_loss / len(valid_dataloader)))
return (epoch_loss / len(valid_dataloader), a_ind)
train_model1(...)
train_model2(...)
val_model2(...)
Any help would be appreciated. Thanks.