So basically, I wrap all the code for model training in one class and called it in the main file which is the same file I miss numpy.ma import default_fill_value
import argparse
from numpy.ma import default_fill_value # <-----this one
from src.main.generator import TranGen
import os
from utils.utils import str2bool
parser = argparse.ArgumentParser()
.
.
args = parser.parse_args()
def main():
# if create checkpoint dir if it does not exist
if not os.path.exists(args.checkpoint_dir):
os.mkdir(args.checkpoint_dir)
model = TranGen(args=args)
model.start_training()
if __name__ == "__main__":
main()
In my wrapper class look like something this
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel
from torch import optim
import numpy as np
from hparams import hparams
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
from src.dataset.generator import Dataset
from torch.utils import data as data_utils
from utils.front import frontalize_landmarks
from src.models.generator import LstmGen as Generator
#from src.models.cus_gen import Generator
#from src.models.fl_syncnet import SyncNet_fl
from src.models.cus_sync import ModSyncNet as SyncNet
from torch.utils.tensorboard import SummaryWriter
from utils.func_utils import BatchLipNorm
from utils.plot import plot_compareLip, plot_visLip, plot_comp, plot_seqlip_comp
from utils.wav2lip import load_checkpoint , save_checkpoint
from utils.utils import save_logs,load_logs , norm_lip2d
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)
class TranGen():
"""
"""
def __init__ (self, args):
# arguement and hyperparameters
self.save_name = args.save_name
self.checkpoint_dir = args.checkpoint_dir
self.checkpoint_path = args.checkpoint_path
self.ckpt_syncnet_path = args.checkpoint_syncnet_path
self.batch_size = hparams.batch_size
self.apply_disc = args.apply_disc
self.global_epoch = 0
self.nepochs = hparams.nepochs
self.lr = 0.0001
# if create checkpoint dir if it does not exist
if not os.path.exists(self.checkpoint_dir):
os.mkdir(self.checkpoint_dir)
# Tensorboard
self.writer = SummaryWriter("../tensorboard/{}".format(self.save_name))
"""<---------------------------Dataset -------------------------------------->"""
self.train_dataset = Dataset(split='train', args=args)
self.vali_dataset = Dataset(split='val', args=args)
self.train_loader = data_utils.DataLoader(self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=hparams.num_workers)
self.vali_loader = data_utils.DataLoader(self.vali_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=hparams.num_workers)
""" <------------------------------SyncNet ------------------------------------->"""
print("Loading SyncNet .......")
# load Syncnet
self.syncnet = SyncNet().to(device=device)
# load Syncnet checkpoint
self.syncnet = load_checkpoint(path=self.ckpt_syncnet_path,
model=self.syncnet,
optimizer=None,
use_cuda=use_cuda,
reset_optimizer=True,
pretrain=True)
self.syncnet.to(device=device)
self.syncnet.eval()
# frontalize weight
self.front_weight = np.load('./checkpoints/front/frontalization_weights.npy')
print("Finish loading Syncnet !!")
"""<----------------------------Generator------------------------------------------->"""
# load lip generator
self.generator = Generator().to(device=device)
self.optimizer = optim.Adam([params for params in self.generator.parameters() if params.requires_grad], lr=self.lr)
# load checkpoint if the path is given
self.continue_ckpt = False
if self.checkpoint_path is not None:
self.continue_ckpt =True
self.generator, self.optimizer, self.global_epoch = load_checkpoint(path = self.checkpoint_path,
model = self.generator,
optimizer = self.optimizer,
use_cuda = use_cuda,
reset_optimizer = False,
pretain=False
)
print("Load generator checkpoint")
if self.continue_ckpt:
self.train_loss , self.vali_loss = load_logs(model_name="generator", savename="{}.csv".format(self.save_name),epoch=self.global_epoch, type_model='generator')
self.global_epoch +=1
else:
print("Not continue form Checkpoint")
self.train_loss = np.array([])
self.vali_loss = np.array([])
"""<-----------------------Parallel Trainining-------------------------------->"""
# If GPU detect more that one then train model in parallel
if torch.cuda.device_count() > 1:
self.generator = DataParallel(self.generator)
self.batch_size = self.batch_size * torch.cuda.device_count()
print("Training or Testing model with {} GPU " .format(torch.cuda.device_count()))
self.generator.to(device)
"""<----------List of reconstruction loss-------------------------------------->"""
# Binary cross entrophy loss
self.bce_loss = nn.BCELoss()
# Mean Square Error loss
self.mse_loss = nn.MSELoss()
# L1 loss
self.l1_loss = nn.L1Loss()
# L1 smooth loss
self.l1_smooth = nn.SmoothL1Loss()
# chosen reconstruction loss
self.recon_loss = self.l1_smooth
# normalize lip
self.blnorm =BatchLipNorm()
self.recon_coeff = 0.5
self.sync_coeff = 0.5
def __train_model__ (self):
running_sync_loss = 0.
running_recon_loss = 0.
running_lapla_loss =0.
running_loss =0
iter_inbatch = 0
prog_bar = tqdm(self.train_loader)
for (con_fl, seq_mels, mel, gt_fl) in prog_bar:
self.optimizer.zero_grad()
self.generator.train()
con_lip = con_fl[:,:,48:,:].to(device)
con_face = con_fl[:,:,:48,:].to(device)
gt_lip = gt_fl[:,:,48:,:].to(device)
gt_face = gt_fl[:,:,:48,:].to(device)
seq_mels = seq_mels.to(device)
mel = mel.to(device)
gen_lip, _ = self.generator(seq_mels, con_lip)
gt_lip = gt_lip.reshape(gt_lip.size(0),-1)
gen_lip = gen_lip.reshape(gen_lip.size(0),-1)
recon_loss = self.recon_loss(gen_lip,gt_lip)
sync_loss = self.__get_sync_loss__ (mel, gen_lip, gt_face) if self.global_epoch >=self.apply_disc else torch.zeros(1)
loss = (self.recon_coeff * recon_loss) + (self.sync_coeff * sync_loss)if self.global_epoch >= self.apply_disc else recon_loss
loss.backward()
self.optimizer.step()
# display a loss before multiply to coefficent
running_recon_loss += recon_loss.item()
running_sync_loss += sync_loss.item()
running_loss += loss.item()
iter_inbatch+=1
prog_bar.set_description("TRAIN Epochs: {} , Loss : {:.3f} , Recon : {:.3f}, Sync : {:.3f}".format(self.global_epoch,
running_loss/iter_inbatch,
running_recon_loss/iter_inbatch,
running_sync_loss/iter_inbatch))
avg_loss = running_loss / iter_inbatch
avg_recon_loss = running_loss / iter_inbatch
avg_sync_loss = running_sync_loss/ iter_inbatch
return avg_loss, avg_recon_loss, avg_sync_loss
def __training_stage__ (self):
while self.global_epoch < self.nepochs:
cur_train_loss , cur_train_recon_loss , cur_train_sync_loss= self.__train_model__()
cur_vali_loss , cur_vali_recon_loss , cur_vali_sync_loss= self.__eval_model__()
com_fig = self.__compare_lip__()
com_seq_fig = self.__vis_seq_result__()
self.__update_logs__(cur_train_loss, cur_vali_loss, cur_train_recon_loss, cur_vali_recon_loss, cur_train_sync_loss, cur_vali_sync_loss, com_fig, com_seq_fig)
# save checkpoint
save_checkpoint(self.generator, self.optimizer, self.checkpoint_dir, self.global_epoch)
self.global_epoch +=1
def start_training(self):
print("Save name : {}".format(self.save_name))
print("Using CUDA : {} ".format(use_cuda))
if use_cuda: print ("Using {} GPU".format(torch.cuda.device_count()))
print("Training dataset {}".format(len(self.train_dataset)))
print("Validation dataset {}".format(len(self.vali_dataset)))
self.__vis_comp_graph__()
print("Start training SyncNet")
self.__training_stage__()
print("Finish Trainig SyncNet")
self.writer.close()
So, I tried to measure the time of training for one batch iteration with the import numpy that cause the problem and without. I start timing from optimizer.zero_grad() to optimizer.step()
for (con_fl, seq_mels, mel, gt_fl) in prog_bar:
start = time.time()
self.optimizer.zero_grad()
self.generator.train()
con_lip = con_fl[:,:,48:,:].to(device)
con_face = con_fl[:,:,:48,:].to(device)
gt_lip = gt_fl[:,:,48:,:].to(device)
gt_face = gt_fl[:,:,:48,:].to(device)
seq_mels = seq_mels.to(device)
mel = mel.to(device)
gen_lip, _ = self.generator(seq_mels, con_lip)
gt_lip = gt_lip.reshape(gt_lip.size(0),-1)
gen_lip = gen_lip.reshape(gen_lip.size(0),-1)
recon_loss = self.recon_loss(gen_lip,gt_lip)
sync_loss = self.__get_sync_loss__ (mel, gen_lip, gt_face) if self.global_epoch >=self.apply_disc else torch.zeros(1)
loss = (self.recon_coeff * recon_loss) + (self.sync_coeff * sync_loss)if self.global_epoch >= self.apply_disc else recon_loss
loss.backward()
self.optimizer.step()
end = time.time()
print(end-start)
This is the time display in a console with the numpy import ,It shows that for one batch h iteration, It takes 0.2-4 seconds and on the right tqdm progress bar estimates the while epoch to be finished by 10 min.
But if I remove that numpy import from the first file, this is the measured time. For each batch, it takes only 0.03-57 sec, which is almost ten times faster than the previous one. And for one epoch, it takes only 1 min, which is the normal time it supposes to take.