For the below code from https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py I was trying to implement quantization. Since presently it doesn’t has Conv, BatchNorm and Relu as a single block, for classification class what should be the fuse_model.
cfg=[64,"M",128,"M",256,256,"M",512,512,"M",512,512,"M"]
def vgg_make(cfg,batch_norm):
emp=[]
inp_c=3
for i,val in enumerate(cfg):
if val=="M":
emp+=[nn.MaxPool2d(kernel_size=(2,2), stride=2)]
else:
conv=nn.Conv2d(inp_c, val, kernel_size=(3,3), padding=1)
if batch_norm:
emp+=[conv, nn.BatchNorm2d(val), nn.ReLU(inplace=True)]
else:
emp+=[conv, nn.ReLU(inplace=True)]
inp_c=val
return nn.Sequential(*emp)
backbone=vgg_make(cfg, batch_norm=True)
classifer = nn.Sequential(nn.Linear(512*7*7,4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096,4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096,10))
class classfication(nn.Module):
def __init__(self, backbone, classifer):
super().__init__()
self.backbone=backbone
self.classifer=classifer
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x=self.quant(x)
x=self.backbone(x)
x=torch.flatten(x, 1)
x=self.classifer(x)
x=self.dequant(x)
return x
def fuse_model(self):
for m in self.modules():
pass