Error: FX mode quantization ('function' object has no attribute 'shape')

Hello,

I have an UNet-like inpainting model like below, and recently trying to quantize this to deploy to Android mobile environment. I am trying every pytorch quantization tutorial, but some parts are beyond my experience/knowledge. I started off with FX quantization with post training static quantization but keep getting an error that:

AttributeError: ‘function’ object has no attribute ‘shape’

whenever I execute the following:

# 1. FX Mode quantization
    model_to_quantize = copy.deepcopy(model_fp)
    qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
    model_to_quantize.eval()
    # prepare
    model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
    # calibrate (not shown)
    # quantize
    model_quantized = quantize_fx.convert_fx(model_prepared)

    # Error Line
    output_quantized = model_quantized(img, mask)

C:\Users\er\Anaconda3\envs\deeplearning\lib\site-packages\torch\quantization\observer.py:1090: UserWarning: must run observer before calling calculate_qparams.                                    Returning default scale and zero point 
  warnings.warn(
Traceback (most recent call last):
  File "C:\Users\er\Anaconda3\envs\deeplearning\lib\site-packages\torch\fx\graph_module.py", line 504, in wrapped_call
    return cls_call(self, *args, **kwargs)
  File "C:\Users\er\Anaconda3\envs\deeplearning\lib\site-packages\torch\nn\modules\module.py", line 1056, in _call_impl
    return forward_call(*input, **kwargs)
  File "<eval_with_key_4>", line 468, in forward
    quantize_per_tensor_41 = torch.quantize_per_tensor(decoder_1_input_conv_bias, decoder_1_input_scale_0, decoder_1_input_zero_point_0, torch.quint8);  decoder_1_input_conv_bias = decoder_1_input_scale_0 = decoder_1_input_zero_point_0 = None
TypeError: quantize_per_tensor() received an invalid combination of arguments - got (method, Tensor, Tensor, torch.dtype), but expected one of:
 * (Tensor input, Tensor scale, Tensor zero_point, torch.dtype dtype)
      didn't match because some of the arguments have invalid types: (!method!, Tensor, Tensor, torch.dtype)
 * (Tensor input, float scale, int zero_point, torch.dtype dtype)
      didn't match because some of the arguments have invalid types: (!method!, !Tensor!, !Tensor!, torch.dtype)
 * (tuple of Tensors tensors, Tensor scales, Tensor zero_points, torch.dtype dtype)
      didn't match because some of the arguments have invalid types: (!method!, Tensor, Tensor, torch.dtype)


Call using an FX-traced Module, line 468 of the traced Module's generated forward function:
    decoder_1_input_zero_point_0 = self.decoder_1_input_zero_point_0
    quantize_per_tensor_41 = torch.quantize_per_tensor(decoder_1_input_conv_bias, decoder_1_input_scale_0, decoder_1_input_zero_point_0, torch.quint8);  decoder_1_input_conv_bias = decoder_1_input_scale_0 = decoder_1_input_zero_point_0 = None

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    reshape = quantize_per_tensor_41.reshape(1, -1, 1, 1);  quantize_per_tensor_41 = None

    clamp_13 = torch.clamp(decoder_1_mask_conv, 0, 1)

Traceback (most recent call last):
  File "C:\Users\er\PycharmProjects\ZB_PartialConv_SOCOFing\mymodel.py", line 257, in <module>
    quantized_output = model_quantized(img, mask)
  File "C:\Users\er\Anaconda3\envs\deeplearning\lib\site-packages\torch\fx\graph_module.py", line 512, in wrapped_call
    raise e.with_traceback(None)
TypeError: quantize_per_tensor() received an invalid combination of arguments - got (method, Tensor, Tensor, torch.dtype), but expected one of:
 * (Tensor input, Tensor scale, Tensor zero_point, torch.dtype dtype)
      didn't match because some of the arguments have invalid types: (!method!, Tensor, Tensor, torch.dtype)
 * (Tensor input, float scale, int zero_point, torch.dtype dtype)
      didn't match because some of the arguments have invalid types: (!method!, !Tensor!, !Tensor!, torch.dtype)
 * (tuple of Tensors tensors, Tensor scales, Tensor zero_points, torch.dtype dtype)
      didn't match because some of the arguments have invalid types: (!method!, Tensor, Tensor, torch.dtype)
class PartialConvLayer(nn.Module):

    def __init__(self, in_channels, out_channels, kernel, bn=True, bias=False, sample="none-3", activation="relu"):
        super().__init__()
        self.bn = bn
        self.kernel_size = kernel
        self.in_channel = in_channels

        if sample == "down-7":
            # Kernel Size = 7, Stride = 2, Padding = 3
            self.input_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 3, bias=bias)
            self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 3, bias=False)

        elif sample == "down-5":
            self.input_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 2, bias=bias)
            self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 2, bias=False)

        elif sample == "down-3":
            self.input_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 1, bias=bias)
            self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel, 2, 1, bias=False)

        else:
            self.input_conv = nn.Conv2d(in_channels, out_channels, kernel, 1, 1, bias=bias)
            self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel, 1, 1, bias=False)

        nn.init.constant_(self.mask_conv.weight, 1.0)

        # Initialize weight using Kaiming Initialization
        # a: negative slope of relu set to 0, same as relu
        # "fan_in" preserved variance from forward pass
        nn.init.kaiming_normal_(self.input_conv.weight, a=0, mode="fan_in")

        for param in self.mask_conv.parameters():
            param.requires_grad = False

        if bn:
            # Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
            # Applying BatchNorm2d layer after Conv will remove the channel mean
            self.batch_normalization = nn.BatchNorm2d(out_channels)

        if activation == "relu":
            # Used between all encoding layers
            self.activation = nn.ReLU()
        elif activation == "leaky_relu":
            # Used between all decoding layers (Leaky RELU with alpha = 0.2)
            self.activation = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, input_x, mask):

        # output = W^T dot (X .* M) + b
        output = self.input_conv(input_x * mask)

        # requires_grad = False
        with torch.no_grad():
            # mask = (1 dot M) + 0 = M
            output_mask = self.mask_conv(mask)

        if self.input_conv.bias is None:
            output_bias = 0
        else:
            # Only the last layer has a bias term
            output_bias = self.input_conv.bias.reshape(1, -1, 1, 1)

        # Mask Update
        updated_mask = torch.clamp(output_mask, 0, 1)

        # Output image Update
        num_scaling = torch.ones_like(output) * self.kernel_size * self.kernel_size
        denom_scaling = output_mask / self.in_channel
        scaling_factor = num_scaling / denom_scaling

        updated_output = (output - output_bias) * scaling_factor + output_bias
        updated_output = updated_output * updated_mask

        if self.bn:
            updated_output = self.batch_normalization(updated_output)

        if hasattr(self, 'activation'):
            updated_output = self.activation(updated_output)

        return updated_output, updated_mask


