If there is non_traceable_module_class in prepare_custom_config_dict, qat_swap_modules would still swap the module under the non_traceable_module_class. I think this should be an unexpected behavior, since if a module is non-traceable, it should not be quantized.
example codes:
import torch
from torch import nn
from torch.quantization.quantize_fx import prepare_qat_fx
class UnTraceableModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2)
def forward(self, x):
return self.linear(x)
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 2)
self.untraceable_module = UnTraceableModule()
def forward(self, x):
x = self.linear(x)
x = self.untraceable_module(x)
return x
model = MyModel()
qconfig_dict = {
"": torch.quantization.get_default_qat_qconfig()
}
prepare_custom_config_dict = {
"non_traceable_module_class": [UnTraceableModule]
}
prepared_model = prepare_qat_fx(model.train(), qconfig_dict, prepare_custom_config_dict)
print(type(prepared_model.untraceable_module.linear)) # <---- get <class 'torch.nn.qat.modules.linear.Linear'>