PTSQ for model with layers and tensor ops

Can you kindly give a pointer on how to do PTSQ on a custom model (snippet below) in Eager mode that consists of supported Modules (Conv2D, BN, ReLU, etc.) and regular tensor ops like matmul, cat, slice, permute, view, etc.?

import torch
import torch.nn as nn

class CustomLayer(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        num_joints,
        stride,
        num_partitions,
        dropout,
        residual):
        
        super().__init__()

        assert kernel_size % 2 == 1

        self.num_partitions = num_partitions
        self.num_joints = num_joints
        self.stride = stride
        self.kernel_size = kernel_size

        self.out_channels = out_channels

        self.conv = nn.Conv2d(in_channels, out_channels*num_partitions, kernel_size=1, bias=False)
        
        self.bn_relu = nn.Sequential(
            nn.BatchNorm2d(out_channels, track_running_stats=False),
            nn.ReLU())

        if not residual:
            self.residual = lambda _: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels, track_running_stats=False))

        if not residual:
            self.do = nn.Dropout(dropout)
        else:
            self.do = nn.Sequential(
                nn.ReLU(),
                nn.Dropout(dropout))


    def forward(self, x, A):
        capture_length = x.size(2)
        device = torch.device("cuda:{0}".format(torch.cuda.current_device()) if torch.cuda.is_available() else "cpu")
        
        lt_matrix = torch.zeros(capture_length, capture_length, device=device)
        for i in range(self.kernel_size//self.stride):
            lt_matrix += F.pad(
                torch.eye(
                    capture_length - self.stride * i,
                    device=device),
                (i*self.stride,0,0,i*self.stride))

        res = self.residual(x)
        x = self.conv(x)
        x = torch.split(x, self.out_channels, dim=1)
        x = torch.stack(x, -1)
        x = x.permute(0,2,4,1,3)
        x = torch.matmul(x, A)
        x = x.permute(0,2,3,4,1)
        x = torch.matmul(x, lt_matrix)
        x = torch.sum(x, dim=(1))
        x = x.permute(0,1,3,2)
        x = self.bn_relu(x)

        return self.do(x + res)

What API should I extend (e.g. FloatFunctional)? Should I define swap modules for my containing Module and pass it to QConfig? Should I add Observer before tensor ops and if so, how during the calibration process can I update the “scale” and “zero_point” for the tensor of that op from the collected Observer stats?

Desired outcome:
Quantize the whole model to INT8 (weights and activations): statically quantize both the MATMUL ops and the following SUM as if they were wrapped as part of a single layer.

Couldn’t find an answer up till now. Would greatly appreciate any help, especially if you could briefly mention the methodology that I can follow to achieve this and to subsequently make a contribution. Thank you!

did you see this tutorial? (beta) Static Quantization with Eager Mode in PyTorch — PyTorch Tutorials 2.0.1+cu117 documentation

for conv2d/bn/relu, you need to make sure their inputs are quantized by placing QuantStub/DeQuantStub, and set qconfig for them properly, if you need to quantize (conv - relu) as a fused module, then you’ll need to call fuse_modules first ((beta) Static Quantization with Eager Mode in PyTorch — PyTorch Tutorials 2.0.1+cu117 documentation has an example as well)

example for using FloatFunctional: vision/mobilenetv2.py at main · pytorch/vision · GitHub
you’ll need to do the modification for add, mul, cat etc. (here is the list: pytorch/functional_modules.py at main · pytorch/pytorch · GitHub)

I think matmul is probably not supported right now, for permute/slice/view you don’t need to do anything, they can work with both floating point and quantized Tensors