Calibration of model in post training static quantization using fx api

Hi there,

I am working to quantizate a semantic segmentation model using the fx api provided by pytorch.

The model has a swin transformer as a backbone, aspp module and some upconvolutions following a DeepLabv3+ architecture.

I have followed the steps in the tutoriall by @jerryzh168

I am able to run the following lines of code

qconfig = get_default_qconfig("fbgemm")
    qconfig_dict = {"": qconfig}
    prepare_custom_config_dict = {
        "non_traceable_module_class": [PatchEmbed, BasicLayer]
    }
    prepared_model = prepare_fx(model, qconfig_dict, prepare_custom_config_dict)
    print(prepared_model.graph)
    print(prepared_model.code)

I am excluding PatchEmbed and BasicLayer module classes in prepare_fx.
Here’s their code:

class PatchEmbed(nn.Module):

    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        """Forward function."""
        # padding
        _, _, H, W = x.size()
        if W % self.patch_size[1] != 0:
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
        if H % self.patch_size[0] != 0:
            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))

        x = self.proj(x)  # B C Wh Ww
        if self.norm is not None:
            Wh, Ww = x.size(2), x.size(3)
            x = x.flatten(2).transpose(1, 2)
            x = self.norm(x)
            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)

        return x

class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    """

    def __init__(self,
                 dim,
                 depth,
                 num_heads,
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 downsample=None,
                 use_checkpoint=False):
        super().__init__()
        self.window_size = window_size
        self.shift_size = window_size // 2
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x, H, W):
        """ Forward function.
        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """

        # calculate attention mask for SW-MSA
        Hp = int(np.ceil(H / self.window_size)) * self.window_size
        Wp = int(np.ceil(W / self.window_size)) * self.window_size
        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        for blk in self.blocks:
            blk.H, blk.W = H, W
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, attn_mask)
            else:
                x = blk(x, attn_mask)
        if self.downsample is not None:
            x_down = self.downsample(x, H, W)
            Wh, Ww = (H + 1) // 2, (W + 1) // 2
            return x, H, W, x_down, Wh, Ww
        else:
            return x, H, W, x, H, W

The PatchEmbed gave me problems due to the presence of if statements.

BasicLayer was failing when executing numpy operations with Proxys in these lines:

Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size

My problem comes, when I try to calibrate the model.

def calibrate(model, data_loader):
    with torch.no_grad():
        for inp, target, _, _ in tqdm(data_loader, total=len(data_loader.dataset),
                                      desc='Calibrating model for post training static quantization...'):
            model(inp)


calibrate(prepared_model, data_loader)

That’s when I get the following error:

Traceback (most recent call last):
  File "/home/victor/proyectos/roof_segmentation/increase_speed_model_production/roof_condition_semseg/utils/quantization.py", line 37, in calibrate
    model(inp)
  File "/home/victor/proyectos/roof_segmentation/increase_speed_model_production/roof_condition_semseg/venv/lib/python3.8/site-packages/torch/fx/graph_module.py", line 513, in wrapped_call
    raise e.with_traceback(None)
AttributeError: 'int' object has no attribute 'numel'

The weird part is that my input is torch.tensor not an int. So I guess the generated graph may have a problem somewhere, after executing prepare_fx.

I haven’t found any information regarding problems when calibrating the model.

Any ideas on how to solve the problem?

Thanks.

Note: I have already tried post training dynamic quantization using eager mode and it works fine. However, it only allows me to quantizate nn.Linear and activation functions of my model.

Hey,

If you could give us a minimal reproducible example, that would be helpful.

my guess is that an observer is getting attached to one of the modules with an int input and, assuming its a tensor, calls numel() on it.

Hey,

Thanks for the answer.

I can’t share all the details on the model’s architecture due to my company policy but I can share the workflow I use:

from models import DeeplabV3X
from models.backbone import PatchEmbed, BasicLayer
from datasets import Dataset, build_dataloader

model = DeeplabV3X()

# Use a dataset to compare inference times and evaluation metrics
test_dataset = Dataset()

# Dataloader getitem generates input (n, 3, 512, 512) tensor image, target (n, 1, 512, 512) tensor with masks
test_dataloader = build_dataloader(test_dataset)

# Deepcopying the original model because quantization api changes the model inplace and we want
# to keep the original model for future comparison
 q_model = copy.deepcopy(model)

# Function to calibrate graph module
def calibrate(model, data_loader):
    with torch.no_grad():
        for inp, target, _, _ in tqdm(data_loader, total=len(data_loader.dataset),
                                      desc='Calibrating model for post training static quantization...'):
            model(inp)

# Function to convert model
def quantizate_ptq_static_fx(model, data_loader):
    qconfig = get_default_qconfig("fbgemm")
    qconfig_dict = {"": qconfig}
    prepare_custom_config_dict = {
        "non_traceable_module_class": [PatchEmbed, BasicLayer]
    }
    prepared_model = prepare_fx(model, qconfig_dict, prepare_custom_config_dict) # Generate graph
    print(prepared_model.graph)
    print(prepared_model.code)
    calibrate(prepared_model, data_loader) # Code breaks here
    q_model = convert_fx(prepared_model)
    return q_model

q_model = quantizate_ptq_static_fx(q_model, test_dataloader)

for inp, target in test_loader:
    
    logits = model(inp)
    q_logits = model(inp)

I have generated the logs of graph and code for prepared_model.

Let me know if this is helpful.

A minimal reproducible example doesn’t generally involve details about model architecture, ideally it’d be a toy model with only the problematic piece.

The issue is in one of the modules and without code I can’t determine more than that. You could probably just use print statements to figure out which one is the issue and go from there.

Again, if i were to guess, its probably because you’re passing in/returning a mixture of different dtypes (i know you said your input is not an int but the input to BasicLayer does contain an int) in some of these modules when fx is probably assuming its all tensors. You could fix this by passing/returning everything as tensors (if thats actually the issue).

@HDCharles we can probably get some hint from the graph and code that is attached.

here:

view = model_model_backbone_norm0_activation_post_process_0.view(-1, getitem_2_activation_post_process_0, getitem_3_activation_post_process_0, 96); 

looks like getitem_2_activation_post_process_0 and getitem_3_activation_post_process_0 are expected to be int? yet it is observed?

Is this intended? we could potentially remove the quantization support for view, it’s probably not needed

Good catch.

Its bigger than view though, pretty much every quant pattern in GeneralTensorShapeOpQuantizeHandler has the same issue if you do anything but hard code the non tensor arguments. To be honest I’m not sure why these need a quant handler since they can handle both normal and qtensors, they don’t break anything if they are excluded from the flow.

e.g.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization.quantize_fx import prepare_fx, convert_fx

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.lin = nn.Linear(5,1)

    def forward(self, x, y):
        x = self.pool(F.relu(self.conv1(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = x.view(-1, y)
        x = self.lin(x)
        return x

model=Net().eval()
model(torch.randn(5,3,32,32), 5)
qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
# qconfig_dict = {"": qconfig, "object_type": [('view', None)]}
prepared_model = prepare_fx(model, qconfig_dict)
print(prepared_model.code)
prepared_model(torch.randn(5,3,32,32), 5)
final_model = convert_fx(prepared_model)
print(final_model.code)
final_model(torch.randn(5,3,32,32), 5)

@vlc
you can solve the issue by specifying None as the qconfig for view (see the commented out qconfig dict in the repro) to exclude it.