Using pytorch 1.6.
I’m trying to implement qat on a lstm based model I have.
class Net(torch.nn.Module):
def __init__(self, seq_length):
super(Net, self).__init__()
self.hidden_size = 16
self.input_size = 18
self.seq_length = seq_length
self.relu1 = torch.nn.ReLU()
# Need to specify input sizes up front
# batch_first specifies an input shape of (nBatches, nSeq, nFeatures),
# otherwise this is (nSeq, nBatch, nFeatures)
self.lstm = torch.nn.LSTM(input_size = self.input_size, hidden_size = self.hidden_size, batch_first = True)
self.linear1 = torch.nn.Linear(self.hidden_size, self.hidden_size)
self.dropout = torch.nn.Dropout(0.5) #self.squeeze = torch.squeeze
self.linearOut = torch.nn.Linear(self.hidden_size, 1)
self.sigmoidOut = torch.nn.Sigmoid()
self.sqeeze1 = torch.Tensor.squeeze
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x):
# Can pass the initial hidden state, but not necessary here
#x, h = self.gru1(x)#, self.h0)
#x, h = self.gru(x)#, self.h0)
x, (h,c) = self.lstm(x)#, self.h0)
# Get last output, x[:,l - 1,:], equivalent to (last) hidden state
# Squeeze to remove length 1 dim
x = self.sqeeze1(h)
x = self.dropout(x)
x = self.linear1(x)
x = self.relu1(x)
x = self.linearOut(x)
# Apply sigmoid either in the loss function, or in eval(...)
return x
def evaluate(self,x):
return self.sigmoidOut(self.forward(x))
Standard training works fine, and after preparing the model for qat with,
qat_model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
print(qat_model.qconfig)
qat_model = torch.quantization.prepare_qat(qat_model)
print(qat_model)
I’m running the same training loop, just with different learning rates.
for epoch in tqdm(range(qat_epochs)):
#model.train()
for batch in range(qat_nBatches):
start_time = time.time()
batch_data = data[batch * batch_size : (batch + 1) * batch_size]
batch_seq_lens = seq_lens[batch * batch_size : (batch + 1) * batch_size]
batch_labels = labels[batch * batch_size : (batch + 1) * batch_size]
packedData = pack_padded_sequence(batch_data,
batch_seq_lens,
batch_first = True,
enforce_sorted = False)
output = qat_model(packedData)
loss = lossF(output, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = qat_model.evaluate(packedData).detach().cpu().numpy().flatten()
predClasses = np.zeros(pred.shape)
predClasses[pred > 0.5] = 1
losses.append(loss.detach().cpu().numpy())
accuracy.append(accuracy_score(batch_labels.detach().cpu().numpy().flatten(), predClasses))
packedDataTest = pack_padded_sequence(data[data.shape[0] // 2:],
seq_lens[data.shape[0] // 2:],
batch_first = True,
enforce_sorted = False)
labelsTest = labels[data.shape[0] // 2:]
quantised_model = torch.quantization.convert(qat_model.eval(), inplace = False)
predTestT = qat_model.evaluate(packedDataTest)
predTest = predTestT.detach().cpu().numpy().flatten()
predClassesTest = np.zeros(predTest.shape)
predClassesTest[predTest > 0.5] = 1
lossesTestQAT.append(lossF(predTestT, labelsTest).detach().cpu().numpy().flatten())
accuracyTestQAT.append(accuracy_score(labelsTest.detach().cpu().numpy().flatten(), predClassesTest))
However, I get this error
Traceback (most recent call last):
File "rnn_qat.py", line 307, in <module>
output = qat_model(packedData)
File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "rnn_qat.py", line 59, in forward
x, (h,c) = self.lstm(x)#, self.h0)
File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 726, in _call_impl
hook_result = hook(self, input, result)
File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/quantization/quantize.py", line 74, in _observer_forward_hook
return self.activation_post_process(output)
File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 91, in forward
self.activation_post_process(X.detach())
AttributeError: 'tuple' object has no attribute 'detach'
At first I thought this was to do with the LSTM outputting tuples, as I had to change from GRU to LSTM for quantisation. But, if that was the problem, surely the normal training loop would fail in the same way. Any help is appreciated