Hey everyone!
I’m training a GAN to make some experiments with the discriminator. For that reason, I save the checkpoint of the discriminator and generator each epoch using
# do checkpointing
torch.save(generator.state_dict(), '%s/generator_epoch_{}.pth'.format(str(log_epoch)) % (checkpointdir))
torch.save(discriminator.state_dict(), '%s/discriminator_epoch_{}.pth'.format(str(log_epoch)) % (checkpointdir))
as recommended. Now after training I’m loading the discriminator to do my experiments with it, but no matter which epoch, it fails to recognize any image I’m feeding it as real. I use the same network as the one I trained with (otherwise it would give me an error anyway, wouldn’t it?).
I do the loading as
discriminator = dcgm.DiscriminatorNet(nc=nc, alpha=opt.alpha, ndf=128, ngpu=ngpu)
dict = torch.load(opt.loadD, map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
discriminator.load_state_dict(dict)
discriminator.to(gpu)
I have logged the discriminators output during the run, and it is able to discriminate just as it should. The generator also learns as it’s supposed to, and I can generate images using the loaded checkpoint.
I’ve actually had the same issue with my generator, which I still have not logically resolved, it just started working with a different run.
I’m starting to get pretty desperate at this point!
Additional information:
I’m training on Pytorch 0.4.1 and loading on Pytorch 0.4.0.
The batch norm parameters are removed while loading like this:
if torch.__version__ == '0.4.0':
del dict['net.1.bn2.num_batches_tracked']
...
as many as it needs
Thanks!