Pytorch Quantization Aware Training

Hi:

I am trying to use quantization aware training for my CNN network. What I want to do is I load a pretrained RestNet18 and finetune it with other dataset. I run into some problem to quantitize the network during training.

I followed the tutorial (pytorch.org/docs/stable/quantization.html).

However, I run into an NotImplementedError: Cannot fuse modules: in fusemodule.py:

Could you help me understand this error?

Many thanks
Shixian Wen

Error Messages

Traceback (most recent call last):
File “attetion_delta_quantization.py”, line 468, in
model_one_fused = torch.quantization.fuse_modules(model_one,[[‘base.0.weight’,‘base.1.weight’,‘base.1.bias’,‘base.4.0.conv1.weight’,‘base.4.0.bn1.weight’,‘base.4.0.bn1.bias’,‘base.4.0.conv2.weight’,‘base.4.0.bn2.weight’,‘base.4.0.bn2.bias’,‘base.4.1.conv1.weight’,‘base.4.1.bn1.weight’,‘base.4.1.bn1.bias’,‘base.4.1.conv2.weight’,‘base.4.1.bn2.weight’,‘base.4.1.bn2.bias’,‘base.5.0.conv1.weight’,‘base.5.0.bn1.weight’,‘base.5.0.bn1.bias’,‘base.5.0.conv2.weight’,‘base.5.0.bn2.weight’,‘base.5.0.bn2.bias’,‘base.5.0.downsample.0.weight’,‘base.5.0.downsample.1.weight’,‘base.5.0.downsample.1.bias’,‘base.5.1.conv1.weight’,‘base.5.1.bn1.weight’,‘base.5.1.bn1.bias’,‘base.5.1.conv2.weight’,‘base.5.1.bn2.weight’,‘base.5.1.bn2.bias’,‘base.6.0.conv1.weight’,‘base.6.0.bn1.weight’,‘base.6.0.bn1.bias’,‘base.6.0.conv2.weight’,‘base.6.0.bn2.weight’,‘base.6.0.bn2.bias’,‘base.6.0.downsample.0.weight’,‘base.6.0.downsample.1.weight’,‘base.6.0.downsample.1.bias’,‘base.6.1.conv1.weight’,‘base.6.1.bn1.weight’,‘base.6.1.bn1.bias’,‘base.6.1.conv2.weight’,‘base.6.1.bn2.weight’,‘base.6.1.bn2.bias’,‘base.7.0.conv1.weight’,‘base.7.0.bn1.weight’,‘base.7.0.bn1.bias’,‘base.7.0.conv2.weight’,‘base.7.0.bn2.weight’,‘base.7.0.bn2.bias’,‘base.7.0.downsample.0.weight’,‘base.7.0.downsample.1.weight’,‘base.7.0.downsample.1.bias’,‘base.7.1.conv1.weight’,‘base.7.1.bn1.weight’,‘base.7.1.bn1.bias’,‘base.7.1.conv2.weight’,‘base.7.1.bn2.weight’,‘base.7.1.bn2.bias’,‘linear_sub’,‘linear_bird’,‘linear_boat’,‘linear_car’,‘linear_cat’,‘linear_fungus’,‘linear_insect’,‘linear_monkey’,‘linear_truck’,‘linear_dog’,‘linear_fruit’]])
File “/home/shixian/anaconda3/envs/pytorchnew/lib/python3.7/site-packages/torch/quantization/fuse_modules.py”, line 198, in fuse_modules
_fuse_modules(model, module_list, fuser_func)
File “/home/shixian/anaconda3/envs/pytorchnew/lib/python3.7/site-packages/torch/quantization/fuse_modules.py”, line 141, in _fuse_modules
new_mod_list = fuser_func(mod_list)
File “/home/shixian/anaconda3/envs/pytorchnew/lib/python3.7/site-packages/torch/quantization/fuse_modules.py”, line 124, in fuse_known_modules
raise NotImplementedError(“Cannot fuse modules: {}”.format(types))
NotImplementedError: Cannot fuse modules: (<class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.parameter.Parameter’>, <class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.linear.Linear’>, <class ‘torch.nn.modules.linear.Linear’>)

Here is my CNN model:

class Resnet18_ONE(nn.Module):

def __init__(self):
    super(Resnet18_ONE,self).__init__()
    self.quant = torch.quantization.QuantStub()
    resnet18 = torchvision.models.resnet18(pretrained=True)
    num_ftrs = resnet18.fc.in_features
    self.base= nn.Sequential(*list(resnet18.children())[:-1])
    self.linear_sub = nn.Linear(num_ftrs, superclasses)
    self.linear_bird = nn.Linear(num_ftrs, classes_bird)
    self.linear_boat = nn.Linear(num_ftrs, classes_boat)
    self.linear_car = nn.Linear(num_ftrs, classes_car)
    self.linear_cat = nn.Linear(num_ftrs, classes_cat)
    self.linear_fungus = nn.Linear(num_ftrs, classes_fungus)
    self.linear_insect = nn.Linear(num_ftrs, classes_insect)
    self.linear_monkey = nn.Linear(num_ftrs, classes_monkey)
    self.linear_truck = nn.Linear(num_ftrs, classes_truck)
    self.linear_dog = nn.Linear(num_ftrs, classes_dog)
    self.linear_fruit = nn.Linear(num_ftrs, classes_fruit)
    self.dequant = torch.quantization.DeQuantStub()

