Quantization for VGG model

For the below code from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py I was trying to implement quantization. Since presently it doesn’t has Conv, BatchNorm and Relu as a single block, for classification class what should be the fuse_model.

cfg=[64,"M",128,"M",256,256,"M",512,512,"M",512,512,"M"]

def vgg_make(cfg,batch_norm):
    emp=[]
    inp_c=3
    for i,val in enumerate(cfg):
        if val=="M":
            emp+=[nn.MaxPool2d(kernel_size=(2,2), stride=2)]
        else:
            conv=nn.Conv2d(inp_c, val, kernel_size=(3,3), padding=1)
            if batch_norm:
                emp+=[conv, nn.BatchNorm2d(val), nn.ReLU(inplace=True)]
            else:
                emp+=[conv, nn.ReLU(inplace=True)]
            inp_c=val
    return nn.Sequential(*emp)
backbone=vgg_make(cfg, batch_norm=True)
classifer = nn.Sequential(nn.Linear(512*7*7,4096),
                          nn.ReLU(inplace=True),
                          nn.Dropout(),
                          nn.Linear(4096,4096),
                          nn.ReLU(inplace=True),
                          nn.Dropout(),
                          nn.Linear(4096,10))
class classfication(nn.Module):
    def __init__(self, backbone, classifer):
        super().__init__()
        self.backbone=backbone
        self.classifer=classifer
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
    def forward(self, x):
        x=self.quant(x)
        x=self.backbone(x)
        x=torch.flatten(x, 1)
        x=self.classifer(x)
        x=self.dequant(x)
        return x
    
    def fuse_model(self):
        for m in self.modules():
            pass

Unable to fuse convolution, batchnorm and relu.

Each entry in the module list looks something like

  (0): ConvBNReLU(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
  )

At the time of fusion using below code i get the error.

 def fuse_model(self):
     for m in self.modules():
        if type(m) == ConvBNReLU:
           fuse_modules(m, ['0', '1', '2'], inplace=True)
NotImplementedError: Cannot fuse modules: (<class 'torch.nn.modules.conv.Conv2d'>, <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, <class 'torch.nn.modules.activation.ReLU6'>)

Please let me know the way around.

did you find a way around this ? having the same problem

Try replace ReLU6 with ReLU, as quantization does not currently support ReLU6.