Quantization of multi_head_attention_forward

Hi Team,

Could someone help me with quantization of multi head attention layers in PyTorch ?

I am new to PyTorch and have been experimenting quantization of OpenAI’s CLIP model in PyTorch. Specifically I’m trying to quantize (modified) ResNet encoders of CLIP which has CNN blocks followed by a final F.multi_head_attention_forward layer.

I’m using FX Graph Mode Quantization for quantizing the RN50 model as follows,

RN50.eval()
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
example_inputs = torch.rand(1,3, 224, 224) # get an example input
prepared_RN50 = prepare_fx(RN50, qconfig_mapping, example_inputs)   # fuse modules and insert observers
calibrate(prepared_RN50)  # run calibration on sample data
prepared_RN50 
quantized_RN50 = convert_fx(prepared_RN50)  # convert the calibrated model to a quantized model

While the quantized model works as expected, we don’t see the model size decreasing as expected. Checking further we see parameters of attention layer in float32 datatype.

for name, param in quantized_RN50.named_parameters():
    print(name, param.dtype)

attnpool.positional_embedding torch.float32
attnpool.q_proj.weight torch.float32
attnpool.q_proj.bias torch.float32
attnpool.k_proj.weight torch.float32
attnpool.k_proj.bias torch.float32
attnpool.v_proj.weight torch.float32
attnpool.v_proj.bias torch.float32
attnpool.c_proj.weight torch.float32
attnpool.c_proj.bias torch.float32

In below document, nn.MultiheadAttention is mentioned as both supported through custom modules (under section Quantization API Summary) and not supported (under section Operator Support) for static quantization. However, I didn’t come across any documentation or examples reg the use of quantization of multiheadattention layer. It would be great if someone can point me towards the same.
https://pytorch.org/docs/stable/quantization.html

first, do you need dynamic quantization or static quantization?

for static quantization + fx graph mode quantization, here is an example for LSTM: https://github.com/pytorch/pytorch/blob/main/test/quantization/fx/test_quantize_fx.py#L4763-L4780 you can probably do similar things with MultiheadAttention (replace LSTM with MultiheadAttention)

Hi Jerry,

Thanks for the response. I’m looking for static quantization. I tried the below but got error - AttributeError: ‘tuple’ object has no attribute ‘numel’, while passing inputs to the prepared model. I suspect it is because MultiheadAttention is returning a tuple of output and weights. Would you be able to share a working example or point what I need to change in below code ?

class m_AttentionPool2d(nn.Module):
    def __init__(self):
        super(m_AttentionPool2d, self).__init__()
        self.m_atten = nn.MultiheadAttention(embed_dim=2048, num_heads=32, bias=True, add_bias_kv=False)

    def forward(self, x):
        (op , _) = self.m_atten(query=x[:1], key=x, value=x, need_weights=False, average_attn_weights=False, is_causal=False)
        return op
    
model = m_AttentionPool2d()

# test model with random input
example_inputs = torch.rand(50,20,2048)
example_outputs = model(example_inputs)
print(example_outputs.shape)     # output - torch.Size([1, 20, 2048])

qconfig_mapping = get_default_qconfig_mapping()

prepare_custom_config = PrepareCustomConfig().set_float_to_observed_mapping(torch.nn.MultiheadAttention, torch.ao.nn.quantizable.MultiheadAttention)
convert_custom_config = ConvertCustomConfig().set_observed_to_quantized_mapping(torch.ao.nn.quantizable.MultiheadAttention, torch.ao.nn.quantized.MultiheadAttention)

example_inputs = torch.rand(4,20,2048)

# quantize model
model.eval()
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config)
#self.checkGraphModuleNodes(model, expected_node_occurrence=prepare_node_occurrence)
prepared_model(example_inputs)   # attribute error seen for this line during execution

quantized_model = convert_fx(prepared_model, convert_custom_config=convert_custom_config)
#self.checkGraphModuleNodes(model, expected_node_occurrence=convert_node_occurrence)
quantized_model(example_inputs)

actually static quant for multihead attention is only available in eager mode quantization I think, so you can do something like:

custom_module_config = {
        "float_to_observed_custom_module_class": {
            torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention
        },
        "observed_to_quantized_custom_module_class": {
            torch.nn.quantizable.MultiheadAttention: torch.nn.quantizable.MultiheadAttention
        },
    }
    torch.ao.quantization.prepare(
        model, inplace=True, prepare_custom_config_dict=custom_module_config
    )

    torch.ao.quantization.convert(
        model,
        inplace=True,
        convert_custom_config_dict=custom_module_config,
    )

and also set the qconfig properly for the torch.nn.MultiheadAttention module, here is the general static quantization tutorial for eager mode quantization: (beta) Static Quantization with Eager Mode in PyTorch — PyTorch Tutorials 2.0.1+cu117 documentation

Hi Jerry,

Thanks for getting back. I tried static eager quantization of MultiheadAttention, but still find the weights to be in float32 datatype after quantization. Moreover model size before and after quantization remains the same. While there are no errors, model doesn’t seem to be quantized. Sharing the code below,

import os 
import torch 
from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub

class m_AttentionPool2d(nn.Module):
    def __init__(self):
        super(m_AttentionPool2d, self).__init__()
        self.m_atten = nn.MultiheadAttention(embed_dim=2048, num_heads=32, bias=True, add_bias_kv=False)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        (op , _) = self.m_atten(query=x[:1], key=x, value=x, need_weights=False, average_attn_weights=False, is_causal=False)
        op = self.dequant(op) 
        return op
    
