Loss is stuck in Quantization Aware Training


I am trying to do QAT for SRCNN following this tutorial

My code does not throw any error but when I train, the loss is always constant and there are lots of 0s in the output. Here is the code of my model:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.name = "QSRCNN"
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, padding=5 // 2)
        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # Convert to NHWC format from NCHW
        x = x.contiguous(memory_format=torch.channels_last)
        x = self.quant(x)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.dequant(x)
        # Convert back to NCHW format
        x = x.contiguous(memory_format=torch.contiguous_format)

        return x

    def q_config(self, model):
        model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
        model = torch.quantization.fuse_modules(model, [['conv1', 'relu1'], ["conv2", "relu2"]])
        model = torch.quantization.prepare_qat(model, inplace=False)

        return model

And the training loop:

    loss_fn = nn.L1Loss()
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=10e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, verbose=True)
    # Detect the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Init the model and do the quantization
    model = model.q_config(model)
    # Training loop
    for epoch in range(MAX_EPOCH):
        epoch_start_time = time.time()
        running_loss = 0.0
        for i, sample in enumerate(dataloader, 0):
            lr, hr = sample["LR"], sample["HR"]
            lr = lr.to(device)
            hr = hr.to(device)
            # Forward Pass
            sr = model(lr)
            # Loss Calculation
            loss = loss_fn(sr, hr).to(device)
            running_loss += sr.shape[0] * loss.item()
            # Backward pass
        # Epoch loss
        epoch_loss = running_loss / len(dataset)
        epoch_end_time = time.time()
        print("Epoch [{} / {}] || Loss:{:.4f} || Epoch Time:{:.4f}".format(
            epoch + 1, MAX_EPOCH, epoch_loss, epoch_end_time - epoch_start_time))

When I comment out the line model = model.q_config(model) everything works fine and the model is trained properly. Any idea about what is wrong ?


Okay found the mistake, I defined the optimizer before the quantized model. Moving the optimizer line after the model declaration fixed the issue

Yes, since we change the model, it is important to call the optimizer after creating the quantized model

