Hello all,
I try to quantize nn.TransformerEncoder
, but get errors during inference.
The problem is with nn.MultiheadAttention
, which is basically a set of nn.Linear
operations and should work OK after quantization.
Minimal example:
import torch
mlth = torch.nn.MultiheadAttention(512, 8)
possible_input = torch.rand((10, 10, 512))
quatized = torch.quantization.quantize_dynamic(mlth)
quatized(possible_input, possible_input, possible_input)
It fails with:
/opt/miniconda/lib/python3.7/site-packages/torch/nn/functional.py in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
3946 assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
3947 attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
-> 3948 attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
3949
3950 if need_weights:
/opt/miniconda/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1610 ret = torch.addmm(bias, input, weight.t())
1611 else:
-> 1612 output = input.matmul(weight.t())
1613 if bias is not None:
1614 output += bias
AttributeError: 'function' object has no attribute 't'
That's because `.weight` is not parameter anymore, but the method (for components of the quantized module).
You can check it like:
mlth.out_proj.weight
Parameter containing:
tensor([[-0.0280, 0.0016, 0.0163, ..., 0.0375, 0.0153, -0.0435],
[-0.0168, 0.0310, -0.0211, ..., -0.0258, 0.0043, -0.0094],
[ 0.0412, -0.0078, 0.0262, ..., 0.0328, 0.0439, 0.0066],
...,
[-0.0278, 0.0337, 0.0189, ..., -0.0402, 0.0193, -0.0163],
[ 0.0034, -0.0364, -0.0418, ..., -0.0248, -0.0375, -0.0236],
[-0.0312, 0.0236, 0.0404, ..., 0.0266, 0.0255, 0.0265]],
requires_grad=True)
while
quatized.out_proj.weight
<bound method Linear.weight of DynamicQuantizedLinear(in_features=512, out_features=512, qscheme=torch.per_tensor_affine)>
Can you please guide me about this? Is it expected behavior? Should I report it to pyTorch GitHub issues?
It looks like quantization break all the module which use .weight
inside.
Thanks in advance