Ost Training Quantization fails on SPAN model with type_as

Hi all,

I’m trying to run post training quantization (PTQ) on a super-resolution model called SPAN, but I’m running into an error when I forward a quantized input.

I wrapped the model like this:

@ARCH_REGISTRY.register()
class QuantizedSPAN(SPAN):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)  # original SPAN
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        print(f"quantized x --> \n {x} \n  Type: {x.dtype}\n")

        x = super().forward(x)
        x = self.dequant(x)
        return x

The parent SPAN.forward() starts with:

def forward(self, x):
    self.mean = self.mean.type_as(x)  # <-- this line crashes
    x = (x - self.mean) * self.img_range
    ...

Before quantization, the input looks like this (torch.float32):

tensor([[[[0.1137, 0.1137, 0.1137,  ..., 0.1098, 0.1059, 0.1059],
          ...
]]], dtype=torch.float32)

After quantization, it becomes (torch.quint8):

tensor([[[[0.1124, 0.1124, 0.1124,  ..., 0.1092, 0.1059, 0.1059],
          ...
]]], device='cuda:0', size=(1, 3, 40, 800), dtype=torch.quint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.003210579277947545,
       zero_point=0)

And then I get this error:

File "SPAN_arch.py", line 263, in forward
    x = super().forward(x)
File "SPAN_arch.py", line 232, in forward
    self.mean = self.mean.type_as(x)
RuntimeError: empty_strided not supported on quantized tensors yet
see https://github.com/pytorch/pytorch/issues/74540


My questions:

  1. Is there a recommended way to handle situations where .type_as(x) is called but the tensor is quantized (torch.quint8)?

  2. Should I just keep self.mean in float32 and avoid casting it to the quantized dtype?

  3. Is this a known limitation of PTQ in PyTorch, especially on CUDA?

Any advice would be appreciated

this is an old deprecated quant flow. I’d recommend to use our two new flows (quantize_ and pt2e quantization): Quick Start Guide — torchao main documentation