I am trying to experiment with styleGAN-2 and styleGAN-XL, and I was trying to interchange these two in some applications to see how the results changed. So, I tried to substitute the styleGAN-2(sg2) model with styleGAN-XL(sgxl) in a repository by changing the classes that called the sg2 model to call the sgxl in the following way:
def define_G(opt):
# COMMENTED PART IS HOW THE REPO ORIGINALLY CALLED GENERATOR FOR SG2
# w_shift = opt.optim_param_g == 'w_shift'
# generator = stylegan2.Generator(
# opt.size, opt.z_dim, opt.n_mlp, lr_mlp=opt.lr_mlp, channel_multiplier=opt.channel_multiplier, w_shift=w_shift)
G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict())
G_kwargs.channel_base = 32768
G_kwargs.channel_max = 512
G_kwargs.class_name = 'training.networks.stylegan3_resetting.Generator'
G_kwargs.magnitude_ema_beta = 0.5 ** (opt.batch / (20 * 1e3))
G_kwargs.channel_base *= 2 # increase for StyleGAN-XL
G_kwargs.channel_max *= 2 # increase for StyleGAN-XL
G_kwargs.conv_kernel = 3
G_kwargs.use_radial_filters = False
G_kwargs.w_dim = 512
G_kwargs.z_dim = 64
G_kwargs.mapping_kwargs.rand_embedding = False
G_kwargs.num_layers = 14
G_kwargs.mapping_kwargs.num_layers = 2
G_kwargs.c_dim = 0
G_kwargs.img_channels = 3
G_kwargs.img_resolution = 512
common_kwargs = dict()
generator = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs)
return generator
def define_D(opt):
# COMMENTED PART IS HOW THE REPO ORIGINALLY CALLED DISCRIMINATOR FOR SG2
# discriminator = stylegan2.Discriminator(opt.size, channel_multiplier=opt.channel_multiplier)
D_kwargs = dnnlib.EasyDict(
class_name='pg_modules.discriminator.ProjectedDiscriminator',
backbones=['deit_base_distilled_patch16_224', 'tf_efficientnet_lite0'],
diffaug=True,
interp224=True,
backbone_kwargs=dnnlib.EasyDict(),
)
D_kwargs.backbone_kwargs.cout = 64
D_kwargs.backbone_kwargs.expand = True
D_kwargs.backbone_kwargs.proj_type = 2 # CCM only works better on very low resolutions
D_kwargs.backbone_kwargs.num_discs = 4
D_kwargs.backbone_kwargs.cond = False
common_kwargs = dict(c_dim=0, img_resolution=512, img_channels=3)
discriminator = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs) # subclass of torch.nn.Module
return discriminator
And this is how the model is loaded:
def initialize_networks(self, opt):
netG = networks.define_G(opt)
netD_sketch = networks.define_D(opt) if opt.isTrain else None
if opt.g_pretrained != '':
print("---------------GAN Model ----------------")
print(netG)
weights = torch.load(opt.g_pretrained, map_location=lambda storage, loc: storage)
netG.load_state_dict(weights, strict=False)
if netD_sketch is not None and opt.d_pretrained != '' and not opt.dsketch_no_pretrain:
print("Using pretrained weight for D1...")
weights = torch.load(opt.d_pretrained, map_location=lambda storage, loc: storage)
netD_sketch.load_state_dict(weights, strict=False)
But, this approach results in the following error:
Traceback (most recent call last):
File "train.py", line 106, in <module>
training_loop()
File "train.py", line 34, in training_loop
trainer = GANTrainer(opt)
File "/mnt/Data1/vmisra/GANSketching/training/gan_trainer.py", line 19, in __init__
self.gan_model = GANModel(opt).to(self.device)
File "/mnt/Data1/vmisra/GANSketching/training/gan_model.py", line 19, in __init__
self.netG, self.netD = self.initialize_networks(opt)
File "/mnt/Data1/vmisra/GANSketching/training/gan_model.py", line 148, in initialize_networks
netG.load_state_dict(weights, strict=False)
File "/home/vmisra/anaconda3/envs/gansketching/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
size mismatch for synthesis.L0_36_1024.weight: copying a param with shape torch.Size([1024, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
size mismatch for synthesis.L0_36_1024.down_filter: copying a param with shape torch.Size([12, 12]) from checkpoint, the shape in current model is torch.Size([12]).
size mismatch for synthesis.L1_36_1024.weight: copying a param with shape torch.Size([1024, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
size mismatch for synthesis.L1_36_1024.down_filter: copying a param with shape torch.Size([12, 12]) from checkpoint, the shape in current model is torch.Size([12]).
size mismatch for mapping.w_avg: copying a param with shape torch.Size([1000, 512]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for mapping.fc0.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([512, 64])
I tried to solve it by playing with the dimensions, but that had no effect, I even tried to print the model structure to see if that could help me pinpoint the error, but I couldn’t find that out. Here’s what the model structure looks like:
---------------GAN Model ----------------
Generator(
(synthesis): SynthesisNetwork(
w_dim=512, num_ws=16,
img_resolution=512, img_channels=3,
num_layers=14, num_critical=2,
margin_size=10, num_fp16_res=4
(input): SynthesisInput(
w_dim=512, channels=1024, size=[36, 36],
sampling_rate=16, bandwidth=2
(affine): FullyConnectedLayer(in_features=512, out_features=4, activation=linear)
)
(L0_36_1024): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=False,
in_sampling_rate=16, out_sampling_rate=16,
in_cutoff=2, out_cutoff=2,
in_half_width=6, out_half_width=6,
in_size=[36, 36], out_size=[36, 36],
in_channels=1024, out_channels=1024
(affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
)
(L1_36_1024): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=False,
in_sampling_rate=16, out_sampling_rate=16,
in_cutoff=2, out_cutoff=2.99661,
in_half_width=6, out_half_width=5.00339,
in_size=[36, 36], out_size=[36, 36],
in_channels=1024, out_channels=1024
(affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
)
(L2_52_1024): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=False,
in_sampling_rate=16, out_sampling_rate=32,
in_cutoff=2.99661, out_cutoff=4.48985,
in_half_width=5.00339, out_half_width=11.5102,
in_size=[36, 36], out_size=[52, 52],
in_channels=1024, out_channels=1024
(affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
)
(L3_52_1024): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=False,
in_sampling_rate=32, out_sampling_rate=32,
in_cutoff=4.48985, out_cutoff=6.72717,
in_half_width=11.5102, out_half_width=9.27283,
in_size=[52, 52], out_size=[52, 52],
in_channels=1024, out_channels=1024
(affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
)
(L4_84_1024): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=True,
in_sampling_rate=32, out_sampling_rate=64,
in_cutoff=6.72717, out_cutoff=10.0794,
in_half_width=9.27283, out_half_width=21.9206,
in_size=[52, 52], out_size=[84, 84],
in_channels=1024, out_channels=1024
(affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
)
(L5_84_1024): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=True,
in_sampling_rate=64, out_sampling_rate=64,
in_cutoff=10.0794, out_cutoff=15.102,
in_half_width=21.9206, out_half_width=16.898,
in_size=[84, 84], out_size=[84, 84],
in_channels=1024, out_channels=1024
(affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
)
(L6_148_1024): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=True,
in_sampling_rate=64, out_sampling_rate=128,
in_cutoff=15.102, out_cutoff=22.6274,
in_half_width=16.898, out_half_width=41.3726,
in_size=[84, 84], out_size=[148, 148],
in_channels=1024, out_channels=1024
(affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
)
(L7_148_967): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=True,
in_sampling_rate=128, out_sampling_rate=128,
in_cutoff=22.6274, out_cutoff=33.9028,
in_half_width=41.3726, out_half_width=30.0972,
in_size=[148, 148], out_size=[148, 148],
in_channels=1024, out_channels=967
(affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
)
(L8_276_645): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=True,
in_sampling_rate=128, out_sampling_rate=256,
in_cutoff=33.9028, out_cutoff=50.7968,
in_half_width=30.0972, out_half_width=77.2032,
in_size=[148, 148], out_size=[276, 276],
in_channels=967, out_channels=645
(affine): FullyConnectedLayer(in_features=512, out_features=967, activation=linear)
)
(L9_276_431): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=True,
in_sampling_rate=256, out_sampling_rate=256,
in_cutoff=50.7968, out_cutoff=76.1093,
in_half_width=77.2032, out_half_width=51.8907,
in_size=[276, 276], out_size=[276, 276],
in_channels=645, out_channels=431
(affine): FullyConnectedLayer(in_features=512, out_features=645, activation=linear)
)
(L10_532_287): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=True,
in_sampling_rate=256, out_sampling_rate=512,
in_cutoff=76.1093, out_cutoff=114.035,
in_half_width=51.8907, out_half_width=141.965,
in_size=[276, 276], out_size=[532, 532],
in_channels=431, out_channels=287
(affine): FullyConnectedLayer(in_features=512, out_features=431, activation=linear)
)
(L11_532_192): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=False, use_fp16=True,
in_sampling_rate=512, out_sampling_rate=512,
in_cutoff=114.035, out_cutoff=170.86,
in_half_width=141.965, out_half_width=85.1405,
in_size=[532, 532], out_size=[532, 532],
in_channels=287, out_channels=192
(affine): FullyConnectedLayer(in_features=512, out_features=287, activation=linear)
)
(L12_532_128): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=True, use_fp16=True,
in_sampling_rate=512, out_sampling_rate=512,
in_cutoff=170.86, out_cutoff=256,
in_half_width=85.1405, out_half_width=59.173,
in_size=[532, 532], out_size=[532, 532],
in_channels=192, out_channels=128
(affine): FullyConnectedLayer(in_features=512, out_features=192, activation=linear)
)
(L13_512_128): SynthesisLayer(
w_dim=512, is_torgb=False,
is_critically_sampled=True, use_fp16=True,
in_sampling_rate=512, out_sampling_rate=512,
in_cutoff=256, out_cutoff=256,
in_half_width=59.173, out_half_width=59.173,
in_size=[532, 532], out_size=[512, 512],
in_channels=128, out_channels=128
(affine): FullyConnectedLayer(in_features=512, out_features=128, activation=linear)
)
(L14_512_3): SynthesisLayer(
w_dim=512, is_torgb=True,
is_critically_sampled=True, use_fp16=True,
in_sampling_rate=512, out_sampling_rate=512,
in_cutoff=256, out_cutoff=256,
in_half_width=59.173, out_half_width=59.173,
in_size=[512, 512], out_size=[512, 512],
in_channels=128, out_channels=3
(affine): FullyConnectedLayer(in_features=512, out_features=128, activation=linear)
)
)
(mapping): MappingNetwork(
z_dim=64, c_dim=0, w_dim=512, num_ws=16
(embed): Embedding(1000, 320)
(fc0): FullyConnectedLayer(in_features=64, out_features=512, activation=lrelu)
(fc1): FullyConnectedLayer(in_features=512, out_features=512, activation=lrelu)
)
)