Cannot quantize LSTMCell

I’m currently trying to statically quantize the AttentionCell that contains a LSTMCell on an x86 system:

class AttentionCell(nn.Module):

    def __init__(self, input_size, hidden_size, num_embeddings):
        super(AttentionCell, self).__init__()
        self.i2h = nn.Linear(input_size, hidden_size, bias=False)
        self.h2h = nn.Linear(hidden_size, hidden_size)  # either i2i or h2h should have bias
        self.score = nn.Linear(hidden_size, 1, bias=False)
        self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size)
        self.hidden_size = hidden_size

This is the custom config I’m using at the moment:

custom_module_config = {
                'float_to_observed_custom_module_class': {
                    torch.nn.LSTMCell: torch.nn.quantizable.LSTMCell,
                    torch.nn.LSTM: torch.nn.quantizable.LSTM
                }
            }

However, after the preparing the model for quantization, during forwarding some samples to fine-tune the weights I get the following error:

  File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 149, in forward
    return self.module(*inputs, **kwargs)
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/mlpococr/poc_ocr/recognition/model.py", line 121, in forward
    batch_max_length=self.opt.batch_max_length
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/mlpococr/poc_ocr/recognition/modules/prediction.py", line 55, in forward
    hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots)
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/mlpococr/poc_ocr/recognition/modules/prediction.py", line 91, in forward
    cur_hidden = self.rnn(concat_context, prev_hidden)
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 893, in _call_impl
    hook_result = hook(self, input, result)
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/quantization/quantize.py", line 83, in _observer_forward_hook
    return self.activation_post_process(output)
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/cgarriga/projects/datascience-common/ocr_poc/venv/lib/python3.6/site-packages/torch/quantization/observer.py", line 900, in forward
    if x_orig.numel() == 0:
AttributeError: 'tuple' object has no attribute 'numel'

The forward method for the AttentionCell is the following:

    def forward(self, prev_hidden, batch_H, char_onehots):
        # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size]
        batch_H_proj = self.i2h(batch_H)
        prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1)
        e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj))  # batch_size x num_encoder_step * 1

        alpha = F.softmax(e, dim=1)
        context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1)  # batch_size x num_channel
        concat_context = torch.cat([context, char_onehots], 1)  # batch_size x (num_channel + num_embedding)
        cur_hidden = self.rnn(concat_context, prev_hidden)
        return cur_hidden, alpha

How could I solve this issue? Apparently the problem comes from the LSTMCell inside the AttentionCell.

Thank you in advance!

Thanks for pointing this issue out. There is an error in pytorch/rnn.py at master · pytorch/pytorch · GitHub. Filing an issue to track this at: Quantizable LSTMCell does not work correctly. · Issue #55945 · pytorch/pytorch · GitHub

On it – will send a fix asap