Can't quantize Linear + Relu

When running this code I can’t fuse Linear + ReLU, though the documentation says that it is possible (

class MyModule(nn.Module):
    def __init__(self):
        self.quant = torch.quantization.QuantStub()
        self.linear = nn.Linear(32, 64)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, input):
        y = self.quant(input)
        y = self.linear(input)
        y = self.relu(y)
        y = self.dequant(y)
        return y

    def fuse_model(self):
      torch.quantization.fuse_modules(self, modules_to_fuse=[[self.linear, self.relu]], inplace=True)

print("Create model...")
model_ = MyModule()
print("Create model... [OK]")
in_ = torch.ones(32, 32)
print("Forward pass...")
y = model_(in_)
print("Forward pass... [OK]")

print("Fusing model... ")
print("Fusing model... [OK]")

I got the following error:

ModuleAttributeError                      Traceback (most recent call last)
<ipython-input-38-7ffe2cc10aa7> in <module>()
     25 model_.eval()
---> 26 model_.fuse_model()

4 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/ in __getattr__(self, name)
    777                 return modules[name]
    778         raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
--> 779             type(self).__name__, name))
    781     def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:

ModuleAttributeError: 'Linear' object has no attribute 'split'
  • Torch version: 1.7.0+cu101
  • Torch vision version: 0.8.1+cu101

I think you need the following:

def fuse_model(self):
      torch.quantization.fuse_modules(self, modules_to_fuse=[["linear", "relu"]], inplace=True)
1 Like

For sure… thanks a lot!