I implemented a GAN model and because I need to train it for 500 epochs, I’ve saved the result of each 10 epochs for both models:
torch.save({
'epoch': epoch + 1,
'gen_state_dict': gen.state_dict(),
'disc_state_dict': disc.state_dict(),
'gen_optim': opt_gen.state_dict(),
'disc_optim': opt_disc.state_dict(),
}, os.path.join("", 'gan_epoch-{}.pt'.format(epoch + 1)))
and I load it:
disc = Discriminator(in_channels=3).to(device)
gen = Generator(in_channels=3).to(device)
checkpoint = torch.load("/content/drive/MyDrive/Epochs/gan_epoch-20.pt")
gen.load_state_dict(checkpoint['gen_state_dict'])
disc.load_state_dict(checkpoint['disc_state_dict'])
opt_disc.load_state_dict(checkpoint['disc_optim'])
opt_gen.load_state_dict(checkpoint['gen_optim'])
disc.train()
gen.train()
The code works well but I am wondering if the results will be correct, I have noticed that the training becomes faster, before saving the models, one epoch takes 20 minutes, now it takes only 8 minutes also the discriminator loss increases a lot from 0.xxx to 7.xxx is this normal?