Qat_swap_modules should consider the non_traceable_module_class

If there is non_traceable_module_class in prepare_custom_config_dict, qat_swap_modules would still swap the module under the non_traceable_module_class. I think this should be an unexpected behavior, since if a module is non-traceable, it should not be quantized.

example codes:

import torch
from torch import nn
from torch.quantization.quantize_fx import prepare_qat_fx


class UnTraceableModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 2)

    def forward(self, x):
        return self.linear(x)


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 2)
        self.untraceable_module = UnTraceableModule()

    def forward(self, x):
        x = self.linear(x)
        x = self.untraceable_module(x)
        return x


model = MyModel()
qconfig_dict = {
    "": torch.quantization.get_default_qat_qconfig()
}
prepare_custom_config_dict = {
    "non_traceable_module_class": [UnTraceableModule]
}
prepared_model = prepare_qat_fx(model.train(), qconfig_dict, prepare_custom_config_dict)
print(type(prepared_model.untraceable_module.linear))    # <---- get <class 'torch.nn.qat.modules.linear.Linear'>

made an issue:

I’ll see if I can put together a fix

1 Like