Optional submodule in ScriptedModule

Hi! How should I declare an optional BatchNorm submodule in a valid TorchScript way?

class Downsample(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        apply_batchnorm: bool = False,
    ):
        super(Downsample, self).__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=2,
            bias=False,
            padding=1,
        )
        self.apply_batchnorm = apply_batchnorm
        if self.apply_batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        if self.apply_batchnorm:
            x = self.bn(x)
        x = F.leaky_relu(x, 0.3)
        return x

When trying to run torch.jit.script(Downsample(1, 2, 3)) I get following error.

RuntimeError: 
Module 'Downsample' has no attribute 'bn' :
  File "path/to/file.py", line 100
        x = self.conv(x)
        if self.apply_batchnorm:
            x = self.bn(x)
                ~~~~~~~ <--- HERE
        x = F.leaky_relu(x, 0.3)
        return x

I understand that in TorchScript every variable must have single static type. Adding bn: Optional[nn.BatchNorm2d] to the class definition does not help.

This is close, the issue is that when self.apply_batchnorm is false, there is no bn attribute on the module, so it cannot be accessed / checked. So

        if self.apply_batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

turns into

        if self.apply_batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)
        else:
            self.bn = None

Another piece is that to call an optional module, the compiler must be able to figure out that it is not None when it is called, so instead of if self.apply_batchnorm you have to do if self.bn is not None.

This is the full working example

class Downsample(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        apply_batchnorm: bool = False,
    ):
        super(Downsample, self).__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=2,
            bias=False,
            padding=1,
        )
        self.apply_batchnorm = apply_batchnorm
        if self.apply_batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)
        else:
            self.bn = None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        x = F.leaky_relu(x, 0.3)
        return x


torch.jit.script(Downsample(1, 2, 3))
1 Like

U can also use nn.Identity to replace the batchnorm. For example, in __init__, u can write:

if self.apply_batchnorm:
    self.bn = nn.BatchNorm2d(out_channels)
else:
    self.bn = nn.Identity()
2 Likes

Thank you driazati and G.M, I have also found another solution in Github issues in the meantime.

1 Like