Validating UNet architecture

Hi, this is my first time building a model with Pytorch, so I am translating a unit from TensorFlow. Thus, I was wondering if someone could give me a sanity check that the model looks valid:


class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, init_features=64, pooling_steps=2):
        super(UNet, self).__init__()

        features = init_features
        self.encoders = nn.ModuleList()
        self.decoders = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upconv = nn.ModuleList()

        for i in range(pooling_steps):
            self.encoders.append(UNet._block(in_channels if i==0 else features, features * (2**i), name=f"enc{i+1}"))
            self.decoders.insert(0, UNet._block(features * (2**(i+1)), features * (2**i), name=f"dec{i+1}"))
            self.upconv.insert(0, nn.ConvTranspose2d(features * (2**(i+1)), features * (2**i), kernel_size=2, stride=2))

        self.bottleneck = UNet._block(features * (2**(pooling_steps-1)), features * (2**pooling_steps), name="bottleneck")
        self.conv = nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=1)

    def forward(self, x):
        encs = []
        for i, encoder in enumerate(self.encoders):
            x = encoder(x)
            encs.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        for i, decoder in enumerate(self.decoders):
            x = self.upconv[i](x)
            x = torch.cat((x, encs[-(i+1)]), dim=1)
            x = decoder(x)

        return torch.sigmoid(self.conv(x))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=features,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=features),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=features,
                out_channels=features,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=features),
            nn.ReLU(inplace=True),
        )

Also, I was wondering what strategy could I follow in order to speed up inference. I read that it is possible to use half-precision (16 bits), but I am not sure how to implement it. Does it go with the tensor, or is it an attribute of the model?

Any other trick will be welcome.

I don’t see anything obviously wrong in the model implementation.

For mixed-precision training check this tutorial.

1 Like

Thanks for the feedback!