Hi, I was trying to quantize Vision Transformer from timm
as follows:
class QuantAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# quantization stubs
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
B, N, C = x.shape
x = self.quant(x)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.dequant(x)
x = self.proj_drop(x)
return x
However, some operations, in particular:
q = q * self.scale
- multiplication by scalarattn = q @ k.transpose(-2, -1)
batched matmul
are not supported byQuantizedCPU
and the error message claims the following:
NotImplementedError: Could not run 'aten::bmm.out' with arguments from the 'QuantizedCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build).
The quantization pipeline is the following:
quant_attn.qconfig = torch.ao.quantization.get_default_qconfig('x86')
quant_attn = torch.ao.quantization.prepare(quant_attn)
for _ in range(10):
inputs = torch.randn(1, 16, 128, device=device)
quant_attn(inputs)
quant_attn = torch.ao.quantization.convert(quant_attn)
The message appears on the call:
inputs = torch.randn(1, 16, 128, device=device)
outputs = quant_attn(inputs)
Where can one find a list of supported operations to better understand the limitations of the current Quantization API?
Environment
torch 2.1.0+cu118
torchvision 0.16.0+cu118