Encountering a mismatch error while trying to load pytorch model.

I am trying to experiment with styleGAN-2 and styleGAN-XL, and I was trying to interchange these two in some applications to see how the results changed. So, I tried to substitute the styleGAN-2(sg2) model with styleGAN-XL(sgxl) in a repository by changing the classes that called the sg2 model to call the sgxl in the following way:

def define_G(opt):
    # COMMENTED PART IS HOW THE REPO ORIGINALLY CALLED GENERATOR FOR SG2
    # w_shift = opt.optim_param_g == 'w_shift'
    # generator = stylegan2.Generator(
    #     opt.size, opt.z_dim, opt.n_mlp, lr_mlp=opt.lr_mlp, channel_multiplier=opt.channel_multiplier, w_shift=w_shift)
    G_kwargs = dnnlib.EasyDict(class_name=None, z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict())


    G_kwargs.channel_base = 32768
    G_kwargs.channel_max = 512
    G_kwargs.class_name = 'training.networks.stylegan3_resetting.Generator'
    G_kwargs.magnitude_ema_beta = 0.5 ** (opt.batch / (20 * 1e3))
    G_kwargs.channel_base *= 2  # increase for StyleGAN-XL
    G_kwargs.channel_max *= 2   # increase for StyleGAN-XL
    G_kwargs.conv_kernel = 3
    G_kwargs.use_radial_filters = False
    G_kwargs.w_dim = 512
    G_kwargs.z_dim = 64
    G_kwargs.mapping_kwargs.rand_embedding = False
    G_kwargs.num_layers = 14
    G_kwargs.mapping_kwargs.num_layers = 2
    G_kwargs.c_dim = 0
    G_kwargs.img_channels = 3

    G_kwargs.img_resolution = 512

    common_kwargs = dict()

    generator = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs)

    return generator


def define_D(opt):
    # COMMENTED PART IS HOW THE REPO ORIGINALLY CALLED DISCRIMINATOR FOR SG2
    # discriminator = stylegan2.Discriminator(opt.size, channel_multiplier=opt.channel_multiplier)  
    D_kwargs = dnnlib.EasyDict(
        class_name='pg_modules.discriminator.ProjectedDiscriminator',
        backbones=['deit_base_distilled_patch16_224', 'tf_efficientnet_lite0'],
        diffaug=True,
        interp224=True,
        backbone_kwargs=dnnlib.EasyDict(),
    )

    D_kwargs.backbone_kwargs.cout = 64
    D_kwargs.backbone_kwargs.expand = True
    D_kwargs.backbone_kwargs.proj_type = 2  # CCM only works better on very low resolutions
    D_kwargs.backbone_kwargs.num_discs = 4
    D_kwargs.backbone_kwargs.cond = False

    common_kwargs = dict(c_dim=0, img_resolution=512, img_channels=3)

    discriminator = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs) # subclass of torch.nn.Module

    return discriminator

And this is how the model is loaded:

def initialize_networks(self, opt):
        netG = networks.define_G(opt)
        netD_sketch = networks.define_D(opt) if opt.isTrain else None

        if opt.g_pretrained != '':
            print("---------------GAN Model ----------------")
            print(netG)
            weights = torch.load(opt.g_pretrained, map_location=lambda storage, loc: storage)
            netG.load_state_dict(weights, strict=False)

        if netD_sketch is not None and opt.d_pretrained != '' and not opt.dsketch_no_pretrain:
            print("Using pretrained weight for D1...")
            weights = torch.load(opt.d_pretrained, map_location=lambda storage, loc: storage)
            netD_sketch.load_state_dict(weights, strict=False)

But, this approach results in the following error:

