Cuda OOM error while training seq to seq model

Hello, I am facing this CUDA out of memory error while executing the 3rd step (val_model2) of my pipeline. I am training a BERT-BERT seq to seq model. I even tried setting my batch size to 1.
In the first step I am training a seq to seq model (say M1) on a dataset D1
In the second step I am training another similar model using the output from M1 on an unlabeled dataset D2
In the third step I am learning A by reducing model2’s loss on validation set of D1.
Here is my training code.

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.92 GiB total capacity; 9.48 GiB already allocated; 19.38 MiB free; 10.15 GiB reserved in total by PyTorch)
'''
MT model class to load pretrained models. Training occurs in 3 steps
1) Train first MT model using matrix A to calculate loss
2) Train second MT model on a dataset created by first MT model on the unlabeled dataset
3) Estimate A by reducing the validation loss of second MT model on validation set of MT dataset
'''
class TranslationModel:
    def __init__(self, device, batch_size, logging, model1_path, model2_path, config):
        self.model1=EncoderDecoderModel.from_pretrained(model1_path)
        self.model2=EncoderDecoderModel.from_pretrained(model2_path)
        self.device=device
        self.batch_size=batch_size
        self.model1 = self.model1.cuda()
        self.model2 = self.model2.cuda()
        self.logger=logging
        self.config=config
        

    def train_model1(self, A_batch, train_dataloader, optimizer1, tokenizer, criterion, scheduler1):
        self.model1.train()
        epoch_loss = 0
        optimizer1.zero_grad()
        num_train_batches = len(train_dataloader)

        for i, ((en_input, en_masks, de_output, de_masks), a) in enumerate(zip(train_dataloader, A_batch)):
            
            optimizer1.zero_grad()
            en_input = en_input.to(self.device) 
            de_output = de_output.to(self.device)
            en_masks = en_masks.to(self.device)
            de_masks = de_masks.to(self.device)
            lm_labels = de_output.clone().to(self.device)

            out = self.model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output, 
                                decoder_attention_mask=de_masks, labels=lm_labels)
                
            predictions = F.log_softmax(out[1], dim=2)
            loss1=compute_loss1(predictions, de_output, a, self.device, criterion)
            epoch_loss+=loss1.item()
            loss1.backward(inputs=list(self.model1.parameters()), retain_graph=True) 
            torch.nn.utils.clip_grad_norm_(self.model1.parameters(), self.config["model1"]['grad_clip'])
            optimizer1.step() # wt updation  
            scheduler1.step() 

            if ((i+1)*self.batch_size)% self.config['report_freq'] == 0:
                self.logger.info('loss after %d instances: %d', (i+1)*self.batch_size, epoch_loss)
                self.logger.info('bleu score after %d instances: %d', (i+1)*self.batch_size, calc_bleu(en_input, lm_labels, self.model1, tokenizer))
        
        self.logger.info('Mean epoch loss for step 1: %d', (epoch_loss / num_train_batches))
        return ((epoch_loss / num_train_batches))

    def train_model2(self, unlabeled_dataloader, optimizer2, tokenizer, criterion, scheduler2):
        epoch_loss=0
        optimizer2.zero_grad()
        self.model2.train()
        num_train_batches = len(unlabeled_dataloader)
        for i, (en_input, en_masks, de_output, de_masks) in enumerate(unlabeled_dataloader):
            en_input = en_input.to(self.device)
            outputs=self.model1(input_ids=en_input, decoder_input_ids=en_input, output_hidden_states=True, return_dict=True)
            predictions = F.log_softmax(outputs.logits, dim=2)
            values, new_labels = torch.max(predictions, 2)
            
            out=self.model2(input_ids=en_input, decoder_inputs_embeds=outputs.decoder_hidden_states[-1], labels=new_labels)
            predictions = F.log_softmax(out[1], dim=2)
            loss2=compute_loss2(predictions, new_labels, self.device, criterion)

            epoch_loss += loss2.item()
            loss2.backward(inputs=list(self.model2.parameters()), retain_graph=True)
            torch.nn.utils.clip_grad_norm_(self.model2.parameters(), self.config["model2"]['grad_clip'])
            optimizer2.step()
            scheduler2.step()
            
            if ((i+1)*self.batch_size)% self.config['report_freq'] == 0:
                self.logger.info('loss after %d instances: %d', (i+1)*self.batch_size, epoch_loss)
                self.logger.info('bleu score after %d instances: %d', (i+1)*self.batch_size, calc_bleu(en_input, new_labels, self.model2, tokenizer))

        self.logger.info('Mean epoch loss for step 2: %d', (epoch_loss / num_train_batches))
        
        return ((epoch_loss / num_train_batches))
        
    def val_model2(self, valid_dataloader, optimizer3, A, A_batch, tokenizer, criterion, scheduler3, a_ind):
        epoch_loss=0
        self.model2.eval()
        optimizer3.zero_grad()
        A.grad=torch.zeros(len(A), device='cpu')

        for i, ((en_input, en_masks, de_output, de_masks), a) in enumerate(zip(valid_dataloader, A_batch)):
            en_input = en_input.to(self.device)
            de_output = de_output.to(self.device)
            en_masks = en_masks.to(self.device)
            de_masks = de_masks.to(self.device)
            lm_labels = de_output.clone().to(self.device)
            
            out=self.model2(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output, 
                            decoder_attention_mask=de_masks, labels=de_output.clone())
            predictions = F.log_softmax(out[1], dim=2)
            loss3 = compute_loss2(predictions, de_output, self.device, criterion)
            epoch_loss+=loss3.item()

            loss3.backward(inputs=list(self.model2.parameters()), retain_graph=True)

            del loss3
      
            # compute hessian vector product
            r=1e-2
            vector=[]
            for param in self.model2.parameters():
                if param.grad!=None:
                    vector.append(param.grad.data.to(self.device))
                else:
                    vector.append(torch.ones(1).to(self.device))
            
            R = r / _concat(vector, self.device).norm()

            print(R)
            for p, v in zip(self.model2.parameters(), vector):
                p.data.to(self.device)
                p.data.add_(alpha=R, other=v)
                        
            #calculate loss
            outputs=self.model1(input_ids=en_input, decoder_input_ids=en_input, output_hidden_states=True, return_dict=True)
            predictions = F.log_softmax(outputs.logits, dim=2)
            values, new_labels = torch.max(predictions, 2)
            
            out=self.model2(input_ids=en_input, decoder_inputs_embeds=outputs.decoder_hidden_states[-1], labels=new_labels)
            predictions = F.log_softmax(out[1], dim=2)
            loss2=compute_loss2(predictions, new_labels, self.device, criterion)
            
            grads_p=torch.autograd.grad(loss2, self.model1.parameters(), allow_unused=True, retain_graph=True)

            del loss2
            del predictions
            del out 
            del outputs
            del new_labels
            torch.cuda.empty_cache()
            for p, v in zip(self.model2.parameters(), vector):
                p.data.sub_(alpha=2 * R, other=v)
               
            
            #calculate loss
            outputs=self.model1(input_ids=en_input, decoder_input_ids=en_input, output_hidden_states=True, return_dict=True)
            predictions = F.log_softmax(outputs.logits, dim=2)
            values, new_labels = torch.max(predictions, 2)
            
            out=self.model2(input_ids=en_input, decoder_inputs_embeds=outputs.decoder_hidden_states[-1], labels=new_labels)
            predictions = F.log_softmax(out[1], dim=2)
            loss2=compute_loss2(predictions, new_labels, self.device, criterion)
        
            grads_n = torch.autograd.grad(loss2, self.model1.parameters(), allow_unused=True, retain_graph=True)

            del loss2
            del predictions
            del out 
            del outputs
            del new_labels
            for p, v in zip(self.model2.parameters(), vector):
                p.data.add_(R, v)
            
            del vector
            torch.cuda.empty_cache()
            
            # OOM error occurs here
            vector=[]
            for x,y in zip(grads_p, grads_n):
                if x!=None and y!=None:
                    vector.append(((x - y).div_(2 * R)))
                else:
                    vector.append(torch.ones(1))
            
            del grads_n
            del grads_p
            torch.cuda.empty_cache()
            for p, v in zip(self.model1.parameters(), vector):
                p.data.add_(alpha=R, other=v)
                
            #calculate loss
            out = self.model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output, 
                                decoder_attention_mask=de_masks, labels=lm_labels.clone())
                
            predictions = F.log_softmax(out[1], dim=2)
            loss1=compute_loss1(predictions, de_output, a, self.device, criterion)    

            grads_p=torch.autograd.grad(loss1, a, allow_unused=True, retain_graph=True)

            for p, v in zip(self.model1.parameters(), vector):
                p.data.sub_(2 * R, v)

            del out
            del predictions
            del loss1
            torch.cuda.empty_cache()
            #calculate loss
            out = self.model1(input_ids=en_input, attention_mask=en_masks, decoder_input_ids=de_output, 
                                decoder_attention_mask=de_masks, labels=lm_labels.clone())
                
            predictions = F.log_softmax(out[1], dim=2)
            loss1=compute_loss1(predictions, de_output, a, self.device, criterion)    

            grads_n=torch.autograd.grad(loss1, a, allow_unused=True, retain_graph=True)

            del out
            del predictions
            del loss1
            torch.cuda.empty_cache()
            for p, v in zip(self.model1.parameters(), vector):
                p.data.add_(R, v)

            A.grad[a_ind:a_ind+self.batch_size]=[(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)][0]
            print(A.grad)
            del grads_p
            del grads_n
            torch.cuda.empty_cache()
            optimizer3.step()
            scheduler3.step()
            a_ind=a_ind+self.batch_size
            A.grad=torch.zeros(len(A), device=self.device)
            if ((i+1)*self.batch_size)% self.config['report_freq'] == 0:
                self.logger.info('loss after %d instances: %d', (i+1)*self.batch_size, epoch_loss)
                self.logger.info('bleu score after %d instances: %d', (i+1)*self.batch_size, calc_bleu(en_input, lm_labels, self.model2, tokenizer))

        self.logger.info('Mean epoch loss for step 3: %d', (epoch_loss / len(valid_dataloader))) 
            
        return (epoch_loss / len(valid_dataloader), a_ind)
train_model1(...)
train_model2(...)
val_model2(...)

Any help would be appreciated. Thanks.