Loading GAN generator

Hey everyone,

I’ve looked into this quite some time now, and I haven’t found any explanation, yet. I’ve trained a few GANs with different hyperparameters and saved the checkpoints every epoch. Now, some of the runs I’ve done have checkpoints I can load and generate data without a problem, and some just generate rubbish, as if the generator has not been trained at all or collapsed at some point.
I have changed the network structure at one point, doubling the number of filters in the first layer of the generator, but also reverted to that network structure when I loaded the checkpoints. I also don’t any error messages when loading the state dict, which should mean that everything is loaded correctly.

I have trained the network using Pytorch 0.4.1 and try to evaluate on Pytorch 0.4.0, but deleted the num_batches_tracked in the batch norm layers accordingly. It also works on some checkpoints.

I’m utterly confused on why that would happen, has anyone experienced something of that kind?

I load my generator checkpoint like this:

    generator = dcgm.GeneratorNe(nc, ngf=128, ngpu=ngpu)
    dict = torch.load(opt.loadG, map_location='cuda:0' if torch.cuda.is_available() else 'cpu')
    if torch.__version__ == '0.4.0':
        del dict['net.1.num_batches_tracked']
        del dict['net.4.num_batches_tracked']
        del dict['net.7.num_batches_tracked']
        del dict['net.10.num_batches_tracked']
        del dict['net.13.num_batches_tracked']
    generator.load_state_dict(dict)
    generator.to(gpu)
    generator.eval()

Hi sami,

I have the same issue. I loaded weights file and generated image from weights. It seemed that using random noise not really learnt weights. Have you figured out why is that? I haven’t successfully generated one image for now.

I’m using torch 0.4.1.

Weights file is saved via torch.save(generator.state_dict(), path) during training phase.
During testing phase, I did:

model = generator()
checkpoint = torch.load(‘path/001_G.pth’, map_location = str(device))
model.load_state_dict(checkpoint, strict=False)
model.to(device)
model.float()
model.eval()

def label_sampel():
label = torch.LongTensor(batch_size, 1).random_()%n_class
one_hot= torch.zeros(batch_size, n_class).scatter_(1, label, 1)
print(device)
return label.squeeze(1).to(device), one_hot.to(device)

z = torch.randn(batch_size, z_dim).to(device)
z_class, z_class_one_hot = label_sampel()

fake_images = model(z, z_class_one_hot)
save_image(denorm(fake_images.data), os.path.join(path, ‘1_generated.png’))