Loading state to a quantized Inception_v3 model

Hello,

I’ve read the excellent static quantization tutorial, which worked really well with my MobileNet_v2 pretrained weights
https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html?highlight=quantization

Now I want to do the same with Inception_v3 model. Thing is, when I try to load the state dictionary I get the following error (full trace bellow), why and how do I fix that? Do I need to create a new class as they did in the tutorial?

per_channel_quantized_model = torchvision.models.quantization.inception_v3(pretrained=True, aux_logits=False, quantize=True)
per_channel_quantized_model.fc = nn.Linear(2048, 2)
state_dict = torch.load(float_model_file, map_location=torch.device(‘cpu’))
per_channel_quantized_model.load_state_dict(state_dict)

File “”, line 5, in
per_channel_quantized_model.load_state_dict(state_dict)
File “/home/nimrod/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 1030, in load_state_dict
load(self)
File “/home/nimrod/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 1028, in load
load(child, prefix + name + ‘.’)
File “/home/nimrod/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 1028, in load
load(child, prefix + name + ‘.’)
File “/home/nimrod/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 1025, in load
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
File “/home/nimrod/anaconda3/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py”, line 120, in _load_from_state_dict
state_dict[prefix + ‘weight’], state_dict[prefix + ‘bias’])
KeyError: ‘Conv2d_1a_3x3.conv.bias’

Using eager mode quantization often changes the model hierarchy of the model, so it’s possible that the module hierarchy of your quantized model no longer matches your fp32 state_dict. What is the origin of the fp32 state_dict, is it also from torchvision?

In practice, people usually fix this by either loading the weights before fusing the model, or by writing custom state dict mappers which modify the state keys according to the module hierarchy changes.

Yes, the original state dict is a fp32 state_dict, a torchvision model.

I loaded the state_dict before fusing the model, I fuse it only afterward. So writing a state_dict mappers (or creating the model class “from scratch” as they did in the tutorial) is the only solution?
Can you give an example of how I can create such a mapper?

per_channel_quantized_model = torchvision.models.quantization.inception_v3(pretrained=True, aux_logits=False, quantize=True)

this line should already load a pretrained model. To clarify, is the reason you are loading weights again is to populate the new fc layer, or something else?

It loads imagenet weights, but I need a different representation, not just the fc layer, meaning the weights of the whole net. You can ignore the pretrained value.

Makes sense. The torchvision.models.quantization.inception_v3(pretrained=True, aux_logits=False, quantize=True) line is torchvision’s best effort to provide a pretrained model ready for quantization for use cases where the default fp32 pretrained weights are fine. Unfortunately, if you need to load a different version of floating point weights, a mapping of the state dict is required.

Here is a code snippet which does this for an unrelated model, but the principle is the same:

        def get_new_bn_key(old_bn_key):
            # tries to adjust the key for conv-bn fusion, where
            # root
            #   - conv
            #   - bn
            #
            # becomes
            #
            # root
            #   - conv
            #     - bn
            return old_bn_key.replace(".bn.", ".conv.bn.")

        non_qat_to_qat_state_dict_map = {}
        for key in original_state_dict.keys():
            if key in new_state_dict.keys():
                non_qat_to_qat_state_dict_map[key] = key
            else:
                maybe_new_bn_key = get_new_bn_key(key)
                if maybe_new_bn_key in new_state_dict.keys():
                    non_qat_to_qat_state_dict_map[key] = maybe_new_bn_key
        ...
        # when loading the state dict, use the mapping created above

We are planning to release a tool soon (hopefully v1.8) to automate all of this, so this should get easier in the near future.

1 Like

Yeah, I get the idea. You copy the keys if they with the same name, or change the name and then copy the keys in the new map. Afterward, you load the weights to the quantized model according to the new mapping.

If I look at the Inception and Quantized Inception’s state_dict (before fusing, of course) then they have completely different names, and also a different length of state_dict. How can I tackle that ?

Here is a high level of the differences:

# fp32
# torchvision.models.inception_v3(pretrained=False, aux_logits=False)
# state_dict keys of Conv2d_1a_3x3
# conv
Conv2d_1a_3x3.conv.weight           
# bn
Conv2d_1a_3x3.bn.weight             
Conv2d_1a_3x3.bn.bias               
Conv2d_1a_3x3.bn.running_mean       
Conv2d_1a_3x3.bn.running_var        
Conv2d_1a_3x3.bn.num_batches_tracked

# ready for quantization but not quantized
# mq = torchvision.models.quantization.inception_v3(pretrained=False, aux_logits=False, quantize=False)
# state_dict keys of Conv2d_1a_3x3
# conv
Conv2d_1a_3x3.conv.weight   
# bn        
Conv2d_1a_3x3.bn.weight             
Conv2d_1a_3x3.bn.bias               
Conv2d_1a_3x3.bn.running_mean       
Conv2d_1a_3x3.bn.running_var        
Conv2d_1a_3x3.bn.num_batches_tracked

# quantized and fused
# mq = torchvision.models.quantization.inception_v3(pretrained=False, aux_logits=False, quantize=True)
# state_dict keys of Conv2d_1a_3x3
# conv, including quantization-specific scale+zp
Conv2d_1a_3x3.conv.weight    
Conv2d_1a_3x3.conv.bias      
Conv2d_1a_3x3.conv.scale     
Conv2d_1a_3x3.conv.zero_point
# no bn, it was fused into the conv

You could use the ready for quantization but not quantized model if you are running quantization yourself, the state_dict will be closer to fp32 version since it is before fusion. Then you’d have to go block by block and see if there are additional differences.

1 Like

which names in particular are different? It should be pretty similar.

The state_dicts have the same lengths, my mistake. That was quite odd and confusing, though the lengths make sense now. The naming differences shouldn’t be much of a problem.
I’ll deal with that and then perform the fusion.

It could have been great if I could perform a QAT using CUDA instead of dealing with the conversion, hopefully a CUDA support will available soon.
So far the PTQ with optimized calibration works really well with my data, I get almost no drop in AP:) , but maybe in slightly different cases it would be useful in order to avoid drop in AP.

Thanks.

calling prepare and running the QAT fine-tuning is supported on CUDA. The only thing not supported is calling convert and running the quantized kernels. Not sure if that is what you were referring to.

1 Like

Good to know. I thought that there’s no CUDA support also for QAT. Thanks.