Quantization of multi_head_attention_forward

Hi Team,

Could someone help me with quantization of multi head attention layers in PyTorch ?

I am new to PyTorch and have been experimenting quantization of OpenAI’s CLIP model in PyTorch. Specifically I’m trying to quantize (modified) ResNet encoders of CLIP which has CNN blocks followed by a final F.multi_head_attention_forward layer.

I’m using FX Graph Mode Quantization for quantizing the RN50 model as follows,

RN50.eval()
qconfig = get_default_qconfig("x86")
qconfig_mapping = QConfigMapping().set_global(qconfig)
example_inputs = torch.rand(1,3, 224, 224) # get an example input
prepared_RN50 = prepare_fx(RN50, qconfig_mapping, example_inputs)   # fuse modules and insert observers
calibrate(prepared_RN50)  # run calibration on sample data
prepared_RN50 
quantized_RN50 = convert_fx(prepared_RN50)  # convert the calibrated model to a quantized model

While the quantized model works as expected, we don’t see the model size decreasing as expected. Checking further we see parameters of attention layer in float32 datatype.

for name, param in quantized_RN50.named_parameters():
    print(name, param.dtype)

attnpool.positional_embedding torch.float32
attnpool.q_proj.weight torch.float32
attnpool.q_proj.bias torch.float32
attnpool.k_proj.weight torch.float32
attnpool.k_proj.bias torch.float32
attnpool.v_proj.weight torch.float32
attnpool.v_proj.bias torch.float32
attnpool.c_proj.weight torch.float32
attnpool.c_proj.bias torch.float32

In below document, nn.MultiheadAttention is mentioned as both supported through custom modules (under section Quantization API Summary) and not supported (under section Operator Support) for static quantization. However, I didn’t come across any documentation or examples reg the use of quantization of multiheadattention layer. It would be great if someone can point me towards the same.
https://pytorch.org/docs/stable/quantization.html

first, do you need dynamic quantization or static quantization?

for static quantization + fx graph mode quantization, here is an example for LSTM: https://github.com/pytorch/pytorch/blob/main/test/quantization/fx/test_quantize_fx.py#L4763-L4780 you can probably do similar things with MultiheadAttention (replace LSTM with MultiheadAttention)

Hi Jerry,

Thanks for the response. I’m looking for static quantization. I tried the below but got error - AttributeError: ‘tuple’ object has no attribute ‘numel’, while passing inputs to the prepared model. I suspect it is because MultiheadAttention is returning a tuple of output and weights. Would you be able to share a working example or point what I need to change in below code ?

class m_AttentionPool2d(nn.Module):
    def __init__(self):
        super(m_AttentionPool2d, self).__init__()
        self.m_atten = nn.MultiheadAttention(embed_dim=2048, num_heads=32, bias=True, add_bias_kv=False)

    def forward(self, x):
        (op , _) = self.m_atten(query=x[:1], key=x, value=x, need_weights=False, average_attn_weights=False, is_causal=False)
        return op
    
model = m_AttentionPool2d()

# test model with random input
example_inputs = torch.rand(50,20,2048)
example_outputs = model(example_inputs)
print(example_outputs.shape)     # output - torch.Size([1, 20, 2048])

qconfig_mapping = get_default_qconfig_mapping()

prepare_custom_config = PrepareCustomConfig().set_float_to_observed_mapping(torch.nn.MultiheadAttention, torch.ao.nn.quantizable.MultiheadAttention)
convert_custom_config = ConvertCustomConfig().set_observed_to_quantized_mapping(torch.ao.nn.quantizable.MultiheadAttention, torch.ao.nn.quantized.MultiheadAttention)

example_inputs = torch.rand(4,20,2048)

# quantize model
model.eval()
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config)
#self.checkGraphModuleNodes(model, expected_node_occurrence=prepare_node_occurrence)
prepared_model(example_inputs)   # attribute error seen for this line during execution

quantized_model = convert_fx(prepared_model, convert_custom_config=convert_custom_config)
#self.checkGraphModuleNodes(model, expected_node_occurrence=convert_node_occurrence)
quantized_model(example_inputs)

actually static quant for multihead attention is only available in eager mode quantization I think, so you can do something like:

custom_module_config = {
        "float_to_observed_custom_module_class": {
            torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention
        },
        "observed_to_quantized_custom_module_class": {
            torch.nn.quantizable.MultiheadAttention: torch.nn.quantizable.MultiheadAttention
        },
    }
    torch.ao.quantization.prepare(
        model, inplace=True, prepare_custom_config_dict=custom_module_config
    )

    torch.ao.quantization.convert(
        model,
        inplace=True,
        convert_custom_config_dict=custom_module_config,
    )

