Hi:
I am trying to use quantization aware training for my CNN network. What I want to do is I load a pretrained RestNet18 and finetune it with other dataset. I run into some problem to quantitize the network during training.
I followed the tutorial (pytorch.org/docs/stable/quantization.html).
However, I run into an NotImplementedError: Cannot fuse modules: in fusemodule.py:
Could you help me understand this error?
Many thanks
Shixian Wen
Error Messages
Traceback (most recent call last):
File âattetion_delta_quantization.pyâ, line 468, in
model_one_fused = torch.quantization.fuse_modules(model_one,[[âbase.0.weightâ,âbase.1.weightâ,âbase.1.biasâ,âbase.4.0.conv1.weightâ,âbase.4.0.bn1.weightâ,âbase.4.0.bn1.biasâ,âbase.4.0.conv2.weightâ,âbase.4.0.bn2.weightâ,âbase.4.0.bn2.biasâ,âbase.4.1.conv1.weightâ,âbase.4.1.bn1.weightâ,âbase.4.1.bn1.biasâ,âbase.4.1.conv2.weightâ,âbase.4.1.bn2.weightâ,âbase.4.1.bn2.biasâ,âbase.5.0.conv1.weightâ,âbase.5.0.bn1.weightâ,âbase.5.0.bn1.biasâ,âbase.5.0.conv2.weightâ,âbase.5.0.bn2.weightâ,âbase.5.0.bn2.biasâ,âbase.5.0.downsample.0.weightâ,âbase.5.0.downsample.1.weightâ,âbase.5.0.downsample.1.biasâ,âbase.5.1.conv1.weightâ,âbase.5.1.bn1.weightâ,âbase.5.1.bn1.biasâ,âbase.5.1.conv2.weightâ,âbase.5.1.bn2.weightâ,âbase.5.1.bn2.biasâ,âbase.6.0.conv1.weightâ,âbase.6.0.bn1.weightâ,âbase.6.0.bn1.biasâ,âbase.6.0.conv2.weightâ,âbase.6.0.bn2.weightâ,âbase.6.0.bn2.biasâ,âbase.6.0.downsample.0.weightâ,âbase.6.0.downsample.1.weightâ,âbase.6.0.downsample.1.biasâ,âbase.6.1.conv1.weightâ,âbase.6.1.bn1.weightâ,âbase.6.1.bn1.biasâ,âbase.6.1.conv2.weightâ,âbase.6.1.bn2.weightâ,âbase.6.1.bn2.biasâ,âbase.7.0.conv1.weightâ,âbase.7.0.bn1.weightâ,âbase.7.0.bn1.biasâ,âbase.7.0.conv2.weightâ,âbase.7.0.bn2.weightâ,âbase.7.0.bn2.biasâ,âbase.7.0.downsample.0.weightâ,âbase.7.0.downsample.1.weightâ,âbase.7.0.downsample.1.biasâ,âbase.7.1.conv1.weightâ,âbase.7.1.bn1.weightâ,âbase.7.1.bn1.biasâ,âbase.7.1.conv2.weightâ,âbase.7.1.bn2.weightâ,âbase.7.1.bn2.biasâ,âlinear_subâ,âlinear_birdâ,âlinear_boatâ,âlinear_carâ,âlinear_catâ,âlinear_fungusâ,âlinear_insectâ,âlinear_monkeyâ,âlinear_truckâ,âlinear_dogâ,âlinear_fruitâ]])
File â/home/shixian/anaconda3/envs/pytorchnew/lib/python3.7/site-packages/torch/quantization/fuse_modules.pyâ, line 198, in fuse_modules
_fuse_modules(model, module_list, fuser_func)
File â/home/shixian/anaconda3/envs/pytorchnew/lib/python3.7/site-packages/torch/quantization/fuse_modules.pyâ, line 141, in _fuse_modules
new_mod_list = fuser_func(mod_list)
File â/home/shixian/anaconda3/envs/pytorchnew/lib/python3.7/site-packages/torch/quantization/fuse_modules.pyâ, line 124, in fuse_known_modules
raise NotImplementedError(âCannot fuse modules: {}â.format(types))
NotImplementedError: Cannot fuse modules: (<class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.parameter.Parameterâ>, <class âtorch.nn.modules.linear.Linearâ>, <class âtorch.nn.modules.linear.Linearâ>, <class âtorch.nn.modules.linear.Linearâ>, <class âtorch.nn.modules.linear.Linearâ>, <class âtorch.nn.modules.linear.Linearâ>, <class âtorch.nn.modules.linear.Linearâ>, <class âtorch.nn.modules.linear.Linearâ>, <class âtorch.nn.modules.linear.Linearâ>, <class âtorch.nn.modules.linear.Linearâ>, <class âtorch.nn.modules.linear.Linearâ>, <class âtorch.nn.modules.linear.Linearâ>)
Here is my CNN model:
class Resnet18_ONE(nn.Module):
def __init__(self):
super(Resnet18_ONE,self).__init__()
self.quant = torch.quantization.QuantStub()
resnet18 = torchvision.models.resnet18(pretrained=True)
num_ftrs = resnet18.fc.in_features
self.base= nn.Sequential(*list(resnet18.children())[:-1])
self.linear_sub = nn.Linear(num_ftrs, superclasses)
self.linear_bird = nn.Linear(num_ftrs, classes_bird)
self.linear_boat = nn.Linear(num_ftrs, classes_boat)
self.linear_car = nn.Linear(num_ftrs, classes_car)
self.linear_cat = nn.Linear(num_ftrs, classes_cat)
self.linear_fungus = nn.Linear(num_ftrs, classes_fungus)
self.linear_insect = nn.Linear(num_ftrs, classes_insect)
self.linear_monkey = nn.Linear(num_ftrs, classes_monkey)
self.linear_truck = nn.Linear(num_ftrs, classes_truck)
self.linear_dog = nn.Linear(num_ftrs, classes_dog)
self.linear_fruit = nn.Linear(num_ftrs, classes_fruit)
self.dequant = torch.quantization.DeQuantStub()
def forward(self,x):
x = self.quant(x)
x = self.base(x)
x = torch.flatten(x, 1)
if task == 'SUB':
x = self.linear_sub(x)
elif task == 'BIRD':
x = self.linear_bird(x)
elif task == 'BOAT':
x = self.linear_boat(x)
elif task == 'CAR':
x = self.linear_car(x)
elif task == 'CAT':
x = self.linear_cat(x)
elif task == 'FUNGUS':
x = self.linear_fungus(x)
elif task == 'INSECT':
x = self.linear_insect(x)
elif task == 'MONKEY':
x = self.linear_monkey(x)
elif task == 'TRUCK':
x = self.linear_truck(x)
elif task == 'DOG':
x = self.linear_dog(x)
else:
x = self.linear_fruit(x)
x = self.dequant(x)
return x
Here is the model settings according to the tutorial:
model_one = Resnet18_ONE()
model_one.train()
model_one.qconfig = torch.quantization.get_default_qat_qconfig(âfbgemmâ)
model_one_fused = torch.quantization.fuse_modules(model_one,[[âbase.0.weightâ,âbase.1.weightâ,âbase.1.biasâ,âbase.4.0.conv1.weightâ,âbase.4.0.bn1.weightâ,âbase.4.0.bn1.biasâ,âbase.4.0.conv2.weightâ,âbase.4.0.bn2.weightâ,âbase.4.0.bn2.biasâ,âbase.4.1.conv1.weightâ,âbase.4.1.bn1.weightâ,âbase.4.1.bn1.biasâ,âbase.4.1.conv2.weightâ,âbase.4.1.bn2.weightâ,âbase.4.1.bn2.biasâ,âbase.5.0.conv1.weightâ,âbase.5.0.bn1.weightâ,âbase.5.0.bn1.biasâ,âbase.5.0.conv2.weightâ,âbase.5.0.bn2.weightâ,âbase.5.0.bn2.biasâ,âbase.5.0.downsample.0.weightâ,âbase.5.0.downsample.1.weightâ,âbase.5.0.downsample.1.biasâ,âbase.5.1.conv1.weightâ,âbase.5.1.bn1.weightâ,âbase.5.1.bn1.biasâ,âbase.5.1.conv2.weightâ,âbase.5.1.bn2.weightâ,âbase.5.1.bn2.biasâ,âbase.6.0.conv1.weightâ,âbase.6.0.bn1.weightâ,âbase.6.0.bn1.biasâ,âbase.6.0.conv2.weightâ,âbase.6.0.bn2.weightâ,âbase.6.0.bn2.biasâ,âbase.6.0.downsample.0.weightâ,âbase.6.0.downsample.1.weightâ,âbase.6.0.downsample.1.biasâ,âbase.6.1.conv1.weightâ,âbase.6.1.bn1.weightâ,âbase.6.1.bn1.biasâ,âbase.6.1.conv2.weightâ,âbase.6.1.bn2.weightâ,âbase.6.1.bn2.biasâ,âbase.7.0.conv1.weightâ,âbase.7.0.bn1.weightâ,âbase.7.0.bn1.biasâ,âbase.7.0.conv2.weightâ,âbase.7.0.bn2.weightâ,âbase.7.0.bn2.biasâ,âbase.7.0.downsample.0.weightâ,âbase.7.0.downsample.1.weightâ,âbase.7.0.downsample.1.biasâ,âbase.7.1.conv1.weightâ,âbase.7.1.bn1.weightâ,âbase.7.1.bn1.biasâ,âbase.7.1.conv2.weightâ,âbase.7.1.bn2.weightâ,âbase.7.1.bn2.biasâ,âlinear_subâ,âlinear_birdâ,âlinear_boatâ,âlinear_carâ,âlinear_catâ,âlinear_fungusâ,âlinear_insectâ,âlinear_monkeyâ,âlinear_truckâ,âlinear_dogâ,âlinear_fruitâ]])
model_one_prepared = torch.quantization.prepare(model_one_fused)