Loading vision transform weights to a model error

I’m trying to load pth file of a pre-trained vision transformer to a model. The model has the same architecture used for the vision transformer.
I keep getting these errors:
Missing key(s) in state_dict: “module.FeatureExtraction.ConvNet.0.weight”, “module.FeatureExtraction.ConvNet.0.bias”, “module.FeatureExtraction.ConvNet.3.weight”, “module.FeatureExtraction.ConvNet.3.bias”, “module.FeatureExtraction.ConvNet.6.weight”…
Unexpected key(s) in state_dict: “module.vitstr.cls_token”, “module.vitstr.pos_embed”, “module.vitstr.patch_embed.proj.weight”, “module.vitstr.patch_embed.proj.bias”, “module.vitstr.blocks.0.norm1.weight”, “module.vitstr.blocks.0.norm1.bias”, “module.vitstr.blocks.0.attn.qkv.weight”, “module.vitstr.blocks.0.attn.qkv.bias”, “module.vitstr.blocks.0.attn.proj.weight”, …

my code is:
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
model = Model(opt)
file_path=’/content/drive/MyDrive/VitSTR/vitstr_tiny_patch16_224_aug.pth’
model = torch.nn.DataParallel(model).to(device)
model.load_state_dict(torch.load(file_path))

The error is in loading state dict

Try loading the model before wrapping it as a DataParallel object:

device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
model = Model(opt)
file_path=’/content/drive/MyDrive/VitSTR/vitstr_tiny_patch16_224_aug.pth’
model.load_state_dict(torch.load(file_path))
model = torch.nn.DataParallel(model).to(device)

This is because once you wrap the model in DataParallel, you will need to write model.module to access your original object of class Model() .

Also, remember that it recommended to always use DistributedDataParallel over DataParallel whenever possible.

This a guide that I found really useful to learn about how to use the DistributedDataParallel for multi-gpu/multi-node training. :grinning_face_with_smiling_eyes:

1 Like