Couldn't load MobileViT model (https://github.com/apple/ml-cvnets/blob/main/examples/README-mobilevit.md)

Hi,

I am new to PyTorch.

I wanna load the MobileViT model, but I have got an unexpected result. I have downloaded the model weights, and the problem is when running the model load, I get only the weights, but I wanna see the layers, in order to do a transfer learning.

new_net = torch.load("mobilevit_s.pt")
new_net

Thank you for taking time to read my question!

As you’ve described, the checkpoint contains the model parameters and buffers only.
To see the layers you would need to create the model object and could print it.
Similar to:

model = MyModel(args)
print(model)

where MyModel isthe actual MobileViT definition.

Thank you @ptrblck for your response.

The solution that you have provided is working well, and I could see my model’s architecture and layers. Now, I want to do transfer learning, so, I need to load the model’s weights from “mobilevit_s.pt”, but I am getting an error. Here is my code:

net = MobileViT_S()
net.load_state_dict(torch.load(MODELS_PATH + "MobileViT_S_model_best.pth.tar", map_location=torch.device('cpu')))
print(net)

I have got the following error in load_state_dict function:

RuntimeError: Error(s) in loading state_dict for MobileViT:
	Missing key(s) in state_dict: "stem.0.weight", "stem.0.bias", "stem.1.conv.0.conv.weight", "stem.1.conv.0.norm_layer.weight", "stem.1.conv.0.norm_layer.bias", "stem.1.conv.0.norm_layer.running_mean", "stem.1.conv.0.norm_layer.running_var", "stem.1.conv.1.conv.weight", "stem.1.conv.1.norm_layer.weight", "stem.1.conv.1.norm_layer.bias", "stem.1.conv.1.norm_layer.running_mean", "stem.1.conv.1.norm_layer.running_var", "stem.1.conv.2.weight", "stem.1.conv.3.weight", "stem.1.conv.3.bias", "stem.1.conv.3.running_mean", "stem.1.conv.3.running_var", "stage1.0.conv.0.conv.weight", "stage1.0.conv.0.norm_layer.weight", "stage1.0.conv.0.norm_layer.bias", "stage1.0.conv.0.norm_layer.running_mean", "stage1.0.conv.0.norm_layer.running_var", "stage1.0.conv.1.conv.weight", "stage1.0.conv.1.norm_layer.weight", "stage1.0.conv.1.norm_layer.bias", "stage1.0.conv.1.norm_layer.running_mean", "stage1.0.conv.1.norm_layer.running_var", "stage1.0.conv.2.weight", "stage1.0.conv.3.weight", "stage1.0.conv.3.bias", "stage1.0.conv.3.running_mean", "stage1.0.conv.3.running_var", "stage1.1.conv.0.conv.weight", "stage1.1.conv.0.norm_layer.weight", "stage1.1.conv.0.norm_layer.bias", "stage1.1.conv.0.norm_layer.running_mean", "stage1.1.conv.0.norm_layer.running_var", "stage1.1.conv.1.conv.weight", "stage1.1.conv.1.norm_layer.weight", "stage1.1.conv.1.norm_layer.bias", "stage1.1.conv.1.norm_layer.running_mean", "stage1.1.conv.1.norm_layer.running_var", "stage1.1.conv.2.weight", "stage1.1.conv.3.weight", "stage...
	Unexpected key(s) in state_dict: "epoch", "state_dict", "best_acc1", "optimizer".```

Your checkpoint seems to contain multiple objects where one of them seems to be the actual state_dict.
I guess:

net = MobileViT_S()
net.load_state_dict(torch.load(MODELS_PATH + "MobileViT_S_model_best.pth.tar", map_location=torch.device('cpu'))['state_dict'])
print(net)

should work (i.e. access the loaded checkpoint via ['state_dict']).

1 Like

I’ve printed the [‘state_dict’] and I’ve found that there is this ‘module.’ before every layer name. So, I’ve created a function that delete these ‘module.’ and pass that state_dict dictionary to my model. Here is the solution:

def load_mobilevit_weights(model_path):
  # Create an instance of the MobileViT model
  net = MobileViT_S()

  # Load the MobileViT state_dict
  state_dict = torch.load(model_path, map_location=torch.device('cpu'))['state_dict']

  # Since there is a problem in the names of layers, we will delete 'module.' from every beginning of layer's name
  for key in list(state_dict.keys()):
    state_dict[key.replace('module.', '')] = state_dict.pop(key)
  
  # Once the keys are fixed, we can pass the state_dict to our MobileViT model
  net.load_state_dict(state_dict)
  
  return net

net = load_mobilevit_weights(MODELS_PATH + "MobileViT_S_model_best.pth.tar")

Thank you so much for your help!