Training my pretrained model in different dataset and I got an error of size mismatch

Hello all;
I trained my model from scratch on Publaynet data set and I want to use my pretrained model to train it on different dataset( a small dataset of around 1000 image) so I am trying to do a classical fine tuning pretrained model but I got this error of size mismatch it seems that the size have been changed after training.
this is the error I got :
RuntimeError: Error(s) in loading state_dict for Generator:
size mismatch for crop_encoder.bn1.embed.weight: copying a param with shape torch.Size([6, 128]) from checkpoint, the shape in current model is torch.Size([7, 128]).
size mismatch for crop_encoder.bn2.embed.weight: copying a param with shape torch.Size([6, 256]) from checkpoint, the shape in current model is torch.Size([7, 256]).
size mismatch for crop_encoder.bn3.embed.weight: copying a param with shape torch.Size([6, 512]) from checkpoint, the shape in current model is torch.Size([7, 512]).
size mismatch for crop_encoder.bn4.embed.weight: copying a param with shape torch.Size([6, 1024]) from checkpoint, the shape in current model is torch.Size([7, 1024]).
size mismatch for crop_encoder.bn5.embed.weight: copying a param with shape torch.Size([6, 2048]) from checkpoint, the shape in current model is torch.Size([7, 2048]).
size mismatch for layout_encoder.embedding.weight: copying a param with shape torch.Size([6, 64]) from checkpoint, the shape in current model is torch.Size([7, 64]).
size mismatch for layout_encoder.bn1.embed.weight: copying a param with shape torch.Size([6, 128]) from checkpoint, the shape in current model is torch.Size([7, 128]).
size mismatch for layout_encoder.bn2.embed.weight: copying a param with shape torch.Size([6, 256]) from checkpoint, the shape in current model is torch.Size([7, 256]).
size mismatch for layout_encoder.bn3.embed.weight: copying a param with shape torch.Size([6, 512]) from checkpoint, the shape in current model is torch.Size([7, 512]).
size mismatch for layout_encoder.bn4.embed.weight: copying a param with shape torch.Size([6, 1024]) from checkpoint, the shape in current model is torch.Size([7, 1024]).

Please could you help me on how I can solve this error ?
This is my code :

import torch
import argparse
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from models.generator_128 import Generator
from models.discriminator import ImageDiscriminator
from models.discriminator import ObjectDiscriminator
from models.discriminator import add_sn
from data.coco_custom_mask import get_dataloader as get_dataloader_coco
from data.vg_custom_mask import get_dataloader as get_dataloader_vg
from data.publaynet_custom_mask import get_dataloader as get_dataloader_publaynet
from utils.model_saver import load_model, save_model, prepare_dir
from utils.data import imagenet_deprocess_batch
from utils.miscs import str2bool
import torch.backends.cudnn as cudnn

def main(config):
cudnn.benchmark = True
device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir(config.exp_name)

if config.dataset == 'publaynet':
    data_loader, _ = get_dataloader_publaynet(batch_size=config.batch_size, COCO_DIR=config.coco_dir)
vocab_num = data_loader.dataset.num_objects

assert config.clstm_layers > 0

netG = Generator(num_embeddings=vocab_num,
                 embedding_dim=config.embedding_dim,
                 z_dim=config.z_dim,
                 clstm_layers=config.clstm_layers).to(device)
netD_image = ImageDiscriminator(conv_dim=config.embedding_dim).to(device)
netD_object = ObjectDiscriminator(n_class=vocab_num).to(device)

netD_image = add_sn(netD_image)
netD_object = add_sn(netD_object)

netG_optimizer = torch.optim.Adam(netG.parameters(), config.learning_rate, [0.5, 0.999])
netD_image_optimizer = torch.optim.Adam(netD_image.parameters(), config.learning_rate, [0.5, 0.999])
netD_object_optimizer = torch.optim.Adam(netD_object.parameters(), config.learning_rate, [0.5, 0.999])

print('load model from: {}')

netG.load_state_dict(torch.load("/home/user/PycharmProjects/synth_doc_layout/layout2im/checkpoints/pretrained/iter-300000_netG.pkl"))
netD_image.load_state_dict(torch.load("/home/user/PycharmProjects/synth_doc_layout/layout2im/checkpoints/pretrained/iter-300000_netD_image.pkl"))
netD_object.load_state_dict(torch.load("/home/user/PycharmProjects/synth_doc_layout/layout2im/checkpoints/pretrained/iter-300000_netD_object.pkl"))

data_iter = iter(data_loader)

if config.use_tensorboard: writer = SummaryWriter(log_save_dir)

