I’m currently trying to statically quantize the AttentionCell that contains a LSTMCell on an x86 system:
class AttentionCell(nn.Module):
def __init__(self, input_size, hidden_size, num_embeddings):
super(AttentionCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias
self.score = nn.Linear(hidden_size, 1, bias=False)
self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
self.hidden_size = hidden_size
This is the custom config I’m using at the moment:
custom_module_config = {
'float_to_observed_custom_module_class': {
torch.nn.LSTMCell: torch.nn.quantizable.LSTMCell,
torch.nn.LSTM: torch.nn.quantizable.LSTM
}
}
However, after the preparing the model for quantization, during forwarding some samples to fine-tune the weights I get the following error:
File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 149, in forward
return self.module(*inputs, **kwargs)
File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/cgarriga/projects/datascience-common/ocr_poc/mlpococr/poc_ocr/recognition/model.py", line 121, in forward
batch_max_length=self.opt.batch_max_length
File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/cgarriga/projects/datascience-common/ocr_poc/mlpococr/poc_ocr/recognition/modules/prediction.py", line 55, in forward
hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/cgarriga/projects/datascience-common/ocr_poc/mlpococr/poc_ocr/recognition/modules/prediction.py", line 91, in forward
cur_hidden = self.rnn(concat_context, prev_hidden)
File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 893, in _call_impl
hook_result = hook(self, input, result)
File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/quantization/quantize.py", line 83, in _observer_forward_hook
return self.activation_post_process(output)
File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/quantization/observer.py", line 900, in forward
if x_orig.numel() == 0:
AttributeError: 'tuple' object has no attribute 'numel'
The forward method for the AttentionCell is the following:
def forward(self, prev_hidden, batch_H, char_onehots):
# [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1
alpha = F.softmax(e, dim=1)
context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel
concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
How could I solve this issue? Apparently the problem comes from the LSTMCell inside the AttentionCell.
Thank you in advance!