I have been trying to follow this code snippet and adapt it to a ResNet20 to reproduce results at different bitwidths without any luck, meaning that I always get validation accuracy equal to the bitwidth B I’m using (8% for 8 bits, 4% for 4 bits, and so on). It’s not clear at all to me where and how many times to put the self.quant and self.dequant code lines in the resnet definition, and also how to correctly fuse the model.
These are the code changes I did to the BasicBlock (bold):
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, option='A'):
super(BasicBlock, self).__init__()
**self.quant = torch.quantization.QuantStub()**
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
if option == 'A':
self.shortcut = LambdaLayer(lambda x:
F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), "constant",
0))
elif option == 'B':
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
**self.dequant = torch.quantization.DeQuantStub()**
def forward(self, x):
**out = self.quant(x)**
out = F.relu(self.bn1(self.conv1(out)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
**out = self.dequant(out)**
return out
and to the ResNet module (bold):
class ResNet(nn.Module):
def init(self, block, num_blocks, num_classes=10):
super(ResNet, self).init()
self.quant = torch.quantization.QuantStub()
self.in_planes = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.linear = nn.Linear(64, num_classes)
self.apply(_weights_init)
**self.dequant = torch.quantization.DeQuantStub()**
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
**out = self.quant(x)**
out = F.relu(self.bn1(self.conv1(out)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.avg_pool2d(out, out.size()[3])
out = out.view(out.size(0), -1)
out = self.linear(out)
**out = self.dequant(out)**
return out
And this is how I fuse the model:
model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [[“conv1”, “bn1”]], inplace=True)
for module_name, module in model_fp32_fused.named_children():
if “layer” in module_name:
for basic_block_name, basic_block in module.named_children():
torch.quantization.fuse_modules(
basic_block, [[“conv1”, “bn1”], [“conv2”, “bn2”]],
inplace=True)
What am I doing wrong? I’m sorry if it’s trivial but it’s very difficult to understand how to implement this for more complex models from the documentation. Thank you