How to fuse custom layers

Hey,

I’m trying to fuse an efficient-unet architecture remodeled from : GitHub - zhoudaxia233/EfficientUnet-PyTorch: A PyTorch 1.0 Implementation of Unet with EfficientNet as encoder

I’ve got a problem on fusing layers inheriting from nn.Conv2D and nn.BatchNorm. Here is one of the block I’m trying to fuse :

(19): MBConvBlock(
    (swish): Swish()
    (_expand_conv): Conv2dSamePadding(208, 1248, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (_bn0): BatchNorm2d(1248, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_depthwise_conv): Conv2dSamePadding(1248, 1248, kernel_size=(5, 5), stride=(1, 1), groups=1248, bias=False)
    (_bn1): BatchNorm2d(1248, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_se_reduce): Conv2dSamePadding(1248, 52, kernel_size=(1, 1), stride=(1, 1))
    (_se_expand): Conv2dSamePadding(52, 1248, kernel_size=(1, 1), stride=(1, 1))
    (_project_conv): Conv2dSamePadding(1248, 208, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (_bn2): BatchNorm2d(208, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  )

Here is the code :

# Layers 
class Conv2dSamePadding(nn.Conv2d):
    """2D Convolutions with same padding
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True, name=None):
        super().__init__(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation, groups=groups,
                         bias=bias)
        self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
        self.name = name

    def forward(self, x):
        input_h, input_w = x.size()[2:]
        kernel_h, kernel_w = self.weight.size()[2:]
        stride_h, stride_w = self.stride
        output_h, output_w = math.ceil(input_h / stride_h), math.ceil(input_w / stride_w)
        pad_h = max((output_h - 1) * self.stride[0] + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
        pad_w = max((output_w - 1) * self.stride[1] + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
        return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)


class BatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, name=None):
        super().__init__(num_features, eps=eps, momentum=momentum, affine=affine,
                         track_running_stats=track_running_stats)
        self.name = name
# fuse code
def fuse_model(self):
        for m in self.modules():
            if type(m) == MBConvBlock:
                torch.quantization.fuse_modules(m, ['_expand_conv', '_bn0'], inplace=True)
                torch.quantization.fuse_modules(m, ['_depthwise_conv', '_bn1'], inplace=True)
                torch.quantization.fuse_modules(m, ['_project_conv', '_bn2'], inplace=True)

Obviously the conv and batchnorm layers are very similar to the classic pytorch ones but are not recognized by the fuse_method and I got the error

AssertionError: did not find fuser method for: (<class ‘efficientunet.layers.Conv2dSamePadding’>, <class ‘efficientunet.layers.BatchNorm2d’>)

Do you know how to make the layers acceptable ?

Found the solution by modifying the pytorch code (in site-packages/torch/ao/quantization/)

It is possible to add your operations in DEFAULT_OP_LIST_TO_FUSER_METHOD of fuser_method_mappings.py

DEFAULT_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {

   (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,

   (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,

   (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,

   (efficientunet.layers.Conv2dSamePadding, efficientunet.layers.BatchNorm2d): fuse_conv_bn,

   (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,

   (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn,

   (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu,

   (nn.Conv1d, nn.ReLU): nni.ConvReLU1d,

   (nn.Conv2d, nn.ReLU): nni.ConvReLU2d,

   (nn.Conv3d, nn.ReLU): nni.ConvReLU3d,

   (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,

   (nn.Linear, nn.ReLU): nni.LinearReLU,

   (nn.BatchNorm2d, nn.ReLU): nni.BNReLU2d,

   (nn.BatchNorm3d, nn.ReLU): nni.BNReLU3d,

}

Hope this helps someone.