I want to load pre-trained weights for a SwinV2 transformer and later load the weights into another Swin with modified architecture as below:
import timm
import torch
n_class =8
img_size= 224
mlp_ratio=2
depths=[2,2,6,2]
num_heads=[3, 6, 12, 24]
window_size=7
qkv_bias=True
drop_rate=0.2
attn_drop_rate=0
drop_path_rate=0.1
model_arch='timm/swinv2_cr_tiny_ns_224.sw_in1k'
pre_model = timm.create_model(model_arch, num_classes = n_class,pretrained=True,image_size =img_size,
drop_rate=drop_rate,attn_drop_rate=attn_drop_rate,drop_path_rate=drop_path_rate)
model = timm.create_model(model_arch, num_classes = n_class,pretrained=False,image_size =img_size,
mlp_ratio=mlp_ratio,depths=depths,num_heads=num_heads,
window_size=window_size, qkv_bias=qkv_bias,drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,drop_path_rate=drop_path_rate)
model.load_state_dict(pre_model.state_dict(),strict=False)
But this throws an error:
Cell In[262], line 1
model.load_state_dict(pre_model.state_dict(),strict=False)
File ~\anaconda3\lib\site-packages\torch\nn\modules\module.py:2041 in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SwinTransformerV2Cr:
size mismatch for stages.0.blocks.0.mlp.fc1.weight: copying a param with shape torch.Size([384, 96]) from checkpoint, the shape in current model is torch.Size([192, 96]).
size mismatch for stages.0.blocks.0.mlp.fc1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
size mismatch for stages.0.blocks.0.mlp.fc2.weight: copying a param with shape torch.Size([96, 384]) from checkpoint, the shape in current model is torch.Size([96, 192]).
size mismatch for stages.0.blocks.1.mlp.fc1.weight: copying a param with shape torch.Size([384, 96]) from checkpoint, the shape in current model is torch.Size([192, 96]).
size mismatch for stages.0.blocks.1.mlp.fc1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
size mismatch for stages.0.blocks.1.mlp.fc2.weight: copying a param with shape torch.Size([96, 384]) from checkpoint, the shape in current model is torch.Size([96, 192]).
size mismatch for stages.1.blocks.0.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([384, 192]).
size mismatch for stages.1.blocks.0.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
size mismatch for stages.1.blocks.0.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([192, 384]).
size mismatch for stages.1.blocks.1.mlp.fc1.weight: copying a param with shape torch.Size([768, 192]) from checkpoint, the shape in current model is torch.Size([384, 192]).
size mismatch for stages.1.blocks.1.mlp.fc1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
size mismatch for stages.1.blocks.1.mlp.fc2.weight: copying a param with shape torch.Size([192, 768]) from checkpoint, the shape in current model is torch.Size([192, 384]).
size mismatch for stages.2.blocks.0.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 384]).
size mismatch for stages.2.blocks.0.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for stages.2.blocks.0.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]).
size mismatch for stages.2.blocks.1.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 384]).
size mismatch for stages.2.blocks.1.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for stages.2.blocks.1.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]).
size mismatch for stages.2.blocks.2.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 384]).
size mismatch for stages.2.blocks.2.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for stages.2.blocks.2.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]).
size mismatch for stages.2.blocks.3.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 384]).
size mismatch for stages.2.blocks.3.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for stages.2.blocks.3.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]).
size mismatch for stages.2.blocks.4.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 384]).
size mismatch for stages.2.blocks.4.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for stages.2.blocks.4.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]).
size mismatch for stages.2.blocks.5.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 384]).
size mismatch for stages.2.blocks.5.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for stages.2.blocks.5.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]).
size mismatch for stages.3.blocks.0.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 768]).
size mismatch for stages.3.blocks.0.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for stages.3.blocks.0.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([768, 1536]).
size mismatch for stages.3.blocks.1.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 768]).
size mismatch for stages.3.blocks.1.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for stages.3.blocks.1.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([768, 1536]).
Also, can I implement this using SwinV2 from torchvision?