Error when QAT with Dataparallel in multi-GPU

I follow the QAT tutorial to train a model. When I dont use Dataparallel or use Dataparallel with single GPU, it goes well. But when I train model with Dataparallel in multi-GPU, it raise an error just after disable_observer apply.
The output with error:

/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/quantization/observer.py:208: UserWarning: Must run observer before calling calculate_qparams.                           Returning default scale and zero point.          Returning default scale and zero point.")
[0.5, 0.5, 0.73046875, 0.99609375, 0.953125, 0.5, 0.5, 0.5, 0.5, 0.5]
[0.5, 0.5, 0.73046875, 0.99609375, 0.99609375, 0.5, 0.5, 0.5, 0.5, 0.5]
Traceback (most recent call last):
  File "quant_test.py", line 50, in <module>
    quant_model = torch.quantization.convert(quant_model.eval().cpu(), inplace=False)
  File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/quantization/quantize.py", line 313, in convert
    reassign[name] = swap_module(mod, mapping)
  File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/quantization/quantize.py", line 335, in swap_module
    new_mod = mapping[type(mod)].from_float(mod)
  File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py", line 49, in from_float
    return super(ConvReLU2d, cls).from_float(mod)
  File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py", line 250, in from_float
    weight_post_process(mod.weight)
  File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 86, in forward
    self.ch_axis, self.quant_min, self.quant_max)
RuntimeError: dimensions of scale and zero-point are not consistent with input tensor

My test code:

import torch
from torch import nn, optim
from torch.quantization import QuantStub, DeQuantStub
from copy import deepcopy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = nn.Sequential(
    QuantStub(),
    nn.Conv2d(3, 16, 1, bias=False),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    nn.Conv2d(16, 10, 3, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(10),
    nn.AvgPool2d(14),
    nn.Sigmoid(),
    DeQuantStub(),
)

torch.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)

model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)
optimizer = optim.Adam(model.parameters(), lr=1)
model = nn.DataParallel(model)
model.to(device)
# print(model)

criterion = nn.BCELoss()

for epoch in range(10):
    model.train()

    inputs = torch.rand(2, 3, 28, 28)
    labels = torch.FloatTensor([[1,1,1,1,1,0,0,0,0,0], [1,1,1,1,1,0,0,0,0,0]])

    inputs = inputs.to(device)
    labels = labels.to(device)
    loss = criterion(model(inputs).view(2, 10), labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 2:
        model.apply(torch.quantization.disable_observer)
    if epoch >= 3:
        model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

    quant_model = deepcopy(model.module)
    quant_model = torch.quantization.convert(quant_model.eval().cpu(), inplace=False)
    with torch.no_grad():
        out = quant_model(torch.rand(1, 3, 28, 28))
        print(out.view(10).tolist())

we recently fixed a bug in fake_quant: https://github.com/pytorch/pytorch/commit/e236e1593468a68e47c5bcafd7272eca01684294, are you running with master?

My pytorch version is 1.5.0.dev20200315. This version was built after that commit. I will try latest nightly version later.

I have tested again with pytorch 1.5.0, and the error persists. It gets error when I use Dataparallel.

@supriyar can you take a look?

@eleflea this seems to be a true error. We are looking into this and will post an update when it is fixed.
Do you also see this error if you load pretrained fp32 weights and then run QAT for few iterations (using the same setup)?

@supriyar I have tried to load a normal trained fp32 model, and continue with QAT. I disable observer after 2 epoches, and freeze bn after 2 more epoches.
When I trained with single GPU, the loss was around 13(slightly increase after disabling observer). Everything goes smoothly.
When I trained with 2 GPUs, the loss increased to ~1200 after disabling observer. More iterations seems useless.
I further printed out the first layer during QAT. I found that quant parameters was not updated(compared to using single GPU). Scale and zero_point stuck at 1 and 0, and observer’s min_val & max_val keep empty.
Part of the output is as follows:

Sequential(
  (conv): ConvBnReLU2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=True, observer_enabled=True,            scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0')
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([], device='cuda:0'), max_val=tensor([], device='cuda:0'))
    )
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=True, observer_enabled=True,            scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0')
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([], device='cuda:0'), max_val=tensor([], device='cuda:0'))
    )
  )
  (bn): Identity()
  (act): Identity()
)
lr: 0.000002    epoch: 0/80     step: 13        train_loss: 13.425(xy: 2.847, conf: 6.531, cls: 4.047)   Sequential(
  (conv): ConvBnReLU2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=True, observer_enabled=True,            scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0')
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([], device='cuda:0'), max_val=tensor([], device='cuda:0'))
    )
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=True, observer_enabled=True,            scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0')
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([], device='cuda:0'), max_val=tensor([], device='cuda:0'))
    )
  )
  (bn): Identity()
  (act): Identity()
)
lr: 0.000004    epoch: 0/80     step: 26        train_loss: 13.825(xy: 2.980, conf: 6.424, cls: 4.421)   Sequential(
  (conv): ConvBnReLU2d(
    3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=True, observer_enabled=True,            scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0')
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([], device='cuda:0'), max_val=tensor([], device='cuda:0'))
    )
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=True, observer_enabled=True,            scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0')
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([], device='cuda:0'), max_val=tensor([], device='cuda:0'))
    )
  )
  (bn): Identity()
  (act): Identity()
)
lr: 0.000006    epoch: 0/80     step: 39        train_loss: 13.173(xy: 3.110, conf: 6.385, cls: 3.678)