I use ResNet-18 to test in post static quantization, and the step 2 is I need to fuse module like conv,bn,relu or conv,bn and so on. The the code is wrote as follows:
110 def fuse_model(self):
111 modules_names = [m for m in self.named_modules()]
112 modules_list = [m for m in self.modules()]
113 for ind, m in enumerate(modules_list):
114 if type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.BatchNorm2d and type(modules_list[ind+2]) == nn.ReLU:
115 print("Find ConvBNReLu: ", modules_names[ind][0], '-->', modules_names[ind+1][0], '-->', modules_names[ind+2][0])
116 torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0], modules_names[ind+2][0]], inplace=True)
117 elif type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.BatchNorm2d:
118 print("Find ConvBN: ", modules_names[ind][0], '-->', modules_names[ind+1][0])
119 torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0]], inplace=True)
120 elif type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.ReLU:
121 print("Find ConvReLU: ", modules_names[ind][0], '-->', modules_names[ind+1][0])
122 torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0]], inplace=True)
And I print the layer to fuse as follows:
Find ConvBN: conv1 --> bn1
Find ConvBN: layer1.0.conv1 --> layer1.0.bn1
Find ConvBNReLu: layer1.0.conv2 --> layer1.0.bn2 --> layer1.0.relu
Find ConvBN: layer1.1.conv1 --> layer1.1.bn1
Find ConvBNReLu: layer1.1.conv2 --> layer1.1.bn2 --> layer1.1.relu
Find ConvBN: layer2.0.conv1 --> layer2.0.bn1
Find ConvBNReLu: layer2.0.conv2 --> layer2.0.bn2 --> layer2.0.relu
Find ConvBN: layer2.0.shortcut.0 --> layer2.0.shortcut.1
Find ConvBN: layer2.1.conv1 --> layer2.1.bn1
Find ConvBNReLu: layer2.1.conv2 --> layer2.1.bn2 --> layer2.1.relu
Find ConvBN: layer3.0.conv1 --> layer3.0.bn1
Find ConvBNReLu: layer3.0.conv2 --> layer3.0.bn2 --> layer3.0.relu
Find ConvBN: layer3.0.shortcut.0 --> layer3.0.shortcut.1
Find ConvBN: layer3.1.conv1 --> layer3.1.bn1
Find ConvBNReLu: layer3.1.conv2 --> layer3.1.bn2 --> layer3.1.relu
Find ConvBN: layer4.0.conv1 --> layer4.0.bn1
Find ConvBNReLu: layer4.0.conv2 --> layer4.0.bn2 --> layer4.0.relu
Find ConvBN: layer4.0.shortcut.0 --> layer4.0.shortcut.1
Find ConvBN: layer4.1.conv1 --> layer4.1.bn1
Find ConvBNReLu: layer4.1.conv2 --> layer4.1.bn2 --> layer4.1.relu
And It seems every thing is ok, but when I run eval with pretrained model, the accuracy is 1% for CIFAR100, which origin acc is about 73%. So I start to debug, I found that when I remove fuse ConvBnRelu module, the result is good.
110 def fuse_model(self):
111 modules_names = [m for m in self.named_modules()]
112 modules_list = [m for m in self.modules()]
113 for ind, m in enumerate(modules_list):
114 #if type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.BatchNorm2d and type(modules_list[ind+2]) == nn.ReLU:
115 # print("Find ConvBNReLu: ", modules_names[ind][0], '-->', modules_names[ind+1][0], '-->', modules_names[ind+2][0])
116 # torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0], modules_names[ind+2][0]], inplace=True)
117 if type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.BatchNorm2d:
118 print("Find ConvBN: ", modules_names[ind][0], '-->', modules_names[ind+1][0])
119 torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0]], inplace=True)
120 elif type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.ReLU:
121 print("Find ConvReLU: ", modules_names[ind][0], '-->', modules_names[ind+1][0])
122 torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0]], inplace=True)
So I am confused if ConvBnReLU fuse module has problem, and I test it on pytorch-1.5 and pytorch 1.4, I has this problem both.
So please help me, thanks.