Loss is stuck in Quantization Aware Training


I am trying to do QAT for SRCNN following this tutorial

My code does not throw any error but when I train, the loss is always constant and there are lots of 0s in the output. Here is the code of my model:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.name = "QSRCNN"
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=5 // 2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=5, padding=5 // 2)
        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # Convert to NHWC format from NCHW
        x = x.contiguous(memory_format=torch.channels_last)
        x = self.quant(x)
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.dequant(x)
        # Convert back to NCHW format
        x = x.contiguous(memory_format=torch.contiguous_format)

        return x

    def q_config(self, model):
        model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
        model = torch.quantization.fuse_modules(model, [['conv1', 'relu1'], ["conv2", "relu2"]])
        model = torch.quantization.prepare_qat(model, inplace=False)

        return model

And the training loop:

    loss_fn = nn.L1Loss()
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=10e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, verbose=True)
    # Detect the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Init the model and do the quantization
    model = model.q_config(model)
    # Training loop
    for epoch in range(MAX_EPOCH):
        epoch_start_time = time.time()
        running_loss = 0.0
        for i, sample in enumerate(dataloader, 0):
            lr, hr = sample["LR"], sample["HR"]
            lr = lr.to(device)
            hr = hr.to(device)
            # Forward Pass
            sr = model(lr)
            # Loss Calculation
            loss = loss_fn(sr, hr).to(device)
            running_loss += sr.shape[0] * loss.item()
            # Backward pass
        # Epoch loss
        epoch_loss = running_loss / len(dataset)
        epoch_end_time = time.time()
        print("Epoch [{} / {}] || Loss:{:.4f} || Epoch Time:{:.4f}".format(
            epoch + 1, MAX_EPOCH, epoch_loss, epoch_end_time - epoch_start_time))

When I comment out the line model = model.q_config(model) everything works fine and the model is trained properly. Any idea about what is wrong ?


Okay found the mistake, I defined the optimizer before the quantized model. Moving the optimizer line after the model declaration fixed the issue

1 Like

Yes, since we change the model, it is important to call the optimizer after creating the quantized model

1 Like