and also set the qconfig properly for the torch.nn.MultiheadAttention module, here is the general static quantization tutorial for eager mode quantization: (beta) Static Quantization with Eager Mode in PyTorch — PyTorch Tutorials 2.0.1+cu117 documentation

Hi Jerry,

Thanks for getting back. I tried static eager quantization of MultiheadAttention, but still find the weights to be in float32 datatype after quantization. Moreover model size before and after quantization remains the same. While there are no errors, model doesn’t seem to be quantized. Sharing the code below,

import os 
import torch 
from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub

class m_AttentionPool2d(nn.Module):
    def __init__(self):
        super(m_AttentionPool2d, self).__init__()
        self.m_atten = nn.MultiheadAttention(embed_dim=2048, num_heads=32, bias=True, add_bias_kv=False)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        (op , _) = self.m_atten(query=x[:1], key=x, value=x, need_weights=False, average_attn_weights=False, is_causal=False)
        op = self.dequant(op) 
        return op
    
model = m_AttentionPool2d()

example_inputs = torch.rand(50,20,2048)
example_outputs = model(example_inputs)
print(example_outputs.shape)   # op - torch.Size([1, 20, 2048])

model.eval()

custom_module_config = {
        "float_to_observed_custom_module_class": {torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention},
        "observed_to_quantized_custom_module_class": {torch.nn.quantizable.MultiheadAttention: torch.nn.quantizable.MultiheadAttention},
    }

model.qconfig = torch.ao.quantization.get_default_qconfig('x86')
model_prepared = torch.ao.quantization.prepare(model, prepare_custom_config_dict=custom_module_config)
model_prepared(example_inputs)
model_int8 = torch.ao.quantization.convert(model,convert_custom_config_dict=custom_module_config,)

print(model_int8(example_inputs).shape) # op - torch.Size([1, 20, 2048])

for param in model_int8.state_dict().keys():
    print(param, model_int8.state_dict()[param].dtype)

'''
Ouptut :
m_atten.in_proj_weight torch.float32
m_atten.in_proj_bias torch.float32
m_atten.out_proj.weight torch.float32
m_atten.out_proj.bias torch.float32
'''

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p")/1e6
    os.remove('temp.p')
    return size

print('model Size (MB):', print_size_of_model(model))
print('model_int8 Size (MB):', print_size_of_model(model_int8))

"""
Output:
model Size (MB): 67.143117
model_int8 Size (MB): 67.143117
"""

There’s a typo, it should be

