Quantizing object detection model backbone inconsistency in evaluation result

I am building an object detection faster rcnn model using the quantized mobilenet v as backbone. So far I have managed to run quantization aware training by running the following functions before the training loop

model.backbone.fuse_model() -- (1) 
model.backbone.qconfig = torch.quantization.default_qat_qconfig -- (2) 
model = torch.quantization.prepare_qat(model, inplace=True) -- (3)
model.train()

During training, I run the model on CUDA. After the training I run the following to convert the model back to cpu, quantize it, evaluate and save the model weights.

trained_model.cpu()
trained_model_quantized = convert(trained_model, inplace=**False**) -- (4)
trained_model_quantized.eval()
test_model(trained_model_quantized, dataloader_test)
torch.save(trained_model_quantized.state_dict(), "trained_model_quantized.pt")

On another session, I managed to load the model weights back in by creating a new model, running the quantization steps (1), (2), (3) and (4) on the model and calling model.load(torch.load(PATH_TO_SAVE_WEIGHTS_DICT)) When I attempt to test the model on the same dataset I tested with, I get nonsensical results, as if the model has not been trained before.

One example is as follows below:

[{'boxes': tensor([[2.2284e+02, 6.9949e-01, 2.5754e+02, 3.5039e+01],
        [1.9123e+02, 1.4250e+01, 2.4345e+02, 6.1587e+01],
        [1.8320e+01, 6.3590e-01, 5.3564e+01, 3.5422e+01],
        [2.9484e+02, 2.3753e-01, 3.2000e+02, 1.2851e+01],
        [2.0365e+02, 5.0668e+00, 2.5801e+02, 5.1970e+01],
        [3.0826e+01, 7.2213e-01, 6.6313e+01, 3.4504e+01],
        [1.1139e+01, 2.5874e+02, 6.4464e+01, 3.0664e+02],
        [2.2778e+02, 4.8784e+00, 2.8297e+02, 5.0020e+01],
        [1.1499e+02, 1.8001e+02, 1.6611e+02, 2.2816e+02],
        [1.8020e+02, 2.9482e+02, 2.0485e+02, 3.2000e+02],
        [1.4158e+02, 4.6937e-01, 1.6599e+02, 2.5700e+01],
        [1.6759e+02, 4.9485e-01, 2.1854e+02, 2.4960e+01],
        [0.0000e+00, 2.7912e+02, 9.5221e+00, 3.1460e+02],
        [1.2910e+02, 4.4345e-01, 1.7992e+02, 2.4563e+01],
        [9.0239e+01, 5.0788e-01, 1.4100e+02, 2.3526e+01],
        [5.1181e+01, 2.8405e+02, 1.0670e+02, 3.2000e+02],
        [3.7199e+01, 1.5634e+02, 9.1446e+01, 2.0341e+02],
        [5.2274e+01, 9.1355e-01, 1.0492e+02, 3.7529e+01],
        [2.4848e+02, 6.8777e-01, 2.8421e+02, 3.4754e+01],
        [2.6224e+01, 2.4715e-01, 5.1620e+01, 1.3250e+01],
        [0.0000e+00, 2.1927e+02, 7.0992e+01, 2.9383e+02],
        [6.2371e+01, 6.7974e-01, 2.1039e+02, 3.4605e+01],
        [1.4141e+02, 2.2420e-01, 1.6636e+02, 1.2484e+01],
        [0.0000e+00, 2.6777e-01, 1.2510e+01, 1.2869e+01],
        [2.6803e+02, 3.6809e-01, 3.2000e+02, 1.8864e+01],
        [1.2125e+02, 6.3132e-01, 1.5740e+02, 3.4166e+01],
        [3.7071e+01, 1.8162e+02, 9.0445e+01, 2.2925e+02],
        [1.0853e+02, 7.0764e-01, 1.4476e+02, 3.4172e+01],
        [0.0000e+00, 6.8656e-01, 7.1893e+01, 3.4567e+01],
        [0.0000e+00, 2.9469e+02, 1.2800e+01, 3.2000e+02],
        [0.0000e+00, 2.4723e+02, 1.7331e+01, 3.1682e+02],
        [0.0000e+00, 2.4664e+02, 1.6348e+01, 2.6562e+02],
        [2.3170e+02, 4.9811e-01, 2.8183e+02, 2.4400e+01],
        [0.0000e+00, 2.9816e+02, 4.5395e+00, 3.1654e+02],
        [1.2854e+02, 2.3414e-01, 1.5374e+02, 1.2642e+01],
        [1.3461e+01, 2.5192e-01, 3.9521e+01, 1.2805e+01],
        [1.7189e+02, 2.7412e+02, 2.0622e+02, 3.2000e+02],
        [0.0000e+00, 2.4050e+02, 1.0271e+01, 2.7527e+02],
        [2.9492e+02, 4.7859e-01, 3.2000e+02, 2.5724e+01],
        [0.0000e+00, 1.0180e+02, 2.5666e+01, 1.4887e+02],
        [7.5661e+01, 2.0977e+02, 1.2900e+02, 2.5681e+02],
        [2.5359e+02, 9.9282e-01, 3.2000e+02, 5.0407e+01],
        [0.0000e+00, 2.5901e+02, 4.0501e+01, 3.0663e+02],
        [2.0364e+02, 6.6663e+01, 2.5525e+02, 1.1442e+02],
        [2.0547e+02, 4.2242e+01, 2.5756e+02, 8.8740e+01],
        [0.0000e+00, 4.0732e-01, 3.8736e+01, 1.9843e+01],
        [1.1571e+02, 2.8211e-01, 1.4147e+02, 1.3143e+01],
        [0.0000e+00, 2.2975e+00, 1.5405e+02, 1.0976e+02],
        [2.9816e+02, 2.2668e+01, 3.1766e+02, 5.5445e+01],
        [5.8469e-01, 2.7204e-01, 2.5698e+01, 1.3041e+01],
        [0.0000e+00, 5.1108e+01, 1.3950e+01, 7.7417e+01],
        [1.4076e+02, 2.8269e+02, 1.9360e+02, 3.2000e+02],
        [2.3352e+01, 1.3232e+02, 7.7268e+01, 1.7938e+02],
        [2.2999e+02, 1.6494e+00, 3.2000e+02, 7.5008e+01],
        [0.0000e+00, 6.3699e+01, 1.3418e+01, 8.9596e+01],
        [6.2804e+01, 2.1912e+02, 1.1618e+02, 2.6685e+02],
        [0.0000e+00, 2.4379e+02, 2.6497e+01, 2.9244e+02],
        [1.7839e+02, 2.8245e+02, 2.3241e+02, 3.2000e+02],
        [8.9713e+01, 3.6764e+00, 1.4348e+02, 5.1272e+01],
        [5.8155e+00, 6.0722e-01, 4.1412e+01, 3.4975e+01],
        [0.0000e+00, 1.6554e+02, 2.6498e+01, 2.1379e+02],
        [2.5622e+02, 5.2321e-01, 3.0743e+02, 2.5142e+01],
        [2.6498e+02, 2.0179e+01, 3.2000e+02, 5.5930e+01],
        [7.4018e+01, 1.1732e+02, 1.3037e+02, 1.6463e+02],
        [2.7450e+02, 7.7239e-01, 3.1118e+02, 3.5411e+01],
        [2.1570e+02, 2.7996e+01, 2.6940e+02, 7.5383e+01],
        [0.0000e+00, 2.5685e+02, 1.3539e+01, 2.8232e+02],
        [1.0324e+02, 2.8183e+02, 1.5529e+02, 3.2000e+02],
        [5.7554e+01, 9.1042e+00, 2.8737e+02, 2.3098e+02],
        [2.4072e+01, 2.4965e+02, 7.9286e+01, 2.9569e+02],
        [1.0386e+02, 2.5957e+01, 1.2860e+02, 5.1814e+01],
        [0.0000e+00, 2.1812e+02, 2.7382e+01, 2.6644e+02],
        [0.0000e+00, 2.6893e+02, 1.3575e+01, 2.9465e+02],
        [2.4422e+02, 2.8380e+02, 2.9725e+02, 3.2000e+02],
        [0.0000e+00, 2.6925e+02, 7.4092e+01, 3.2000e+02],
        [0.0000e+00, 5.4974e+01, 3.8673e+01, 1.0265e+02],
        [0.0000e+00, 3.8855e+01, 1.3879e+01, 6.4523e+01],
        [2.1896e+02, 1.5610e+02, 2.6900e+02, 2.0563e+02],
        [1.3965e+02, 9.2438e+01, 1.9291e+02, 1.4087e+02],
        [0.0000e+00, 1.4937e+02, 5.0214e+01, 2.4080e+02],
        [1.1467e+02, 2.3391e+00, 1.6681e+02, 4.8360e+01],
        [0.0000e+00, 2.5561e+01, 2.7101e+01, 7.3610e+01],
        [1.3790e+02, 2.0647e+02, 1.9230e+02, 2.5392e+02],
        [9.0749e+01, 1.6114e+02, 3.2000e+02, 3.1993e+02],
        [1.3997e+02, 7.7576e-01, 1.9317e+02, 3.8217e+01],
        [2.2809e+02, 1.0366e+02, 2.8295e+02, 1.5036e+02],
        [0.0000e+00, 7.9377e+00, 4.7207e+01, 1.0263e+02],
        [2.4093e+02, 1.1915e+02, 2.9302e+02, 1.6518e+02],
        [1.0355e+02, 2.9433e+02, 1.2756e+02, 3.1988e+02],
        [0.0000e+00, 1.6336e+02, 1.4555e+02, 3.2000e+02],
        [2.8921e+02, 2.8545e+01, 3.2000e+02, 4.7252e+01],
        [2.0485e+02, 2.8252e+02, 2.5715e+02, 3.2000e+02],
        [0.0000e+00, 1.7163e+01, 1.8367e+02, 2.4122e+02],
        [1.1222e+02, 4.1794e+01, 1.6605e+02, 8.9092e+01],
        [0.0000e+00, 2.2235e+02, 1.7401e+01, 2.9148e+02],
        [1.7722e+02, 2.7057e+01, 2.3399e+02, 7.3110e+01],
        [2.4246e+02, 1.7376e+01, 2.9653e+02, 6.3945e+01],
        [4.5475e+01, 2.7231e+02, 1.8952e+02, 3.2000e+02],
        [1.8212e+02, 2.6980e+02, 3.2000e+02, 3.2000e+02],
        [1.7683e+02, 6.9158e-01, 2.3024e+02, 3.7515e+01]],
       grad_fn=<StackBackward>), 'labels': tensor([14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 14, 14]), 'scores': tensor([0.0700, 0.0661, 0.0660, 0.0657, 0.0652, 0.0650, 0.0650, 0.0649, 0.0646,
        0.0645, 0.0645, 0.0645, 0.0644, 0.0644, 0.0643, 0.0643, 0.0642, 0.0642,
        0.0642, 0.0642, 0.0641, 0.0640, 0.0640, 0.0640, 0.0638, 0.0637, 0.0636,
        0.0635, 0.0635, 0.0635, 0.0633, 0.0633, 0.0633, 0.0632, 0.0631, 0.0630,
        0.0629, 0.0629, 0.0628, 0.0628, 0.0628, 0.0627, 0.0626, 0.0625, 0.0625,
        0.0625, 0.0623, 0.0623, 0.0621, 0.0621, 0.0620, 0.0619, 0.0619, 0.0618,
        0.0618, 0.0618, 0.0617, 0.0616, 0.0616, 0.0616, 0.0616, 0.0616, 0.0615,
        0.0615, 0.0614, 0.0614, 0.0613, 0.0613, 0.0612, 0.0612, 0.0612, 0.0612,
        0.0611, 0.0611, 0.0610, 0.0610, 0.0610, 0.0610, 0.0609, 0.0608, 0.0607,
        0.0607, 0.0607, 0.0606, 0.0606, 0.0605, 0.0604, 0.0604, 0.0604, 0.0603,
        0.0603, 0.0602, 0.0602, 0.0602, 0.0602, 0.0601, 0.0601, 0.0601, 0.0600,
        0.0599], grad_fn=<IndexBackward>)}]

Sometimes I do not receive any output from the model which is weird.

[{‘boxes’: tensor([], size=(0, 4), grad_fn=), ‘labels’: tensor([], dtype=torch.int64), ‘scores’: tensor([], grad_fn=)}]

I am not sure why there might be a discrepancy when I am loading the same weights trained earlier. Appreciate any advice and guidance.

Hi @yichong96
To clarify, does the first time you evaluate (after running convert) work fine?

I would try to inspect the model parameters and state_dict using both approaches (i.e. directly running eval, vs re-creating the quantized model and loading weights).

If you have a minimal example that could repro this issue, we are happy to take a look.

Hey thanks for your reply. I realised that I had extra classifier layers in the backbone which were not used since I took the whole mobilenet model. I managed to remedy it by just taking the .features portion of the original backbone network.