I follow the QAT tutorial to train a model. When I dont use Dataparallel or use Dataparallel with single GPU, it goes well. But when I train model with Dataparallel in multi-GPU, it raise an error just after disable_observer
apply.
The output with error:
/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/quantization/observer.py:208: UserWarning: Must run observer before calling calculate_qparams. Returning default scale and zero point. Returning default scale and zero point.")
[0.5, 0.5, 0.73046875, 0.99609375, 0.953125, 0.5, 0.5, 0.5, 0.5, 0.5]
[0.5, 0.5, 0.73046875, 0.99609375, 0.99609375, 0.5, 0.5, 0.5, 0.5, 0.5]
Traceback (most recent call last):
File "quant_test.py", line 50, in <module>
quant_model = torch.quantization.convert(quant_model.eval().cpu(), inplace=False)
File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/quantization/quantize.py", line 313, in convert
reassign[name] = swap_module(mod, mapping)
File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/quantization/quantize.py", line 335, in swap_module
new_mod = mapping[type(mod)].from_float(mod)
File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py", line 49, in from_float
return super(ConvReLU2d, cls).from_float(mod)
File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py", line 250, in from_float
weight_post_process(mod.weight)
File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/home/eleflea/.pyenv/versions/3.7.4/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 86, in forward
self.ch_axis, self.quant_min, self.quant_max)
RuntimeError: dimensions of scale and zero-point are not consistent with input tensor
My test code:
import torch
from torch import nn, optim
from torch.quantization import QuantStub, DeQuantStub
from copy import deepcopy
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Sequential(
QuantStub(),
nn.Conv2d(3, 16, 1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 10, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(10),
nn.AvgPool2d(14),
nn.Sigmoid(),
DeQuantStub(),
)
torch.quantization.fuse_modules(model, [['1', '2', '3'], ['4', '5']], inplace=True)
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)
optimizer = optim.Adam(model.parameters(), lr=1)
model = nn.DataParallel(model)
model.to(device)
# print(model)
criterion = nn.BCELoss()
for epoch in range(10):
model.train()
inputs = torch.rand(2, 3, 28, 28)
labels = torch.FloatTensor([[1,1,1,1,1,0,0,0,0,0], [1,1,1,1,1,0,0,0,0,0]])
inputs = inputs.to(device)
labels = labels.to(device)
loss = criterion(model(inputs).view(2, 10), labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 2:
model.apply(torch.quantization.disable_observer)
if epoch >= 3:
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
quant_model = deepcopy(model.module)
quant_model = torch.quantization.convert(quant_model.eval().cpu(), inplace=False)
with torch.no_grad():
out = quant_model(torch.rand(1, 3, 28, 28))
print(out.view(10).tolist())