Quantisation aware training LSTM with pack_padded_sequences?

Using pytorch 1.6.
I’m trying to implement qat on a lstm based model I have.

class Net(torch.nn.Module):
    def __init__(self, seq_length):
        super(Net, self).__init__()

        self.hidden_size = 16
        self.input_size = 18

        self.seq_length = seq_length

        self.relu1 = torch.nn.ReLU()
        # Need to specify input sizes up front

        # batch_first specifies an input shape of (nBatches, nSeq, nFeatures),
        # otherwise this is (nSeq, nBatch, nFeatures)
        self.lstm = torch.nn.LSTM(input_size = self.input_size, hidden_size = self.hidden_size, batch_first = True)
        self.linear1 = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.dropout = torch.nn.Dropout(0.5)        #self.squeeze = torch.squeeze
        self.linearOut = torch.nn.Linear(self.hidden_size, 1)
        self.sigmoidOut = torch.nn.Sigmoid()
        self.sqeeze1 = torch.Tensor.squeeze
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    def forward(self, x):

        # Can pass the initial hidden state, but not necessary here
        #x, h = self.gru1(x)#, self.h0)
        #x, h  = self.gru(x)#, self.h0)
        x, (h,c) = self.lstm(x)#, self.h0)

        # Get last output, x[:,l - 1,:], equivalent to (last) hidden state
        # Squeeze to remove length 1 dim
        x = self.sqeeze1(h)

        x = self.dropout(x)

        x = self.linear1(x)
        x = self.relu1(x)
        x = self.linearOut(x)

        # Apply sigmoid either in the loss function, or in eval(...)
        return x
    def evaluate(self,x):
        return self.sigmoidOut(self.forward(x))

Standard training works fine, and after preparing the model for qat with,

qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
qat_model = torch.quantization.prepare_qat(qat_model)

I’m running the same training loop, just with different learning rates.

for epoch in tqdm(range(qat_epochs)):
    for batch in range(qat_nBatches):
        start_time = time.time()
        batch_data = data[batch * batch_size : (batch + 1) * batch_size]
        batch_seq_lens = seq_lens[batch * batch_size : (batch + 1) * batch_size]
        batch_labels = labels[batch * batch_size : (batch + 1) * batch_size]
        packedData = pack_padded_sequence(batch_data,
                                          batch_first = True,
                                          enforce_sorted = False)
        output = qat_model(packedData)
        loss = lossF(output, batch_labels)
        pred = qat_model.evaluate(packedData).detach().cpu().numpy().flatten()
    predClasses = np.zeros(pred.shape)
    predClasses[pred > 0.5] = 1


    accuracy.append(accuracy_score(batch_labels.detach().cpu().numpy().flatten(), predClasses))

    packedDataTest = pack_padded_sequence(data[data.shape[0] // 2:],
                                          seq_lens[data.shape[0] // 2:],
                                          batch_first = True,
                                          enforce_sorted = False)

    labelsTest = labels[data.shape[0] // 2:]

    quantised_model = torch.quantization.convert(qat_model.eval(), inplace = False)
    predTestT = qat_model.evaluate(packedDataTest)
    predTest = predTestT.detach().cpu().numpy().flatten()
    predClassesTest = np.zeros(predTest.shape)
    predClassesTest[predTest > 0.5] = 1

    lossesTestQAT.append(lossF(predTestT, labelsTest).detach().cpu().numpy().flatten())
    accuracyTestQAT.append(accuracy_score(labelsTest.detach().cpu().numpy().flatten(), predClassesTest))

However, I get this error

Traceback (most recent call last):
  File "rnn_qat.py", line 307, in <module>
    output = qat_model(packedData)
  File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "rnn_qat.py", line 59, in forward
    x, (h,c) = self.lstm(x)#, self.h0)
  File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 726, in _call_impl
    hook_result = hook(self, input, result)
  File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/quantization/quantize.py", line 74, in _observer_forward_hook
    return self.activation_post_process(output)
  File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 91, in forward
AttributeError: 'tuple' object has no attribute 'detach'

At first I thought this was to do with the LSTM outputting tuples, as I had to change from GRU to LSTM for quantisation. But, if that was the problem, surely the normal training loop would fail in the same way. Any help is appreciated

I don’t think qat is supported for LSTM. cc @raghuramank100 @supriyar to confirm

That’s right, we currently do not support QAT for nn.LSTM. We support dynamic quantization of LSTM modules currently

Do you plan on supporting QAT for nn.LSTM?

@asorin, not at the moment. Please submit a feature request if you have a strong use-case for it and we can take a look.