model = m_AttentionPool2d()

example_inputs = torch.rand(50,20,2048)
example_outputs = model(example_inputs)
print(example_outputs.shape)   # op - torch.Size([1, 20, 2048])

model.eval()

custom_module_config = {
        "float_to_observed_custom_module_class": {torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention},
        "observed_to_quantized_custom_module_class": {torch.nn.quantizable.MultiheadAttention: torch.nn.quantizable.MultiheadAttention},
    }

model.qconfig = torch.ao.quantization.get_default_qconfig('x86')
model_prepared = torch.ao.quantization.prepare(model, prepare_custom_config_dict=custom_module_config)
model_prepared(example_inputs)
model_int8 = torch.ao.quantization.convert(model,convert_custom_config_dict=custom_module_config,)

print(model_int8(example_inputs).shape) # op - torch.Size([1, 20, 2048])

for param in model_int8.state_dict().keys():
    print(param, model_int8.state_dict()[param].dtype)

'''
Ouptut :
m_atten.in_proj_weight torch.float32
m_atten.in_proj_bias torch.float32
m_atten.out_proj.weight torch.float32
m_atten.out_proj.bias torch.float32
'''

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p")/1e6
    os.remove('temp.p')
    return size

print('model Size (MB):', print_size_of_model(model))
print('model_int8 Size (MB):', print_size_of_model(model_int8))

"""
Output:
model Size (MB): 67.143117
model_int8 Size (MB): 67.143117
"""

There’s a typo, it should be

custom_module_config = {
“float_to_observed_custom_module_class”: {
torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention
},
“observed_to_quantized_custom_module_class”: {
torch.nn.quantizable.MultiheadAttention: torch.nn.quantized.MultiheadAttention
},

import os 
import torch 
from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub

class m_AttentionPool2d(nn.Module):
    def __init__(self):
        super(m_AttentionPool2d, self).__init__()
        self.m_atten = nn.MultiheadAttention(embed_dim=2048, num_heads=32, bias=True, add_bias_kv=False)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        (op , _) = self.m_atten(query=x[:1], key=x, value=x, need_weights=False, average_attn_weights=False, is_causal=False)
        op = self.dequant(op) 
        return op
    
model = m_AttentionPool2d()

example_inputs = torch.rand(50,20,2048)
example_outputs = model(example_inputs)
print(example_outputs.shape)   # op - torch.Size([1, 20, 2048])

model.eval()

custom_module_config = {
        "float_to_observed_custom_module_class": {torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention},
        "observed_to_quantized_custom_module_class": {torch.nn.quantizable.MultiheadAttention: torch.nn.quantized.MultiheadAttention},
    }
print("before", model)
model.qconfig = torch.ao.quantization.get_default_qconfig()
model_prepared = torch.ao.quantization.prepare(model, prepare_custom_config_dict=custom_module_config)
print("prepared", model_prepared)
model_prepared(example_inputs)
model_int8 = torch.ao.quantization.convert(model_prepared,convert_custom_config_dict=custom_module_config,)
print("converted", model_int8)

Result

torch.Size([1, 20, 2048])
before m_AttentionPool2d(
  (m_atten): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=2048, out_features=2048, bias=True)
  )
  (quant): QuantStub()
  (dequant): DeQuantStub()
)
prepared m_AttentionPool2d(
  (m_atten): QuantizableMultiheadAttention(
    (out_proj): Linear(
      in_features=2048, out_features=2048, bias=True
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (linear_Q): Linear(
      in_features=2048, out_features=2048, bias=True
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (linear_K): Linear(
      in_features=2048, out_features=2048, bias=True
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (linear_V): Linear(
      in_features=2048, out_features=2048, bias=True
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (q_scaling_product): FloatFunctional(
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (quant_attn_output): QuantStub(
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (quant_attn_output_weights): QuantStub(
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (dequant_q): DeQuantStub()
    (dequant_k): DeQuantStub()
    (dequant_v): DeQuantStub()
  )
  (quant): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (dequant): DeQuantStub()
)
converted m_AttentionPool2d(
  (m_atten): QuantizedMultiheadAttention(
    (out_proj): QuantizedLinear(in_features=2048, out_features=2048, scale=0.011808564886450768, zero_point=59, qscheme=torch.per_channel_affine)
    (linear_Q): QuantizedLinear(in_features=2048, out_features=2048, scale=0.025253426283597946, zero_point=65, qscheme=torch.per_channel_affine)
    (linear_K): QuantizedLinear(in_features=2048, out_features=2048, scale=0.025217663496732712, zero_point=66, qscheme=torch.per_channel_affine)
    (linear_V): QuantizedLinear(in_features=2048, out_features=2048, scale=0.0245734341442585, zero_point=64, qscheme=torch.per_channel_affine)
    (q_scaling_product): QFunctional(
      scale=1.0, zero_point=0
      (activation_post_process): Identity()
    )
    (quant_attn_output): Quantize(scale=tensor([0.0197]), zero_point=tensor([65]), dtype=torch.quint8)
    (quant_attn_output_weights): Quantize(scale=tensor([0.0002]), zero_point=tensor([0]), dtype=torch.quint8)
    (dequant_q): DeQuantize()
    (dequant_k): DeQuantize()
    (dequant_v): DeQuantize()
  )
  (quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)