Error Static Quantizing CRNN Model

Hello! I am a beginner in quantizing PyTorch models, so please forgive me for this is a noob question. I am trying to apply this static quantization example. I was able to run the example’s code successfully, but I am experiencing problems applying it to my model.

This is the model (CNN + LSTM interleaved) that I have declared:

class CRNN_Net(nn.Module) :
    def __init__(self, input_dim, conv_out, kernel, stride, out_dim, hidden_dim, num_hidden):
        super(CRNN_Net, self).__init__()
        self.conv1d_1 = nn.Conv1d(input_dim, conv_out, kernel, stride, padding=int(np.floor(kernel/2)))
        self.bn1d_1 = nn.BatchNorm1d(conv_out)
        self.relu_1 = nn.ReLU()

        self.conv1d_2 = nn.Conv1d(hidden_dim, conv_out, kernel, stride, padding=int(np.floor(kernel/2)))
        self.bn1d_2 = nn.BatchNorm1d(conv_out)
        self.relu_2 = nn.ReLU()

        self.lstm_1 = nn.LSTM(conv_out, hidden_dim, num_hidden, batch_first=True)
        self.lstm_2 = nn.LSTM(conv_out, hidden_dim, num_hidden, batch_first=True)
        self.linear = nn.Linear(hidden_dim, out_dim)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)

        x = x.permute(0, 2, 1)
        first = self.conv1d_1(x)
        first = self.bn1d_1(first)
        first = self.relu_1(first)

        first_out, _ = self.lstm_1(first.permute(0,2,1)) 

        second = self.conv1d_2(first_out.permute(0,2,1))
        second = self.bn1d_2(second)
        second = self.relu_2(second)
       
        BN_feat, _ = self.lstm_2(second.permute(0,2,1))
        output = self.linear(BN_feat)

        output = self.dequant(output)
        return output

And it works fine when I forward this data in regular fashion:

model_fp32 = CRNN_Net(input_dim=80, conv_out=512, kernel=5, stride=1, out_dim=5768, hidden_dim=256, num_hidden=1)
model_fp32.eval()
model_fp32(torch.rand([1, 300, 80]))

Now I am trying to quantize this model. Here is my code attempting to quantize it:

model_fp32 = CRNN_Net(input_dim=80, conv_out=512, kernel=5, stride=1, out_dim=5768, hidden_dim=256, num_hidden=1)
model_fp32.eval()

model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# I'm unsure if this layer fusion is what I'm supposed to do.
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv1d_1', 'bn1d_1', 'relu_1'], 
                                                                ['conv1d_2', 'bn1d_2', 'relu_2']])  
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)

input_fp32 = torch.rand([1, 300, 80])  # An error is experienced here.
model_fp32_prepared(input_fp32)

model_int8 = torch.quantization.convert(model_fp32_prepared)
res = model_int8(input_fp32)

The error I experience is AttributeError: 'tuple' object has no attribute 'numel', which occurs in the forward function of my model when it’s “prepared” (i.e. model_fp32_prepared). This seems to occur when feeding into the LSTM model. Why is this occuring?

Additional questions: Does my current setup for static quantization in my model’s forward function look correct? I’m unsure if I’m supposed to do an overall QuantStub() and DeQuantStub() or if I should do this for each layer. Is how I’m doing the layer fusion ideal?

Any insights would be a great help to me. Thank you for your time!

LSTM has multiple outputs, which prevents the correct observation. You might want to try the “Quantizable LSTM”. Take a look at this test for an example: pytorch/test_quantized_op.py at bb21aea37add0400eaa4ea8317656b7469b38a94 · pytorch/pytorch · GitHub Let me know if it’s confusion – I’ll make a simple writeup

1 Like

Hi I wonder if you have any solution to this. I am having the same issue. Thanks!

Had the same problem with GRU static quantization

import torch
from torch import nn


class GRUModel(nn.Module):
    def __init__(self, inputs_size, hidden_size, num_layer):
        super(GRUModel, self).__init__()
        self.lstm = nn.GRU(inputs_size, hidden_size, num_layer, batch_first=True)
        self.linear = nn.Linear(hidden_size, hidden_size)
        self.hidden_size = hidden_size
        self.num_layer = num_layer

    def forward(self, inputs, hidden):
        output, hidden = self.lstm(inputs, hidden)
        output = self.linear(output)
        return output, hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return weight.new_zeros(self.num_layer, bsz, self.hidden_size)
model = GRUModel(inputs_size=22, hidden_size=22, num_layer=1)

batch_size = 1
input = torch.randn(batch_size, 1, 22)
hidden = model.init_hidden(batch_size)
model.eval()  
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
model_fp32_prepared = torch.quantization.prepare(model)

input_fp32 = torch.randn(1, 1, 22)  # (batch_size, time_step, Hidden_size)
output,hidden = model_fp32_prepared(input_fp32,hidden)
AttributeError: 'tuple' object has no attribute 'numel'

I found that pytorch does not currently support static quantization of GRUs