When running this code I can’t fuse Linear + ReLU, though the documentation says that it is possible (https://pytorch.org/docs/stable/_modules/torch/quantization/fuse_modules.html#fuse_modules)
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.linear = nn.Linear(32, 64)
self.relu = nn.ReLU()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, input):
y = self.quant(input)
y = self.linear(input)
y = self.relu(y)
y = self.dequant(y)
return y
def fuse_model(self):
torch.quantization.fuse_modules(self, modules_to_fuse=[[self.linear, self.relu]], inplace=True)
print("Create model...")
model_ = MyModule()
print("Create model... [OK]")
in_ = torch.ones(32, 32)
print("Forward pass...")
y = model_(in_)
print("Forward pass... [OK]")
print("Fusing model... ")
model_.eval()
model_.fuse_model()
print("Fusing model... [OK]")
I got the following error:
---------------------------------------------------------------------------
ModuleAttributeError Traceback (most recent call last)
<ipython-input-38-7ffe2cc10aa7> in <module>()
24
25 model_.eval()
---> 26 model_.fuse_model()
27
4 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __getattr__(self, name)
777 return modules[name]
778 raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
--> 779 type(self).__name__, name))
780
781 def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
ModuleAttributeError: 'Linear' object has no attribute 'split'
- Torch version: 1.7.0+cu101
- Torch vision version: 0.8.1+cu101