Traceback (most recent call last):
  File "train.py", line 106, in <module>
    training_loop()
  File "train.py", line 34, in training_loop
    trainer = GANTrainer(opt)
  File "/mnt/Data1/vmisra/GANSketching/training/gan_trainer.py", line 19, in __init__
    self.gan_model = GANModel(opt).to(self.device)
  File "/mnt/Data1/vmisra/GANSketching/training/gan_model.py", line 19, in __init__
    self.netG, self.netD = self.initialize_networks(opt)
  File "/mnt/Data1/vmisra/GANSketching/training/gan_model.py", line 148, in initialize_networks
    netG.load_state_dict(weights, strict=False)
  File "/home/vmisra/anaconda3/envs/gansketching/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Generator:
        size mismatch for synthesis.L0_36_1024.weight: copying a param with shape torch.Size([1024, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
        size mismatch for synthesis.L0_36_1024.down_filter: copying a param with shape torch.Size([12, 12]) from checkpoint, the shape in current model is torch.Size([12]).
        size mismatch for synthesis.L1_36_1024.weight: copying a param with shape torch.Size([1024, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
        size mismatch for synthesis.L1_36_1024.down_filter: copying a param with shape torch.Size([12, 12]) from checkpoint, the shape in current model is torch.Size([12]).
        size mismatch for mapping.w_avg: copying a param with shape torch.Size([1000, 512]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for mapping.fc0.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([512, 64])

I tried to solve it by playing with the dimensions, but that had no effect, I even tried to print the model structure to see if that could help me pinpoint the error, but I couldn’t find that out. Here’s what the model structure looks like:

---------------GAN Model ----------------
Generator(
  (synthesis): SynthesisNetwork(
    w_dim=512, num_ws=16,
    img_resolution=512, img_channels=3,
    num_layers=14, num_critical=2,
    margin_size=10, num_fp16_res=4
    (input): SynthesisInput(
      w_dim=512, channels=1024, size=[36, 36],
      sampling_rate=16, bandwidth=2
      (affine): FullyConnectedLayer(in_features=512, out_features=4, activation=linear)
    )
    (L0_36_1024): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=False,
      in_sampling_rate=16, out_sampling_rate=16,
      in_cutoff=2, out_cutoff=2,
      in_half_width=6, out_half_width=6,
      in_size=[36, 36], out_size=[36, 36],
      in_channels=1024, out_channels=1024
      (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
    )
    (L1_36_1024): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=False,
      in_sampling_rate=16, out_sampling_rate=16,
      in_cutoff=2, out_cutoff=2.99661,
      in_half_width=6, out_half_width=5.00339,
      in_size=[36, 36], out_size=[36, 36],
      in_channels=1024, out_channels=1024
      (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
    )
    (L2_52_1024): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=False,
      in_sampling_rate=16, out_sampling_rate=32,
      in_cutoff=2.99661, out_cutoff=4.48985,
      in_half_width=5.00339, out_half_width=11.5102,
      in_size=[36, 36], out_size=[52, 52],
      in_channels=1024, out_channels=1024
      (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
    )
    (L3_52_1024): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=False,
      in_sampling_rate=32, out_sampling_rate=32,
      in_cutoff=4.48985, out_cutoff=6.72717,
      in_half_width=11.5102, out_half_width=9.27283,
      in_size=[52, 52], out_size=[52, 52],
      in_channels=1024, out_channels=1024
      (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
    )
    (L4_84_1024): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=True,
      in_sampling_rate=32, out_sampling_rate=64,
      in_cutoff=6.72717, out_cutoff=10.0794,
      in_half_width=9.27283, out_half_width=21.9206,
      in_size=[52, 52], out_size=[84, 84],
      in_channels=1024, out_channels=1024
      (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
    )
    (L5_84_1024): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=True,
      in_sampling_rate=64, out_sampling_rate=64,
      in_cutoff=10.0794, out_cutoff=15.102,
      in_half_width=21.9206, out_half_width=16.898,
      in_size=[84, 84], out_size=[84, 84],
      in_channels=1024, out_channels=1024
      (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
    )
    (L6_148_1024): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=True,
      in_sampling_rate=64, out_sampling_rate=128,
      in_cutoff=15.102, out_cutoff=22.6274,
      in_half_width=16.898, out_half_width=41.3726,
      in_size=[84, 84], out_size=[148, 148],
      in_channels=1024, out_channels=1024
      (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
    )
    (L7_148_967): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=True,
      in_sampling_rate=128, out_sampling_rate=128,
      in_cutoff=22.6274, out_cutoff=33.9028,
      in_half_width=41.3726, out_half_width=30.0972,
      in_size=[148, 148], out_size=[148, 148],
      in_channels=1024, out_channels=967
      (affine): FullyConnectedLayer(in_features=512, out_features=1024, activation=linear)
    )
    (L8_276_645): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=True,
      in_sampling_rate=128, out_sampling_rate=256,
      in_cutoff=33.9028, out_cutoff=50.7968,
      in_half_width=30.0972, out_half_width=77.2032,
      in_size=[148, 148], out_size=[276, 276],
      in_channels=967, out_channels=645
      (affine): FullyConnectedLayer(in_features=512, out_features=967, activation=linear)
    )
    (L9_276_431): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=True,
      in_sampling_rate=256, out_sampling_rate=256,
      in_cutoff=50.7968, out_cutoff=76.1093,
      in_half_width=77.2032, out_half_width=51.8907,
      in_size=[276, 276], out_size=[276, 276],
      in_channels=645, out_channels=431
      (affine): FullyConnectedLayer(in_features=512, out_features=645, activation=linear)
    )
    (L10_532_287): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=True,
      in_sampling_rate=256, out_sampling_rate=512,
      in_cutoff=76.1093, out_cutoff=114.035,
      in_half_width=51.8907, out_half_width=141.965,
      in_size=[276, 276], out_size=[532, 532],
      in_channels=431, out_channels=287
      (affine): FullyConnectedLayer(in_features=512, out_features=431, activation=linear)
    )
    (L11_532_192): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=False, use_fp16=True,
      in_sampling_rate=512, out_sampling_rate=512,
      in_cutoff=114.035, out_cutoff=170.86,
      in_half_width=141.965, out_half_width=85.1405,
      in_size=[532, 532], out_size=[532, 532],
      in_channels=287, out_channels=192
      (affine): FullyConnectedLayer(in_features=512, out_features=287, activation=linear)
    )
    (L12_532_128): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=True, use_fp16=True,
      in_sampling_rate=512, out_sampling_rate=512,
      in_cutoff=170.86, out_cutoff=256,
      in_half_width=85.1405, out_half_width=59.173,
      in_size=[532, 532], out_size=[532, 532],
      in_channels=192, out_channels=128
      (affine): FullyConnectedLayer(in_features=512, out_features=192, activation=linear)
    )
    (L13_512_128): SynthesisLayer(
      w_dim=512, is_torgb=False,
      is_critically_sampled=True, use_fp16=True,
      in_sampling_rate=512, out_sampling_rate=512,
      in_cutoff=256, out_cutoff=256,
      in_half_width=59.173, out_half_width=59.173,
      in_size=[532, 532], out_size=[512, 512],
      in_channels=128, out_channels=128
      (affine): FullyConnectedLayer(in_features=512, out_features=128, activation=linear)
    )
    (L14_512_3): SynthesisLayer(
      w_dim=512, is_torgb=True,
      is_critically_sampled=True, use_fp16=True,
      in_sampling_rate=512, out_sampling_rate=512,
      in_cutoff=256, out_cutoff=256,
      in_half_width=59.173, out_half_width=59.173,
      in_size=[512, 512], out_size=[512, 512],
      in_channels=128, out_channels=3
      (affine): FullyConnectedLayer(in_features=512, out_features=128, activation=linear)
    )
  )
  (mapping): MappingNetwork(
    z_dim=64, c_dim=0, w_dim=512, num_ws=16
    (embed): Embedding(1000, 320)
    (fc0): FullyConnectedLayer(in_features=64, out_features=512, activation=lrelu)
    (fc1): FullyConnectedLayer(in_features=512, out_features=512, activation=lrelu)
  )
)