How to save some memory due to Cuda out of memory

I try to train two DNN jointly, The model is trained and goes to the validation phase after every 5 epochs, the problem is after the 5 epochs it is okay and no problem with memory, but after 10 epochs the model complains about Cuda memory. Any help to solve the memory issue.

class Trainer(BaseTrainer):

def __init__(self, config, resume: bool, model, loss_function, optimizer, train_dataloader, validation_dataloader):
    super(Trainer, self).__init__(config, resume, model, loss_function, optimizer)
    self.train_data_loader = train_dataloader
    self.validation_data_loader = validation_dataloader
    self.model = self.model.double()
    
def _train_epoch(self, epoch):
  
    for i, (mixture, clean, name, label) in enumerate(self.train_data_loader):
        mixture = mixture.to(self.device, dtype=torch.double)
        clean = clean.to(self.device, dtype=torch.double)
        enhanced = self.model(mixture).to(self.device)
        front_loss = self.loss_function(clean, enhanced)
        
        front_loss.backward(retain_graph=True)
        
        torch.cuda.empty_cache()
        model_back.train()
        y = model_back(enhanced.double().to(device2))
        back_loss = backend_loss(y[0], label[0].to(device2))
        print("Iteration %d in epoch%d--> loss = %f" % (i, epoch, back_loss.item()), end='\r')     
        back_loss.backward(retain_graph=True)
        self.optimizer.zero_grad()
        self.optimizer.step()
        torch.cuda.empty_cache()
     dl_len = len(self.train_data_loader)
 

@torch.no_grad()
def _validation_epoch(self, epoch):

        sample_length = self.validation_custom_config["sample_length"]

        stoi_c_n = []  # clean and noisy
        stoi_c_e = []  # clean and enhanced
        stoi_e_n = []
        pesq_c_n = []
        pesq_c_e = []
        pesq_e_n = []
        correct = []

        for i, (mixture, clean, name, label) in enumerate(self.validation_data_loader):
            #assert len(name) == 1, "Only support batch size is 1 in enhancement stage."
            name = name[0]
            padded_length = 0

            mixture = mixture.to(self.device)

            if mixture.size(-1) % sample_length != 0:
                padded_length = sample_length - (mixture.size(-1) % sample_length)
                mixture = torch.cat([mixture, torch.zeros(1, 1, padded_length, device=self.device)], dim=-1)

            assert mixture.size(-1) % sample_length == 0 and mixture.dim() == 3
            mixture_chunks = list(torch.split(mixture, sample_length, dim=-1))

            enhanced_chunks = []
            for chunk in mixture_chunks:
                enhanced_chunks.append(self.model(chunk.double()).detach().cpu())

            enhanced = torch.cat(enhanced_chunks, dim=-1)  # [1, 1, T]
            enhanced = enhanced.to(self.device)
            #print(enhanced)
            if padded_length != 0:
                enhanced = enhanced[:, :, :-padded_length]
                mixture = mixture[:, :, :-padded_length]

            torch.cuda.empty_cache()
            model_back.eval()
            
            y_pred = model_back(enhanced.double().to(self.device))
            pred = torch.argmax(y_pred[0].detach().cpu(), dim=1)
            intent_pred = pred
            correct.append((intent_pred == label[0]).float())
            torch.cuda.empty_cache()
        acc = np.mean(np.hstack(correct))
        intent_acc = acc
        iter_acc = '\n iteration %d epoch %d -->' %(i, epoch)
        print(iter_acc, acc, best_accuracy)
        if intent_acc > best_accuracy:
            improved_accuracy = 'Current accuracy {}, {}'.format(intent_acc, best_accuracy)
            print(improved_accuracy)
            torch.save(model_back.state_dict(), '/home/mnabih/jt/best_model.pkl')
1 Like

@albanD @ptrblck
Do you have any suggestions?