custom_module_config = {
“float_to_observed_custom_module_class”: {
torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention
},
“observed_to_quantized_custom_module_class”: {
torch.nn.quantizable.MultiheadAttention: torch.nn.quantized.MultiheadAttention
},

import os 
import torch 
from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub

class m_AttentionPool2d(nn.Module):
    def __init__(self):
        super(m_AttentionPool2d, self).__init__()
        self.m_atten = nn.MultiheadAttention(embed_dim=2048, num_heads=32, bias=True, add_bias_kv=False)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        (op , _) = self.m_atten(query=x[:1], key=x, value=x, need_weights=False, average_attn_weights=False, is_causal=False)
        op = self.dequant(op) 
        return op
    
model = m_AttentionPool2d()

example_inputs = torch.rand(50,20,2048)
example_outputs = model(example_inputs)
print(example_outputs.shape)   # op - torch.Size([1, 20, 2048])

model.eval()

custom_module_config = {
        "float_to_observed_custom_module_class": {torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention},
        "observed_to_quantized_custom_module_class": {torch.nn.quantizable.MultiheadAttention: torch.nn.quantized.MultiheadAttention},
    }
print("before", model)
model.qconfig = torch.ao.quantization.get_default_qconfig()
model_prepared = torch.ao.quantization.prepare(model, prepare_custom_config_dict=custom_module_config)
print("prepared", model_prepared)
model_prepared(example_inputs)
model_int8 = torch.ao.quantization.convert(model_prepared,convert_custom_config_dict=custom_module_config,)
print("converted", model_int8)

Result

torch.Size([1, 20, 2048])
before m_AttentionPool2d(
  (m_atten): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=2048, out_features=2048, bias=True)
  )
  (quant): QuantStub()
  (dequant): DeQuantStub()
)
prepared m_AttentionPool2d(
  (m_atten): QuantizableMultiheadAttention(
    (out_proj): Linear(
      in_features=2048, out_features=2048, bias=True
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (linear_Q): Linear(
      in_features=2048, out_features=2048, bias=True
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (linear_K): Linear(
      in_features=2048, out_features=2048, bias=True
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (linear_V): Linear(
      in_features=2048, out_features=2048, bias=True
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (q_scaling_product): FloatFunctional(
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (quant_attn_output): QuantStub(
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (quant_attn_output_weights): QuantStub(
      (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
    )
    (dequant_q): DeQuantStub()
    (dequant_k): DeQuantStub()
    (dequant_v): DeQuantStub()
  )
  (quant): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (dequant): DeQuantStub()
)
converted m_AttentionPool2d(
  (m_atten): QuantizedMultiheadAttention(
    (out_proj): QuantizedLinear(in_features=2048, out_features=2048, scale=0.011808564886450768, zero_point=59, qscheme=torch.per_channel_affine)
    (linear_Q): QuantizedLinear(in_features=2048, out_features=2048, scale=0.025253426283597946, zero_point=65, qscheme=torch.per_channel_affine)
    (linear_K): QuantizedLinear(in_features=2048, out_features=2048, scale=0.025217663496732712, zero_point=66, qscheme=torch.per_channel_affine)
    (linear_V): QuantizedLinear(in_features=2048, out_features=2048, scale=0.0245734341442585, zero_point=64, qscheme=torch.per_channel_affine)
    (q_scaling_product): QFunctional(
      scale=1.0, zero_point=0
      (activation_post_process): Identity()
    )
    (quant_attn_output): Quantize(scale=tensor([0.0197]), zero_point=tensor([65]), dtype=torch.quint8)
    (quant_attn_output_weights): Quantize(scale=tensor([0.0002]), zero_point=tensor([0]), dtype=torch.quint8)
    (dequant_q): DeQuantize()
    (dequant_k): DeQuantize()
    (dequant_v): DeQuantize()
  )
  (quant): Quantize(scale=tensor([0.0079]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
)

Hi Charles / Jerry,

Thanks for the response. I’m able to get the quantized version of multihead attention module after fixing the above mentioned typo and another error (was converting original model instead of prepared model in my code).

Though the Q, K, V and out projection weights/bias are in torch.qint8 datatype after quantization, I see the quantized module having same size as original module. I believe it should it be closer 1/4 th of original model.

Could you let me know if it is expected to see the quantized attention module having same size as original attention module ? If so, can you share more details on why this is expected ?

import os 
import torch 
from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub

class m_AttentionPool2d(nn.Module):
    def __init__(self):
        super(m_AttentionPool2d, self).__init__()
        self.m_atten = nn.MultiheadAttention(embed_dim=2048, num_heads=32, bias=True, add_bias_kv=False)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        (op , _) = self.m_atten(query=x[:1], key=x, value=x, need_weights=False, average_attn_weights=False, is_causal=False)
        op = self.dequant(op) # quant
        return op
    
model = m_AttentionPool2d()
model.eval()
example_inputs = torch.rand(50,20,2048)

custom_module_config = {
        "float_to_observed_custom_module_class": {torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention},
        "observed_to_quantized_custom_module_class": {torch.nn.quantizable.MultiheadAttention: torch.nn.quantized.MultiheadAttention},
    }

model.qconfig = torch.ao.quantization.get_default_qconfig('x86')
model_prepared = torch.ao.quantization.prepare(model, prepare_custom_config_dict=custom_module_config)
model_prepared(example_inputs)
model_int8 = torch.ao.quantization.convert(model_prepared,convert_custom_config_dict=custom_module_config,)

print(model_int8(example_inputs).shape)  # op - torch.Size([1, 20, 2048])

def print_size_of_model(mdl, name):
    torch.save(mdl.state_dict(), name + '.p')
    size = os.path.getsize(name + '.p')/1e6
    #os.remove(name + '.p')
    return size

print('model Size (MB):', print_size_of_model(model,'model'))
print('model_int8 Size (MB):', print_size_of_model(model_int8, 'model_int8'))

####
Output:
model Size (MB): 67.143123
model_int8 Size (MB): 67.308203
####

# Output of datatype of quantized module's parameters:

for param in model_int8.state_dict().keys():
    if '_packed_params._packed_params' in param:
        w,b = model_int8.state_dict()[param]
        print(param,end = ' | ')
        print('weight shape : ', w.shape, end = ' | ')
        print('bias shape : ', b.shape, end = ' | ')
        print('dtype : ', model_int8.state_dict()['.'.join(param.split('.')[:-1])+'.dtype'])

####
m_atten.out_proj._packed_params._packed_params | weight shape :  torch.Size([2048, 2048]) | bias shape :  torch.Size([2048]) | dtype :  torch.qint8
m_atten.linear_Q._packed_params._packed_params | weight shape :  torch.Size([2048, 2048]) | bias shape :  torch.Size([2048]) | dtype :  torch.qint8
m_atten.linear_K._packed_params._packed_params | weight shape :  torch.Size([2048, 2048]) | bias shape :  torch.Size([2048]) | dtype :  torch.qint8
m_atten.linear_V._packed_params._packed_params | weight shape :  torch.Size([2048, 2048]) | bias shape :  torch.Size([2048]) | dtype :  torch.qint8
####

# Output of datatype of original module's parameters:
for param in model.state_dict().keys():
    print(param)
    print(model.state_dict()[param].shape, ' | dtype : ', model.state_dict()[param].dtype)

####
m_atten.in_proj_weight
torch.Size([6144, 2048])  | dtype :  torch.float32
m_atten.in_proj_bias
torch.Size([6144])  | dtype :  torch.float32
m_atten.out_proj.weight
torch.Size([2048, 2048])  | dtype :  torch.float32
m_atten.out_proj.bias
torch.Size([2048])  | dtype :  torch.float32
####

i think you’re modifying the model in place, try printing the model size at the start rather than after you do the quantization, you use copy deepcopy when applying prepare to preserve the state of model.

also you can see an example of a tutorial where they check the model size:

https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html?highlight=quantization

Hi Charles,

I checked the size of original model before quantization and it is still 67.143123 MB. Size of quantized model is 67.308203 MB.

inplace is set to False by default for convert and prepare operation. So original model should not be quantized in above code. We don’t see quantized layers while printing the original model and parameters are in float32 datatype (output attached in above code)

https://pytorch.org/docs/stable/generated/torch.ao.quantization.convert.html
https://pytorch.org/docs/stable/generated/torch.ao.quantization.prepare.html

Also, I’m using the same code in the shared tutorial for estimating model size. But slightly modified it to not delete the saved model, save with unique name (passed as parameter) and return model size.

yeah so that’s a pretty big issue lol, thanks for calling it out, we’re not deleting the in_proj weight and bias after we quantize the module. Here is a fix in the meantime

import os 
import torch 
from torch import nn, Tensor
from torch.ao.quantization import QuantStub, DeQuantStub

class m_AttentionPool2d(nn.Module):
    def __init__(self):
        super(m_AttentionPool2d, self).__init__()
        self.m_atten = nn.MultiheadAttention(embed_dim=2048, num_heads=32, bias=True, add_bias_kv=False)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        (op , _) = self.m_atten(query=x[:1], key=x, value=x, need_weights=False, average_attn_weights=False, is_causal=False)
        op = self.dequant(op) # quant
        return op
    
def print_size_of_model(mdl, name):
    torch.save(mdl.state_dict(), name + '.p')
    size = os.path.getsize(name + '.p')/1e6
    os.remove(name + '.p')
    return size

model = m_AttentionPool2d()

model.eval()
example_inputs = torch.rand(50,20,2048)



custom_module_config = {
        "float_to_observed_custom_module_class": {torch.nn.MultiheadAttention: torch.nn.quantizable.MultiheadAttention},
        "observed_to_quantized_custom_module_class": {torch.nn.quantizable.MultiheadAttention: torch.nn.quantized.MultiheadAttention},
}

model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack')
model_prepared = torch.ao.quantization.prepare(model, prepare_custom_config_dict=custom_module_config)



model_prepared(example_inputs)
model_int8 = torch.ao.quantization.convert(model_prepared,convert_custom_config_dict=custom_module_config,)

print('fp32 model Size (MB):', print_size_of_model(model,'model'))
print('model_int8 Size (MB):', print_size_of_model(model_int8, 'model_int8'))
del model_int8.m_atten.in_proj_weight
del model_int8.m_atten.in_proj_bias
print('model_int8 Size (MB):', print_size_of_model(model_int8, 'model_int8'))



model_int8(example_inputs)

print("int8 params")

for key, p in model_int8.named_parameters():
    try:
        print(key, p.shape, p.dtype)
    except:
        print(key, p)

print("int8 buffers")

for name, p in model_int8.named_buffers():
    try:
        print(name, p.shape, p.dtype)
    except:
        print(name, p)

here is a fix, landing now:

Thanks for the workaround and fix, HDCharles and Jerry.