I am trying to retrain a Stack-GAN-v2 model by saving the neural network and optimiser parameters. But, when I reload the parameters the generator is unable to generate good quality images.
# ################## Shared functions ###################
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.orthogonal(m.weight.data, 1.0)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('Linear') != -1:
nn.init.orthogonal(m.weight.data, 1.0)
if m.bias is not None:
m.bias.data.fill_(0.0)
def load_params(model, new_param):
for p, new_p in zip(model.parameters(), new_param):
p.data.copy_(new_p)
def copy_G_params(model):
flatten = deepcopy(list(p.data for p in model.parameters()))
return flatten
def compute_inception_score(predictions, num_splits=1):
# print('predictions', predictions.shape)
scores = []
for i in range(num_splits):
istart = i * predictions.shape[0] // num_splits
iend = (i + 1) * predictions.shape[0] // num_splits
part = predictions[istart:iend, :]
kl = part * \
(np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
kl = np.mean(np.sum(kl, 1))
scores.append(np.exp(kl))
return np.mean(scores), np.std(scores)
def negative_log_posterior_probability(predictions, num_splits=1):
# print('predictions', predictions.shape)
scores = []
for i in range(num_splits):
istart = i * predictions.shape[0] // num_splits
iend = (i + 1) * predictions.shape[0] // num_splits
part = predictions[istart:iend, :]
result = -1. * np.log(np.max(part, 1))
result = np.mean(result)
scores.append(result)
return np.mean(scores), np.std(scores)
def load_network(gpus):
netG = G_NET()
netG.apply(weights_init)
netG = torch.nn.DataParallel(netG, device_ids=gpus)
print(netG)
netsD = []
netsD.append(D_NET64())
netsD.append(D_NET128())
for i in range(len(netsD)):
netsD[i].apply(weights_init)
netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
# print(netsD[i])
print('# of netsD', len(netsD))
count = 0
inception_model = INCEPTION_V3()
#if cfg.CUDA:
netG.cuda()
for i in range(len(netsD)):
netsD[i].cuda()
inception_model = inception_model.cuda()
inception_model.eval()
return netG, netsD, len(netsD), inception_model, count
def define_optimizers(netG, netsD):
optimizersD = []
num_Ds = len(netsD)
for i in range(num_Ds):
opt = optim.Adam(netsD[i].parameters(),
lr=LEARNING_RATE,
betas=(0.5, 0.999))
optimizersD.append(opt)
optimizerG = optim.Adam(netG.parameters(),
lr=LEARNING_RATE,
betas=(0.5, 0.999))
return optimizerG, optimizersD
def save_model(netG, avg_param_G, netsD, epoch, model_dir):
load_params(netG, avg_param_G)
torch.save(
netG.state_dict(),
'%s/netG_%d.pth' % (model_dir, epoch))
for i in range(len(netsD)):
netD = netsD[i]
torch.save(
netD.state_dict(),
'%s/netD%d.pth' % (model_dir, i))
print('Save G/Ds models.')
def save_img_results(imgs_tcpu, fake_imgs, num_imgs,
count, image_dir):
num = VIS_COUNT
# The range of real_img (i.e., self.imgs_tcpu[i][0:num])
# is changed to [0, 1] by function vutils.save_image
real_img = imgs_tcpu[-1][0:num]
vutils.save_image(
real_img, '%s/real_samples.png' % (image_dir),
normalize=True)
real_img_set = vutils.make_grid(real_img).numpy()
real_img_set = np.transpose(real_img_set, (1, 2, 0))
real_img_set = real_img_set * 255
real_img_set = real_img_set.astype(np.uint8)
for i in range(num_imgs):
fake_img = fake_imgs[i][0:num]
# The range of fake_img.data (i.e., self.fake_imgs[i][0:num])
# is still [-1. 1]...
vutils.save_image(
fake_img.data, '%s/count_%09d_fake_samples%d.png' %
(image_dir, count, i), normalize=True)
fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy()
fake_img_set = np.transpose(fake_img_set, (1, 2, 0))
fake_img_set = (fake_img_set + 1) * 255 / 2
fake_img_set = fake_img_set.astype(np.uint8)
# ################# Text to image task############################ #
class condGANTrainer(object):
def __init__(self, output_dir, data_loader, imsize):
#if cfg.TRAIN.FLAG:
self.model_dir = os.path.join(output_dir, 'Model')
self.image_dir = os.path.join(output_dir, 'Image')
self.log_dir = os.path.join(output_dir, 'Log')
mkdir_p(self.model_dir)
mkdir_p(self.image_dir)
mkdir_p(self.log_dir)
#s_gpus = cfg.GPU_ID.split(',')
self.gpus = [0]
self.num_gpus = len(self.gpus)
torch.cuda.set_device(self.gpus[0])
cudnn.benchmark = True
self.batch_size = BATCH_SIZE
self.max_epoch = EPOCHS
self.snapshot_interval = SNAPSHOT_INTERVAL
self.data_loader = data_loader
self.num_batches = len(self.data_loader)
def prepare_data(self, data):
imgs, w_imgs, t_embedding, _ = data
real_vimgs, wrong_vimgs = [], []
vembedding = Variable(t_embedding).cuda()
for i in range(self.num_Ds):
real_vimgs.append(Variable(imgs[i]).cuda())
wrong_vimgs.append(Variable(w_imgs[i]).cuda())
return imgs, real_vimgs, wrong_vimgs, vembedding
def train_Dnet(self, idx, count):
flag = count % 100
batch_size = self.real_imgs[0].size(0)
criterion, mu = self.criterion, self.mu
netD, optD = self.netsD[idx], self.optimizersD[idx]
real_imgs = self.real_imgs[idx]
wrong_imgs = self.wrong_imgs[idx]
fake_imgs = self.fake_imgs[idx]
#
netD.zero_grad()
# Forward
real_labels = self.real_labels[:batch_size]
fake_labels = self.fake_labels[:batch_size]
# for real
real_logits = netD(real_imgs, mu.detach())
wrong_logits = netD(wrong_imgs, mu.detach())
fake_logits = netD(fake_imgs.detach(), mu.detach())
#
errD_real = criterion(real_logits[0], real_labels)
errD_wrong = criterion(wrong_logits[0], fake_labels)
errD_fake = criterion(fake_logits[0], fake_labels)
if len(real_logits) > 1 and UNCOND_LOSS_COEFF > 0:
errD_real_uncond = UNCOND_LOSS_COEFF * \
criterion(real_logits[1], real_labels)
errD_wrong_uncond = UNCOND_LOSS_COEFF * \
criterion(wrong_logits[1], real_labels)
errD_fake_uncond = UNCOND_LOSS_COEFF * \
criterion(fake_logits[1], fake_labels)
#
errD_real = errD_real + errD_real_uncond
errD_wrong = errD_wrong + errD_wrong_uncond
errD_fake = errD_fake + errD_fake_uncond
#
errD = errD_real + errD_wrong + errD_fake
else:
errD = errD_real + 0.5 * (errD_wrong + errD_fake)
# backward
errD.backward()
# update parameters
optD.step()
return errD
def train_Gnet(self, count):
self.netG.zero_grad()
errG_total = 0
flag = count % 100
batch_size = self.real_imgs[0].size(0)
criterion, mu, logvar = self.criterion, self.mu, self.logvar
real_labels = self.real_labels[:batch_size]
for i in range(self.num_Ds):
outputs = self.netsD[i](self.fake_imgs[i], mu)
errG = criterion(outputs[0], real_labels)
if len(outputs) > 1 and UNCOND_LOSS_COEFF > 0:
errG_patch = UNCOND_LOSS_COEFF *\
criterion(outputs[1], real_labels)
errG = errG + errG_patch
errG_total = errG_total + errG
kl_loss = KL_loss(mu, logvar) * KL_LOSS_COEFF
errG_total = errG_total + kl_loss
errG_total.backward()
self.optimizerG.step()
return kl_loss, errG_total
def train(self):
self.netG, self.netsD, self.num_Ds,\
self.inception_model, start_count = load_network(self.gpus)
avg_param_G = copy_G_params(self.netG)
############## LOADING MODEL ################
self.netG.load_state_dict(torch.load('netG.pt'))
self.netsD[0].load_state_dict(torch.load('netsD[0].pt'))
self.netsD[1].load_state_dict(torch.load('netsD[1].pt'))
self.optimizerG, self.optimizersD = \
define_optimizers(self.netG, self.netsD)
############ LOADING OPTIMIZER ##############
self.optimizerG.load_state_dict(torch.load('netG_optim.pt'))
self.optimizersD[0].load_state_dict(torch.load('netsD[0]_optim.pt'))
self.optimizersD[1].load_state_dict(torch.load('netsD[1]_optim.pt'))
self.criterion = nn.BCELoss()
self.real_labels = \
Variable(torch.FloatTensor(self.batch_size).fill_(1))
self.fake_labels = \
Variable(torch.FloatTensor(self.batch_size).fill_(0))
self.gradient_one = torch.FloatTensor([1.0])
self.gradient_half = torch.FloatTensor([0.5])
nz = Z_DIM
noise = Variable(torch.FloatTensor(self.batch_size, nz))
fixed_noise = \
Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))
#if cfg.CUDA:
self.criterion.cuda()
self.real_labels = self.real_labels.cuda()
self.fake_labels = self.fake_labels.cuda()
self.gradient_one = self.gradient_one.cuda()
self.gradient_half = self.gradient_half.cuda()
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
predictions = []
count = start_count
start_epoch = start_count // (self.num_batches)
for epoch in range(start_epoch, self.max_epoch):
start_t = time.time()
for step, data in enumerate(self.data_loader, 0):
self.imgs_tcpu, self.real_imgs, self.wrong_imgs, \
self.txt_embedding = self.prepare_data(data)
noise.data.normal_(0, 1)
self.fake_imgs, self.mu, self.logvar = \
self.netG(noise, self.txt_embedding)
errD_total = 0
for i in range(self.num_Ds):
errD = self.train_Dnet(i, count)
errD_total += errD
kl_loss, errG_total = self.train_Gnet(count)
for p, avg_p in zip(self.netG.parameters(), avg_param_G):
avg_p.mul_(0.999).add_(0.001, p.data)
# for inception score
pred = self.inception_model(self.fake_imgs[-1].detach())
predictions.append(pred.data.cpu().numpy())
count = count + 1
if count % SNAPSHOT_INTERVAL == 0:
################# SAVING ####################
torch.save(self.netG.state_dict(),'netG.pt')
torch.save(self.optimizerG.state_dict(),'netG_optim.pt')
torch.save(self.netsD[0].state_dict(),'netsD[0].pt')
torch.save(self.optimizersD[0].state_dict(),'netsD[0]_optim.pt')
torch.save(self.netsD[1].state_dict(),'netsD[1].pt')
torch.save(self.optimizersD[1].state_dict(),'netsD[1]_optim.pt')
save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
# Save images
backup_para = copy_G_params(self.netG)
load_params(self.netG, avg_param_G)
#
self.fake_imgs, _, _ = \
self.netG(fixed_noise, self.txt_embedding)
save_img_results(self.imgs_tcpu, self.fake_imgs, self.num_Ds,
count, self.image_dir)
#
load_params(self.netG, backup_para)
# Compute inception score
if len(predictions) > 500:
predictions = np.concatenate(predictions, 0)
mean, std = compute_inception_score(predictions, 10)
mean_nlpp, std_nlpp = \
negative_log_posterior_probability(predictions, 10)
predictions = []