Fuse ConvBnReLU model error

I use ResNet-18 to test in post static quantization, and the step 2 is I need to fuse module like conv,bn,relu or conv,bn and so on. The the code is wrote as follows:

110     def fuse_model(self):
111         modules_names = [m for m in self.named_modules()]
112         modules_list = [m for m in self.modules()]
113         for ind, m in enumerate(modules_list):
114             if type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.BatchNorm2d and type(modules_list[ind+2]) == nn.ReLU:
115                 print("Find ConvBNReLu: ", modules_names[ind][0], '-->', modules_names[ind+1][0], '-->', modules_names[ind+2][0])
116                 torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0], modules_names[ind+2][0]], inplace=True)
117             elif type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.BatchNorm2d:
118                 print("Find ConvBN: ", modules_names[ind][0], '-->', modules_names[ind+1][0])
119                 torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0]], inplace=True)
120             elif type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.ReLU:
121                 print("Find ConvReLU: ", modules_names[ind][0], '-->', modules_names[ind+1][0])
122                 torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0]], inplace=True)

And I print the layer to fuse as follows:
Find ConvBN: conv1 --> bn1
Find ConvBN: layer1.0.conv1 --> layer1.0.bn1
Find ConvBNReLu: layer1.0.conv2 --> layer1.0.bn2 --> layer1.0.relu
Find ConvBN: layer1.1.conv1 --> layer1.1.bn1
Find ConvBNReLu: layer1.1.conv2 --> layer1.1.bn2 --> layer1.1.relu
Find ConvBN: layer2.0.conv1 --> layer2.0.bn1
Find ConvBNReLu: layer2.0.conv2 --> layer2.0.bn2 --> layer2.0.relu
Find ConvBN: layer2.0.shortcut.0 --> layer2.0.shortcut.1
Find ConvBN: layer2.1.conv1 --> layer2.1.bn1
Find ConvBNReLu: layer2.1.conv2 --> layer2.1.bn2 --> layer2.1.relu
Find ConvBN: layer3.0.conv1 --> layer3.0.bn1
Find ConvBNReLu: layer3.0.conv2 --> layer3.0.bn2 --> layer3.0.relu
Find ConvBN: layer3.0.shortcut.0 --> layer3.0.shortcut.1
Find ConvBN: layer3.1.conv1 --> layer3.1.bn1
Find ConvBNReLu: layer3.1.conv2 --> layer3.1.bn2 --> layer3.1.relu
Find ConvBN: layer4.0.conv1 --> layer4.0.bn1
Find ConvBNReLu: layer4.0.conv2 --> layer4.0.bn2 --> layer4.0.relu
Find ConvBN: layer4.0.shortcut.0 --> layer4.0.shortcut.1
Find ConvBN: layer4.1.conv1 --> layer4.1.bn1
Find ConvBNReLu: layer4.1.conv2 --> layer4.1.bn2 --> layer4.1.relu

And It seems every thing is ok, but when I run eval with pretrained model, the accuracy is 1% for CIFAR100, which origin acc is about 73%. So I start to debug, I found that when I remove fuse ConvBnRelu module, the result is good.

110     def fuse_model(self):
111         modules_names = [m for m in self.named_modules()]
112         modules_list = [m for m in self.modules()]
113         for ind, m in enumerate(modules_list):
114             #if type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.BatchNorm2d and type(modules_list[ind+2]) == nn.ReLU:
115             #    print("Find ConvBNReLu: ", modules_names[ind][0], '-->', modules_names[ind+1][0], '-->', modules_names[ind+2][0])
116             #    torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0], modules_names[ind+2][0]], inplace=True)
117             if type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.BatchNorm2d:
118                 print("Find ConvBN: ", modules_names[ind][0], '-->', modules_names[ind+1][0])
119                 torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0]], inplace=True)
120             elif type(m) == nn.Conv2d and type(modules_list[ind+1]) == nn.ReLU:
121                 print("Find ConvReLU: ", modules_names[ind][0], '-->', modules_names[ind+1][0])
122                 torch.quantization.fuse_modules(self, [modules_names[ind][0], modules_names[ind+1][0]], inplace=True)

So I am confused if ConvBnReLU fuse module has problem, and I test it on pytorch-1.5 and pytorch 1.4, I has this problem both.
So please help me, thanks.

Can you share more details? Are you calling fusion after the model is set to eval? Are you quantizing the model?

Hello Raghuraman,

I bump into this question when I was searching for details about the module fusion.

In your quantization tutorial, it was explicitly mentioned that module fusion will help make the model faster by saving on memory access while also improving numerical accuracy.

I am curious about what exactly does PyTorch do when fusing the modules? I may understand that it may save memory access but why will the fusion improve the numerical accuracy?

Thanks a lot,