for i in range(data_iter, config.niter):
    try:
        batch = next(data_iter)
    except:
        data_iter = iter(data_loader)
        batch = next(data_iter)

        # =================================================================================== #
        #                             1. Preprocess input data                                #
        # =================================================================================== #
    imgs, objs, boxes, masks, obj_to_img = batch
    z = torch.randn(objs.size(0), config.z_dim)
    imgs, objs, boxes, masks, obj_to_img, z = imgs.to(device), objs.to(device), boxes.to(device), \
                                            masks.to(device), obj_to_img, z.to(device)

        # =================================================================================== #
        #                             2. Train the discriminator                              #
        # =================================================================================== #

        # Generate fake image
    output = netG(imgs, objs, boxes, masks, obj_to_img, z)
    crops_input, crops_input_rec, crops_rand, img_rec, img_rand, mu, logvar, z_rand_rec = output

        # Compute image adv loss with fake images.
    out_logits = netD_image(img_rec.detach())
    d_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits(out_logits, torch.full_like(out_logits, 0))

    out_logits = netD_image(img_rand.detach())
    d_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(out_logits, torch.full_like(out_logits, 0))

    d_image_adv_loss_fake = 0.5 * d_image_adv_loss_fake_rec + 0.5 * d_image_adv_loss_fake_rand

        # Compute image src loss with real images rec.
    out_logits = netD_image(imgs)
    d_image_adv_loss_real = F.binary_cross_entropy_with_logits(out_logits, torch.full_like(out_logits, 1))

        # Compute object sn adv loss with fake rec crops
    out_logits, _ = netD_object(crops_input_rec.detach(), objs)
    g_object_adv_loss_rec = F.binary_cross_entropy_with_logits(out_logits, torch.full_like(out_logits, 0))

        # Compute object sn adv loss with fake rand crops
    out_logits, _ = netD_object(crops_rand.detach(), objs)
    d_object_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(out_logits, torch.full_like(out_logits, 0))

    d_object_adv_loss_fake = 0.5 * g_object_adv_loss_rec + 0.5 * d_object_adv_loss_fake_rand

        # Compute object sn adv loss with real crops.
    out_logits_src, out_logits_cls = netD_object(crops_input.detach(), objs)
    d_object_adv_loss_real = F.binary_cross_entropy_with_logits(out_logits_src,
                                                                    torch.full_like(out_logits_src, 1))
    d_object_cls_loss_real = F.cross_entropy(out_logits_cls, objs)

        # Backward and optimizloe.
    d_loss = 0
    d_loss += config.lambda_img_adv * (d_image_adv_loss_fake + d_image_adv_loss_real)
    d_loss += config.lambda_obj_adv * (d_object_adv_loss_fake + d_object_adv_loss_real)
    d_loss += config.lambda_obj_cls * d_object_cls_loss_real

    netD_image.zero_grad()
    netD_object.zero_grad()

    d_loss.backward()

    netD_image_optimizer.step()
    netD_object_optimizer.step()

        # Logging.
    loss = {}
    loss['D/loss'] = d_loss.item()
    loss['D/image_adv_loss_real'] = d_image_adv_loss_real.item()
    loss['D/image_adv_loss_fake'] = d_image_adv_loss_fake.item()
    loss['D/object_adv_loss_real'] = d_object_adv_loss_real.item()
    loss['D/object_adv_loss_fake'] = d_object_adv_loss_fake.item()
    loss['D/object_cls_loss_real'] = d_object_cls_loss_real.item()

        # =================================================================================== #
        #                               3. Train the generator                                #
        # =================================================================================== #
        # Generate fake image
    output = netG(imgs, objs, boxes, masks, obj_to_img, z)
    crops_input, crops_input_rec, crops_rand, img_rec, img_rand, mu, logvar, z_rand_rec = output

        # reconstruction loss of ae and img
        # g_img_rec_loss = torch.abs(img_rec - imgs).view(imgs.shape[0], -1).mean(1)
    g_img_rec_loss = torch.abs(img_rec - imgs).mean()
    g_z_rec_loss = torch.abs(z_rand_rec - z).mean()

        # kl loss
    kl_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    g_kl_loss = torch.sum(kl_element).mul_(-0.5)

        # Compute image adv loss with fake images.
    out_logits = netD_image(img_rec)
    g_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits(out_logits, torch.full_like(out_logits, 1))

    out_logits = netD_image(img_rand)
    g_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(out_logits, torch.full_like(out_logits, 1))

    g_image_adv_loss_fake = 0.5 * g_image_adv_loss_fake_rec + 0.5 * g_image_adv_loss_fake_rand

        # Compute object adv loss with fake images.
    out_logits_src, out_logits_cls = netD_object(crops_input_rec, objs)
    g_object_adv_loss_rec = F.binary_cross_entropy_with_logits(out_logits_src,
                                                                   torch.full_like(out_logits_src, 1))
    g_object_cls_loss_rec = F.cross_entropy(out_logits_cls, objs)

    out_logits_src, out_logits_cls = netD_object(crops_rand, objs)
    g_object_adv_loss_rand = F.binary_cross_entropy_with_logits(out_logits_src,
                                                                    torch.full_like(out_logits_src, 1))
    g_object_cls_loss_rand = F.cross_entropy(out_logits_cls, objs)

    g_object_adv_loss = 0.5 * g_object_adv_loss_rec + 0.5 * g_object_adv_loss_rand
    g_object_cls_loss = 0.5 * g_object_cls_loss_rec + 0.5 * g_object_cls_loss_rand

        # Backward and optimize.
    g_loss = 0
    g_loss += config.lambda_img_rec * g_img_rec_loss
    g_loss += config.lambda_z_rec * g_z_rec_loss
    g_loss += config.lambda_img_adv * g_image_adv_loss_fake
    g_loss += config.lambda_obj_adv * g_object_adv_loss
    g_loss += config.lambda_obj_cls * g_object_cls_loss
    g_loss += config.lambda_kl * g_kl_loss

    netG.zero_grad()
    g_loss.backward()
    netG_optimizer.step()

    loss['G/loss'] = g_loss.item()
    loss['G/image_adv_loss'] = g_image_adv_loss_fake.item()
    loss['G/object_adv_loss'] = g_object_adv_loss.item()
    loss['G/object_cls_loss'] = g_object_cls_loss.item()
    loss['G/rec_img'] = g_img_rec_loss.item()
    loss['G/rec_z'] = g_z_rec_loss.item()
    loss['G/kl'] = g_kl_loss.item()

        # =================================================================================== #
        #                               4. Log                                                #
        # =================================================================================== #
    if (i + 1) % config.log_step == 0:
        log = 'iter [{:06d}/{:06d}]'.format(i + 1, config.niter)
        for tag, roi_value in loss.items():
            log += ", {}: {:.4f}".format(tag, roi_value)
        print(log)

    if (i + 1) % config.tensorboard_step == 0 and config.use_tensorboard:
        for tag, roi_value in loss.items():
            writer.add_scalar(tag, roi_value, i + 1)
        writer.add_image('Result/crop_real', imagenet_deprocess_batch(crops_input).float() / 255, i + 1)
        writer.add_image('Result/crop_real_rec', imagenet_deprocess_batch(crops_input_rec).float() / 255, i + 1)
        writer.add_image('Result/crop_rand', imagenet_deprocess_batch(crops_rand).float() / 255, i + 1)
        writer.add_image('Result/img_real', imagenet_deprocess_batch(imgs).float() / 255, i + 1)
        writer.add_image('Result/img_real_rec', imagenet_deprocess_batch(img_rec).float() / 255, i + 1)
        writer.add_image('Result/img_fake_rand', imagenet_deprocess_batch(img_rand).float() / 255, i + 1)

    if (i + 1) % config.save_step == 0:
        save_model(netG, model_dir=model_save_dir, appendix='netG', iter=i + 1, save_num=5,
                    save_step=config.save_step)
        save_model(netD_image, model_dir=model_save_dir, appendix='netD_image', iter=i + 1, save_num=5,
                    save_step=config.save_step)
        save_model(netD_object, model_dir=model_save_dir, appendix='netD_object', iter=i + 1, save_num=5,
                    save_step=config.save_step)

