QAT: precision drop to 0

Hi,

I have a pytorch trained model that reached a testing precision of 90+ %.
Then I run QAT for this model and converted it to quantizied_model. Then I tested the quantized_model with the same test dataset but the precision drop to 0%.

The framework for my QAT is as below:

	model = create_model(num_classes=num_classes)

    train_loader, test_loader = prepare_dataloader(config, num_workers=8, train_batch_size=128, eval_batch_size=256)

    # Load a pretrained model.
    model = load_model(model=model, model_filepath=model_filepath, device=cuda_device)

    # Move the model to CPU since static quantization does not support CUDA currently.
    model.to(cpu_device)

    # Make a copy of the model for layer fusion
    fused_model = copy.deepcopy(model)

    model.train()

    fused_model.train()

    # Fuse the model in place rather manually.
    
    for module_name, module in fused_model.named_children():
        if "extras" in module_name:
            for basic_block_name, basic_block in module.named_children():
                torch.quantization.fuse_modules(basic_block, [["0", "1", "2"]], inplace=True)
                for sub_block_name, sub_block in basic_block.named_children():
                    if sub_block_name == "Sequential":
                        torch.quantization.fuse_modules(sub_block, [["0", "1", "2"]], inplace=True)


    # Model and fused model should be equivalent.
    model.eval()
    fused_model.eval()
    assert model_equivalence(model_1=model, model_2=fused_model, device=cpu_device, rtol=1e-03, atol=1e-06, num_tests=100, input_size=(1,3,32,32)), "Fused model is not equivalent to the original model!"

    quantized_model = QuantizedRSSD(model_fp32=fused_model)
    
    quantization_config = torch.quantization.get_default_qconfig("fbgemm")
    
    quantized_model.qconfig = quantization_config
    
    prepared = torch.quantization.prepare_qat(quantized_model, inplace=True)

    prepared.apply(torch.quantization.enable_observer)
    # # Use training data for calibration.
    if not os.path.exists(quantized_model_filepath):
        print("Training QAT Model...")
        quantized_model.train()
        
        prepared.apply(torch.quantization.enable_fake_quant)
        
        train_model(model=prepared, config=config, model_filename=quantized_model_filename, 
                    train_loader=train_loader, test_loader=test_loader, device=cuda_device, 
                    learning_rate=1e-2, num_epochs=3)
    
    prepared.to(cpu_device)

    quantized_model = torch.quantization.convert(prepared, inplace=True)
    
    quantized_model.eval()
   
    if not os.path.exists(quantized_model_filepath):
        save_model(model=quantized_model, model_dir=model_dir, model_filename=quantized_model_filename)

What am I missing in my QAT causing the precision drop to 0? Please help. Thanks.

QAT needs careful tuning, my guess is your learning rate might be too large here