Hey everyone,
I’ve looked into this quite some time now, and I haven’t found any explanation, yet. I’ve trained a few GANs with different hyperparameters and saved the checkpoints every epoch. Now, some of the runs I’ve done have checkpoints I can load and generate data without a problem, and some just generate rubbish, as if the generator has not been trained at all or collapsed at some point.
I have changed the network structure at one point, doubling the number of filters in the first layer of the generator, but also reverted to that network structure when I loaded the checkpoints. I also don’t any error messages when loading the state dict, which should mean that everything is loaded correctly.
I have trained the network using Pytorch 0.4.1 and try to evaluate on Pytorch 0.4.0, but deleted the num_batches_tracked in the batch norm layers accordingly. It also works on some checkpoints.
I’m utterly confused on why that would happen, has anyone experienced something of that kind?
I load my generator checkpoint like this:
generator = dcgm.GeneratorNe(nc, ngf=128, ngpu=ngpu)
dict = torch.load(opt.loadG, map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.__version__ == '0.4.0':
del dict['net.1.num_batches_tracked']
del dict['net.4.num_batches_tracked']
del dict['net.7.num_batches_tracked']
del dict['net.10.num_batches_tracked']
del dict['net.13.num_batches_tracked']
generator.load_state_dict(dict)
generator.to(gpu)
generator.eval()