QAT: precision drop to 0


I have a pytorch trained model that reached a testing precision of 90+ %.
Then I run QAT for this model and converted it to quantizied_model. Then I tested the quantized_model with the same test dataset but the precision drop to 0%.

The framework for my QAT is as below:

	model = create_model(num_classes=num_classes)

    train_loader, test_loader = prepare_dataloader(config, num_workers=8, train_batch_size=128, eval_batch_size=256)

    # Load a pretrained model.
    model = load_model(model=model, model_filepath=model_filepath, device=cuda_device)

    # Move the model to CPU since static quantization does not support CUDA currently.

    # Make a copy of the model for layer fusion
    fused_model = copy.deepcopy(model)



    # Fuse the model in place rather manually.
    for module_name, module in fused_model.named_children():
        if "extras" in module_name:
            for basic_block_name, basic_block in module.named_children():
                torch.quantization.fuse_modules(basic_block, [["0", "1", "2"]], inplace=True)
                for sub_block_name, sub_block in basic_block.named_children():
                    if sub_block_name == "Sequential":
                        torch.quantization.fuse_modules(sub_block, [["0", "1", "2"]], inplace=True)

    # Model and fused model should be equivalent.
    assert model_equivalence(model_1=model, model_2=fused_model, device=cpu_device, rtol=1e-03, atol=1e-06, num_tests=100, input_size=(1,3,32,32)), "Fused model is not equivalent to the original model!"

    quantized_model = QuantizedRSSD(model_fp32=fused_model)
    quantization_config = torch.quantization.get_default_qconfig("fbgemm")
    quantized_model.qconfig = quantization_config
    prepared = torch.quantization.prepare_qat(quantized_model, inplace=True)

    # # Use training data for calibration.
    if not os.path.exists(quantized_model_filepath):
        print("Training QAT Model...")
        train_model(model=prepared, config=config, model_filename=quantized_model_filename, 
                    train_loader=train_loader, test_loader=test_loader, device=cuda_device, 
                    learning_rate=1e-2, num_epochs=3)

    quantized_model = torch.quantization.convert(prepared, inplace=True)
    if not os.path.exists(quantized_model_filepath):
        save_model(model=quantized_model, model_dir=model_dir, model_filename=quantized_model_filename)

What am I missing in my QAT causing the precision drop to 0? Please help. Thanks.

QAT needs careful tuning, my guess is your learning rate might be too large here