I try to train two DNN jointly, The model is trained and goes to the validation phase after every 5 epochs, the problem is after the 5 epochs it is okay and no problem with memory, but after 10 epochs the model complains about Cuda memory. Any help to solve the memory issue.
class Trainer(BaseTrainer):
def __init__(self, config, resume: bool, model, loss_function, optimizer, train_dataloader, validation_dataloader):
super(Trainer, self).__init__(config, resume, model, loss_function, optimizer)
self.train_data_loader = train_dataloader
self.validation_data_loader = validation_dataloader
self.model = self.model.double()
def _train_epoch(self, epoch):
for i, (mixture, clean, name, label) in enumerate(self.train_data_loader):
mixture = mixture.to(self.device, dtype=torch.double)
clean = clean.to(self.device, dtype=torch.double)
enhanced = self.model(mixture).to(self.device)
front_loss = self.loss_function(clean, enhanced)
front_loss.backward(retain_graph=True)
torch.cuda.empty_cache()
model_back.train()
y = model_back(enhanced.double().to(device2))
back_loss = backend_loss(y[0], label[0].to(device2))
print("Iteration %d in epoch%d--> loss = %f" % (i, epoch, back_loss.item()), end='\r')
back_loss.backward(retain_graph=True)
self.optimizer.zero_grad()
self.optimizer.step()
torch.cuda.empty_cache()
dl_len = len(self.train_data_loader)
@torch.no_grad()
def _validation_epoch(self, epoch):
sample_length = self.validation_custom_config["sample_length"]
stoi_c_n = [] # clean and noisy
stoi_c_e = [] # clean and enhanced
stoi_e_n = []
pesq_c_n = []
pesq_c_e = []
pesq_e_n = []
correct = []
for i, (mixture, clean, name, label) in enumerate(self.validation_data_loader):
#assert len(name) == 1, "Only support batch size is 1 in enhancement stage."
name = name[0]
padded_length = 0
mixture = mixture.to(self.device)
if mixture.size(-1) % sample_length != 0:
padded_length = sample_length - (mixture.size(-1) % sample_length)
mixture = torch.cat([mixture, torch.zeros(1, 1, padded_length, device=self.device)], dim=-1)
assert mixture.size(-1) % sample_length == 0 and mixture.dim() == 3
mixture_chunks = list(torch.split(mixture, sample_length, dim=-1))
enhanced_chunks = []
for chunk in mixture_chunks:
enhanced_chunks.append(self.model(chunk.double()).detach().cpu())
enhanced = torch.cat(enhanced_chunks, dim=-1) # [1, 1, T]
enhanced = enhanced.to(self.device)
#print(enhanced)
if padded_length != 0:
enhanced = enhanced[:, :, :-padded_length]
mixture = mixture[:, :, :-padded_length]
torch.cuda.empty_cache()
model_back.eval()
y_pred = model_back(enhanced.double().to(self.device))
pred = torch.argmax(y_pred[0].detach().cpu(), dim=1)
intent_pred = pred
correct.append((intent_pred == label[0]).float())
torch.cuda.empty_cache()
acc = np.mean(np.hstack(correct))
intent_acc = acc
iter_acc = '\n iteration %d epoch %d -->' %(i, epoch)
print(iter_acc, acc, best_accuracy)
if intent_acc > best_accuracy:
improved_accuracy = 'Current accuracy {}, {}'.format(intent_acc, best_accuracy)
print(improved_accuracy)
torch.save(model_back.state_dict(), '/home/mnabih/jt/best_model.pkl')