def forward(self,x):
    x = self.quant(x)
    x = self.base(x)
    x = torch.flatten(x, 1)
    if task == 'SUB':
        x = self.linear_sub(x)
    elif task == 'BIRD':
        x = self.linear_bird(x)
    elif task == 'BOAT':
        x = self.linear_boat(x)
    elif task == 'CAR':
        x = self.linear_car(x)
    elif task == 'CAT':
        x = self.linear_cat(x)
    elif task == 'FUNGUS':
        x = self.linear_fungus(x)
    elif task == 'INSECT':
        x = self.linear_insect(x)
    elif task == 'MONKEY':
        x = self.linear_monkey(x)
    elif task == 'TRUCK':
        x = self.linear_truck(x)
    elif task == 'DOG':
        x = self.linear_dog(x)
    else:
        x = self.linear_fruit(x)
    x = self.dequant(x)            
    return x

Here is the model settings according to the tutorial:

model_one = Resnet18_ONE()
model_one.train()
model_one.qconfig = torch.quantization.get_default_qat_qconfig(‘fbgemm’)
model_one_fused = torch.quantization.fuse_modules(model_one,[[‘base.0.weight’,‘base.1.weight’,‘base.1.bias’,‘base.4.0.conv1.weight’,‘base.4.0.bn1.weight’,‘base.4.0.bn1.bias’,‘base.4.0.conv2.weight’,‘base.4.0.bn2.weight’,‘base.4.0.bn2.bias’,‘base.4.1.conv1.weight’,‘base.4.1.bn1.weight’,‘base.4.1.bn1.bias’,‘base.4.1.conv2.weight’,‘base.4.1.bn2.weight’,‘base.4.1.bn2.bias’,‘base.5.0.conv1.weight’,‘base.5.0.bn1.weight’,‘base.5.0.bn1.bias’,‘base.5.0.conv2.weight’,‘base.5.0.bn2.weight’,‘base.5.0.bn2.bias’,‘base.5.0.downsample.0.weight’,‘base.5.0.downsample.1.weight’,‘base.5.0.downsample.1.bias’,‘base.5.1.conv1.weight’,‘base.5.1.bn1.weight’,‘base.5.1.bn1.bias’,‘base.5.1.conv2.weight’,‘base.5.1.bn2.weight’,‘base.5.1.bn2.bias’,‘base.6.0.conv1.weight’,‘base.6.0.bn1.weight’,‘base.6.0.bn1.bias’,‘base.6.0.conv2.weight’,‘base.6.0.bn2.weight’,‘base.6.0.bn2.bias’,‘base.6.0.downsample.0.weight’,‘base.6.0.downsample.1.weight’,‘base.6.0.downsample.1.bias’,‘base.6.1.conv1.weight’,‘base.6.1.bn1.weight’,‘base.6.1.bn1.bias’,‘base.6.1.conv2.weight’,‘base.6.1.bn2.weight’,‘base.6.1.bn2.bias’,‘base.7.0.conv1.weight’,‘base.7.0.bn1.weight’,‘base.7.0.bn1.bias’,‘base.7.0.conv2.weight’,‘base.7.0.bn2.weight’,‘base.7.0.bn2.bias’,‘base.7.0.downsample.0.weight’,‘base.7.0.downsample.1.weight’,‘base.7.0.downsample.1.bias’,‘base.7.1.conv1.weight’,‘base.7.1.bn1.weight’,‘base.7.1.bn1.bias’,‘base.7.1.conv2.weight’,‘base.7.1.bn2.weight’,‘base.7.1.bn2.bias’,‘linear_sub’,‘linear_bird’,‘linear_boat’,‘linear_car’,‘linear_cat’,‘linear_fungus’,‘linear_insect’,‘linear_monkey’,‘linear_truck’,‘linear_dog’,‘linear_fruit’]])

model_one_prepared = torch.quantization.prepare(model_one_fused)

fusion should be applied on modules, not parameters, looks like you are applying fusion on parameters (weights and bias)?

Hi Jerry:
Thank you for your reply.
I am using the nightly version pytorch: ‘1.9.0.dev20210422+cu111’, torchvision: ‘0.10.0.dev20210422+cu111’

I tested your idea.

When I set model_one_fused = torch.quantization.fuse_modules(model_one,[[‘base.0’,‘base.1’]]),
the model works fine.

However, as long as I add more modules to it
e.g., model_one_fused = torch.quantization.fuse_modules(model_one,[[‘base.0’,‘base.1’,‘base.4.0.conv1’,‘base.4.0.bn1’,‘base.4.0.conv2’,‘base.4.0.bn2’]])

The system told me:
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
AssertionError: did not find fuser method for: (<class ‘torch.nn.modules.conv.Conv2d’>, <class ‘torch.nn.modules.batchnorm.BatchNorm2d’>, <class ‘torch.nn.modules.conv.Conv2d’>, <class ‘torch.nn.modules.batchnorm.BatchNorm2d’>, <class ‘torch.nn.modules.conv.Conv2d’>, <class ‘torch.nn.modules.batchnorm.BatchNorm2d’>)

Many thanks
Shixian

oh please specify them in list of lists, what are you fusing?
we support conv + bn, conv + bn + relu fusion.

For example if we have 0.conv, 0.bn, 0.relu and 1.conv, 1.bn we will call:
fuse_modules(model, [[“0.conv”, “0.bn”, “0.relu”], [“1.conv”, “1.bn”]])

Hi Jerry:

It solved my problem!
Thank you!

Many thanks
Shixian

1 Like