Hi all,
I’m trying to run post training quantization (PTQ) on a super-resolution model called SPAN, but I’m running into an error when I forward a quantized input.
I wrapped the model like this:
@ARCH_REGISTRY.register()
class QuantizedSPAN(SPAN):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) # original SPAN
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()
def forward(self, x):
x = self.quant(x)
print(f"quantized x --> \n {x} \n Type: {x.dtype}\n")
x = super().forward(x)
x = self.dequant(x)
return x
The parent SPAN.forward() starts with:
def forward(self, x):
self.mean = self.mean.type_as(x) # <-- this line crashes
x = (x - self.mean) * self.img_range
...
Before quantization, the input looks like this (torch.float32):
tensor([[[[0.1137, 0.1137, 0.1137, ..., 0.1098, 0.1059, 0.1059],
...
]]], dtype=torch.float32)
After quantization, it becomes (torch.quint8):
tensor([[[[0.1124, 0.1124, 0.1124, ..., 0.1092, 0.1059, 0.1059],
...
]]], device='cuda:0', size=(1, 3, 40, 800), dtype=torch.quint8,
quantization_scheme=torch.per_tensor_affine, scale=0.003210579277947545,
zero_point=0)
And then I get this error:
File "SPAN_arch.py", line 263, in forward
x = super().forward(x)
File "SPAN_arch.py", line 232, in forward
self.mean = self.mean.type_as(x)
RuntimeError: empty_strided not supported on quantized tensors yet
see https://github.com/pytorch/pytorch/issues/74540
My questions:
-
Is there a recommended way to handle situations where
.type_as(x)is called but the tensor is quantized (torch.quint8)? -
Should I just keep
self.meanin float32 and avoid casting it to the quantized dtype? -
Is this a known limitation of PTQ in PyTorch, especially on CUDA?
Any advice would be appreciated