if config.use_tensorboard: writer.close()

if name == ‘main’:
parser = argparse.ArgumentParser()

# Training configuration
parser.add_argument('--dataset', type=str, default='publaynet')
#parser.add_argument('--vg_dir', type=str, default='datasets/vg')
parser.add_argument('--coco_dir', type=str, default='/home/tahani/PycharmProjects/synth_doc_layout/layout2im/datasets/annotations/')
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--niter', type=int, default=16000, help='number of training iteration')
parser.add_argument('--image_size', type=int, default=128, help='image size')
parser.add_argument('--object_size', type=int, default=64, help='object size')
parser.add_argument('--embedding_dim', type=int, default=64)
parser.add_argument('--z_dim', type=int, default=64)
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--resi_num', type=int, default=6)
parser.add_argument('--clstm_layers', type=int, default=3)

# Loss weight
parser.add_argument('--lambda_img_adv', type=float, default=1.0, help='weight of adv img')
parser.add_argument('--lambda_obj_adv', type=float, default=1.0, help='weight of adv obj')
parser.add_argument('--lambda_obj_cls', type=float, default=1.0, help='weight of aux obj')
parser.add_argument('--lambda_z_rec', type=float, default=10.0, help='weight of z rec')
parser.add_argument('--lambda_img_rec', type=float, default=1.0, help='weight of image rec')
parser.add_argument('--lambda_kl', type=float, default=0.01, help='weight of kl')