class PartialConvUNet(nn.Module):

    # 256 x 256 image input, 256 = 2^8
    def __init__(self, input_size=256, layers=7):
        if 2 ** (layers + 1) != input_size:
            raise AssertionError

        super().__init__()
        self.freeze_enc_bn = False
        self.layers = layers

        # ======================= ENCODING LAYERS =======================
        # 3x256x256 --> 64x128x128
        self.encoder_1 = PartialConvLayer(3, 64, 7, bn=False, sample="down-7")

        # 64x128x128 --> 128x64x64
        self.encoder_2 = PartialConvLayer(64, 128, 5, sample="down-5")

        # 128x64x64 --> 256x32x32
        self.encoder_3 = PartialConvLayer(128, 256, 3, sample="down-3")

        # 256x32x32 --> 512x16x16
        self.encoder_4 = PartialConvLayer(256, 512, 3, sample="down-3")

        # 512x16x16 --> 512x8x8 --> 512x4x4 --> 512x2x2
        for i in range(5, layers + 1):
            name = "encoder_{:d}".format(i)
            setattr(self, name, PartialConvLayer(512, 512, 3, sample="down-3"))

        # ======================= DECODING LAYERS =======================
        # dec_7: UP(512x2x2) + 512x4x4(enc_6 output) = 1024x4x4 --> 512x4x4
        # dec_6: UP(512x4x4) + 512x8x8(enc_5 output) = 1024x8x8 --> 512x8x8
        # dec_5: UP(512x8x8) + 512x16x16(enc_4 output) = 1024x16x16 --> 512x16x16
        for i in range(layers, 0, -1):
            name = "decoder_{:d}".format(i)
            setattr(self, name, PartialConvLayer(512 + 512, 512, 3, activation="leaky_relu"))

        # UP(512x16x16) + 256x32x32(enc_3 output) = 768x32x32 --> 256x32x32
        self.decoder_4 = PartialConvLayer(512 + 256, 256, 3, activation="leaky_relu")

        # UP(256x32x32) + 128x64x64(enc_2 output) = 384x64x64 --> 128x64x64
        self.decoder_3 = PartialConvLayer(256 + 128, 128, 3, activation="leaky_relu")

        # UP(128x64x64) + 64x128x128(enc_1 output) = 192x128x128 --> 64x128x128
        self.decoder_2 = PartialConvLayer(128 + 64, 64, 3, activation="leaky_relu")

        # UP(64x128x128) + 3x256x256(original image) = 67x256x256 --> 3x256x256(final output)
        self.decoder_1 = PartialConvLayer(64 + 3, 3, 3, bn=False, activation="", bias=True)

    def forward(self, input_x, mask):
        encoder_dict = {}
        mask_dict = {}

        key_prev = "h_0"
        encoder_dict[key_prev], mask_dict[key_prev] = input_x, mask

        # Encoder Path
        for i in range(1, self.layers + 1):
            encoder_key = "encoder_{:d}".format(i)
            key = "h_{:d}".format(i)
            # Passes input and mask through encoding layer
            encoder_dict[key], mask_dict[key] = getattr(self, encoder_key)(encoder_dict[key_prev], mask_dict[key_prev])
            key_prev = key

        # Gets the final output data and mask from the encoding layers
        # 512 x 2 x 2
        out_key = "h_{:d}".format(self.layers)
        out_data, out_mask = encoder_dict[out_key], mask_dict[out_key]

        # Decoder Path
        for i in range(self.layers, 0, -1):
            encoder_key = "h_{:d}".format(i - 1)
            decoder_key = "decoder_{:d}".format(i)

            # Upsample to 2 times scale, matching dimensions of previous encoding layer output
            out_data = F.interpolate(out_data, scale_factor=2)
            out_mask = F.interpolate(out_mask, scale_factor=2)

            # concatenate upsampled decoder output with encoder output of same H x W dimensions
            # s.t. final decoding layer input will contain the original image
            out_data = torch.cat([out_data, encoder_dict[encoder_key]], dim=1)
            # also concatenate the masks
            out_mask = torch.cat([out_mask, mask_dict[encoder_key]], dim=1)

            # feed through decoder layers
            out_data, out_mask = getattr(self, decoder_key)(out_data, out_mask)

        return out_data

    def train(self, mode=True):
        super().train(mode)
        if self.freeze_enc_bn:
            for name, module in self.named_modules():
                if isinstance(module, nn.BatchNorm2d) and "enc" in name:
                    # Sets batch normalization layers to evaluation mode
                    module.eval()

I don’t have a clear clue, but I guess FX mode doesn’t support reshape(1, -1, 1, 1) at PartialConvLayer() so that the module is not traceable. Any idea?