Can QAT inference on CUDA?

Hi,

I know that static & dynamic quantization cannot inference with CUDA.

but I am wondering that QAT model can inference with CUDA.

Thanks.

From the PyTorch Quantization docs

Quantization-aware training (through FakeQuantize) supports both CPU and CUDA

Quantization doc says that it does support both CPU and GPU. I tried the tutorial and it didn’t work.
I am still confusing because some of users are saying it does not support GPU yet from Is QAT Inference not support GPU?

You can run a QAT model prior to convert on GPU. Please look at the example in torchvision: vision/train_quantization.py at master · pytorch/vision · GitHub

After I applied QAT method and I tried to inference the model with GPU but I got this error below. However CPU is working fine.

File “quantize_model.py”, line 359, in
model = quantization_aware_training(model, device)
File “quantize_model.py”, line 120, in quantization_aware_training
torch.quantization.convert(quantized_eval_model, inplace=True)
File “/opt/conda/lib/python3.8/site-packages/torch/quantization/quantize.py”, line 471, in convert
_convert(
File “/opt/conda/lib/python3.8/site-packages/torch/quantization/quantize.py”, line 507, in _convert
_convert(mod, mapping, True, # inplace
File “/opt/conda/lib/python3.8/site-packages/torch/quantization/quantize.py”, line 509, in _convert
reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
File “/opt/conda/lib/python3.8/site-packages/torch/quantization/quantize.py”, line 534, in swap_module
new_mod = mapping[type(mod)].from_float(mod)
File “/opt/conda/lib/python3.8/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py”, line 97, in from_float
return super(ConvReLU2d, cls).from_float(mod)
File “/opt/conda/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py”, line 418, in from_float
return _ConvNd.from_float(cls, mod)
File “/opt/conda/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py”, line 220, in from_float
return cls.get_qconv(mod, activation_post_process, weight_post_process)
File “/opt/conda/lib/python3.8/site-packages/torch/nn/quantized/modules/conv.py”, line 187, in get_qconv
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
File “/opt/conda/lib/python3.8/site-packages/torch/nn/quantized/modules/utils.py”, line 14, in _quantize_weight
qweight = torch.quantize_per_channel(
RuntimeError: Could not run ‘aten::quantize_per_channel’ with arguments from the ‘CUDA’ backend. This could be because the operator doesn’t exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. ‘aten::quantize_per_channel’ is only available for these backends: [CPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

Quantization
model.to(device)
model.eval()
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(‘fbgemm’)
torch.quantization.prepare_qat(model, inplace=True)

Evaluation
with torch.no_grad():
model.eval()
epoch_psnr = AverageMeter()
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(device)
torch.quantization.convert(quantized_eval_model, inplace=True)

        for data in eval_dataloader:
            inputs, labels = data

            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            preds = quantized_eval_model(inputs).clamp(0.0, 1.0)

Model
self.quant = torch.quantization.QuantStub()
self.conv_relu1 = ConvReLu(1, 64, _kernel_size=5, _padding=5//2)
self.conv_relu2 = ConvReLu(64,32, _kernel_size=3, _padding=3//2)
self.sub_pixel = nn.Sequential(
nn.Conv2d(32, 1 * (scale_factor ** 2), kernel_size=3, stride=1, padding=1),
nn.PixelShuffle(scale_factor),
nn.Sigmoid()
)
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
    x = self.quant(x)
    x = self.conv_relu1(x)
    x = self.conv_relu2(x)
    x = self.sub_pixel(x)
    x = self.dequant(x)
    return x

class ConvReLu(nn.Sequential):
def init(self, _in_channels, _out_channels, _kernel_size, _padding=0):
super(ConvReLu, self).init(
nn.Conv2d(_in_channels, _out_channels, kernel_size=_kernel_size, padding=_padding),
nn.ReLU(_out_channels)
)

1 Like