I am using using a transformer model for my project where I am using MultiheadAttention. I want to quantize it. In the pytorch/torch/nn/quantizable/modules/activation.py at main · pytorch/pytorch · GitHub it has been mentioned to use the following
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
This is my code which gives an error:
import torch
import torch.nn as nn
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
class DummyModel(nn.Module):
def __init__(self, embed_dim, num_heads):
super(DummyModel, self).__init__()
self.embedding = nn.Embedding(10, embed_dim)
self.attention = MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, embed_dim)
self.output = nn.Linear(embed_dim, 1)
def forward(self, x):
x = self.embedding(x)
x, _ = self.attention(x, x, x)
x = self.linear(x)
x = torch.relu(x)
x = self.output(x)
return x
# Instantiate the model
embed_dim = 16
num_heads = 2
model = DummyModel(embed_dim, num_heads)
from torch.ao.quantization import quantize_dynamic
# Switch the model to evaluation mode
model.eval()
# Apply dynamic quantization
quantized_model = quantize_dynamic(
model,
{nn.Linear, MultiheadAttention},
dtype=torch.qint8
)