Error during QAT training of ResNet50

Hello,

Could smn please help me to understand how can I fix my code. I’m doing a QAT training of ResNet50 and encountering an error in one of the convoluational layers:

RuntimeError: zero_point must be between quant_min and quant_max

Code is relatively simple.

from torchvision.models import resnet50, ResNet50_Weights

from torch.ao.quantization import (
    fuse_modules, 
    QuantWrapper, 
    get_default_qat_qconfig)

qconfig  = get_default_qat_qconfig('fbgemm')


def apply_recursive(model, qconfig):
  model.qconfig = qconfig
  # qconfig should be applied to all children
  for name, module in model.named_children():
      apply_recursive(module, qconfig)
  return model


def get_qat_fused_manual():

  model_fp32 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
  model_fp32.eval()

  model_fp32 = apply_recursive(model_fp32, qconfig)
  
  # fuse layers
  fuse_modules(model_fp32,[['conv1', 'bn1', 'relu']], inplace=True)

  for layer in ['layer1','layer2','layer3', 'layer4']:
    module = model_fp32.get_submodule(layer)
    n_iter = len(module)
    for i in range(n_iter):
      submodule = module.get_submodule(str(i))
      fuse_modules(submodule,[['conv1', 'bn1'], ['conv2', 'bn2'], ['conv3', 'bn3', 'relu']], inplace=True)

      # replace in downsample
      for name, m in submodule.named_children():
        if name == 'downsample':
          fuse_modules(submodule,[['downsample.0', 'downsample.1']], inplace=True)

  model_fused = model_fp32

  model_fused = QuantWrapper(model_fp32)
  model_fused = torch.ao.quantization.prepare_qat(model_fused.train())
  return model_fused

qat_model = get_qat_fused_manual()

data_loader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size = 32,
    sampler=val_sampler,
    num_workers=args.workers
)


device = 'cuda'
parameters = qat_model.parameters()
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0)
optimizer = torch.optim.SGD(
            parameters,
            lr=1e-5,
            momentum=0.9,
            # weight_decay=args.weight_decay,
            nesterov="nesterov" 
        )

qat_model.to(device)
qat_model.train()
for iter, (image, target) in enumerate(data_loader_train):
    print (iter)
    image, target = image.to(device), target.to(device)
    output = qat_model(image)
    loss = criterion(output, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

From what I see the training is happening for several iterations and then the model.forward method triggers an error


/usr/local/lib/python3.10/dist-packages/torch/ao/quantization/stubs.py in forward(self, X)
     61     def forward(self, X):
     62         X = self.quant(X)
---> 63         X = self.module(X)
     64         return self.dequant(X)

....

/usr/local/lib/python3.10/dist-packages/torchvision/models/resnet.py in _forward_impl(self, x)
    266     def _forward_impl(self, x: Tensor) -> Tensor:
    267         # See note [TorchScript super()]
--> 268         x = self.conv1(x)
    269         x = self.bn1(x)
    270         x = self.relu(x)


...

/usr/local/lib/python3.10/dist-packages/torch/ao/quantization/fake_quantize.py in forward(self, X)
    351 
    352     def forward(self, X: torch.Tensor) -> torch.Tensor:
--> 353         return torch.fused_moving_avg_obs_fake_quant(
    354             X,
    355             self.observer_enabled,

RuntimeError: `zero_point` must be between `quant_min` and `quant_max`.

So could anyone point me what is the problem here? Unable to find anything related on the internet.
Thank you.

can you print what the zero_point is? i’d probably edit fake_quantize.py above the return line that’s throwing the error to print the zero_point.

if i were to guess, i dont see any quant stubs being inserted so i would guess that it has something to do with that.

Otherwise, does it work if you do qat normally?
without fusion?

1 Like

Ok, so before the last forward pass it was looking like this

After making one more forward pass:

After playing a bit more with this I found out that replacing:

fuse_module -> fuse_modules_qat

fixes my issue with Observers. Unfortunately this is not mentioned in the link that you’ve mentioned (basically the main manual on qat) or in any other manuals…
https://pytorch.org/docs/stable/quantization.html#quantization-aware-training-for-static-quantization

In any case, appreciate such fast response, HDCharles.

Thanks.