# Log setting
parser.add_argument('--resume_iter', type=str, default='l',
                    help='l: from latest; s: from scratch; xxx: from iteration xxx')
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--tensorboard_step', type=int, default=100)
parser.add_argument('--save_step', type=int, default=1000)
parser.add_argument('--use_tensorboard', type=str2bool, default='true')

config = parser.parse_args()
config.exp_name = 'layout2im_{}'.format(config.dataset)
print(config)
main(config)

Thank you in advance

The model state that you are loading and the target model are not identical. Most likely your save_model function is saving the model incorrectly. Go through the save_model function to see if there is something finicky happening there.

Thank you for your reply, actually I model trained with the same code after finishing I loaded to train it on other dataset(smaller data ) this my code for saving and loading the trained data that I reloaded after finishing the training:

def prepare_dir(name):
log_save_dir = ‘checkpoints/{}/logs’.format(name)
model_save_dir = ‘checkpoints/{}/models’.format(name)
sample_save_dir = ‘checkpoints/{}/samples’.format(name)
result_save_dir = ‘checkpoints/{}/results’.format(name)

if not Path(log_save_dir).exists(): Path(log_save_dir).mkdir(parents=True)
if not Path(model_save_dir).exists(): Path(model_save_dir).mkdir(parents=True)
if not Path(sample_save_dir).exists(): Path(sample_save_dir).mkdir(parents=True)
if not Path(result_save_dir).exists(): Path(result_save_dir).mkdir(parents=True)

return log_save_dir, model_save_dir, sample_save_dir, result_save_dir

def load_model(model, model_dir=None, appendix=None, iter=‘l’):

load_iter = None
load_model = None

if iter == 's' or not os.path.isdir(model_dir) or len(os.listdir(model_dir)) == 0:
    load_iter = 0
    if not os.path.isdir(model_dir):
        print('models dir not exist')
    elif len(os.listdir(model_dir)) == 0:
        print('models dir is empty')

    print('train from scratch.')
    return load_iter

# load latest epoch
if iter == 'l':
    for file in os.listdir(model_dir):
        if appendix is not None and appendix not in file:
            continue

        if file.endswith('.pkl'):
            current_iter = re.search('iter-\d+', file).group(0).split('-')[1]

            if len(current_iter) > 0:
                current_iter = int(current_iter)

                if load_iter is None or current_iter > load_iter:
                    load_iter = current_iter
                    load_model = os.path.join(model_dir, file)
            else:
                continue

    print('load from iter: %d' % load_iter)
    for param in model.parameters():
        param.requires_grad = False
    model.load_state_dict(torch.load(load_model))


    return load_iter
# from given iter
else:
    iter = int(iter)
    for file in os.listdir(model_dir):
        if file.endswith('.pkl'):
            current_iter = re.search('iter-\d+', file).group(0).split('-')[1]
            if len(current_iter) > 0:
                if int(current_iter) == iter:
                    load_iter = iter
                    load_model = os.path.join(model_dir, file)
                    break
    if load_model:
        model.load_state_dict(torch.load(load_model))
        print('load from iter: %d' % load_iter)
    else:
        load_iter = 0
        print('there is not saved models of iter %d' % iter)
        print('train from scratch.')
    return load_iter

def save_model(model, model_dir=None, appendix=None, iter=1, save_num=5, save_step=1000):
iter_idx = range(iter, iter - save_num * save_step, -save_step)

if not os.path.isdir(model_dir):
    os.makedirs(model_dir)

for file in os.listdir(model_dir):
    if file.endswith('.pkl'):
        current_iter = re.search('iter-\d+', file).group(0).split('-')[1]
        if len(current_iter) > 0:
            if int(current_iter) not in iter_idx:
                os.remove(os.path.join(model_dir, file))
        else:
            continue

if appendix:
    model_name = os.path.join(model_dir, 'iter-%d_%s.pkl' % (iter, appendix))
else:
    model_name = os.path.join(model_dir, 'iter-%d.pkl' % iter)
torch.save(model.state_dict(), model_name)