Multi-task learning single encoder multi-decoder module

Hello all,
Trying to perform multi-task learning with a single encoder and multi-decoder U-Net architecture taking cue from Y-Net or Fork-Net.

image

I can create the encoder-decoder submodule as following:

model = UNet3D(in_channel=1, n_classes=2)
encoder = nn.Sequential(*list(model.children())[0:11])
decoder = nn.Sequential(*list(model.children())[11:])

But this way the skip connections from the encoder are not being concatenated to the decoder layer outputs.

I am using the U-Net model:

class UNet3D(nn.Module):
    def __init__(self, in_channel, n_classes):
        self.in_channel = in_channel
        self.n_classes = n_classes
        bn = True
        bs = True
        super(UNet3D, self).__init__()
        self.ec0_1_32 = self.encoder(self.in_channel, 32, bias=bs, batchnorm=bn) #True
        self.ec1_32_64 = self.encoder(32, 64, bias=bs, batchnorm=bn)
        self.ec2_64_64 = self.encoder(64, 64, bias=bn, batchnorm=bn)
        self.ec3_64_128 = self.encoder(64, 128, bias=bs, batchnorm=bn)
        self.ec4_128_128 = self.encoder(128, 128, bias=bs, batchnorm=bn)
        self.ec5_128_256 = self.encoder(128, 256, bias=bs, batchnorm=bn)
        self.ec6_256_256 = self.encoder(256, 256, bias=bs, batchnorm=bn)
        self.ec7_256_512 = self.encoder(256, 512, bias=bs, batchnorm=bn)

        self.pool0 = nn.MaxPool3d(2)
        self.pool1 = nn.MaxPool3d(2)
        self.pool2 = nn.MaxPool3d(2)

        self.dc9_512_512 = self.decoder(512, 512, kernel_size=2, stride=2, bias=True)
        self.dc8_768_256 = self.decoder(256 + 512, 256, kernel_size=3, stride=1, padding=1, bias=True)
        self.dc7_256_256 = self.decoder(256, 256, kernel_size=3, stride=1, padding=1, bias=True)
        self.dc6_256_256 = self.decoder(256, 256, kernel_size=2, stride=2, bias=True)
        self.dc5_384_128 = self.decoder(128 + 256, 128, kernel_size=3, stride=1, padding=1, bias=True)
        self.dc4_128_128 = self.decoder(128, 128, kernel_size=3, stride=1, padding=1, bias=True)
        self.dc3_128_128 = self.decoder(128, 128, kernel_size=2, stride=2, bias=True)
        self.dc2_192_64 = self.decoder(64 + 128, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.dc1_64_64 = self.decoder(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.dc0_64_nClasses = self.decoder(64, n_classes, kernel_size=1, stride=1, bias=True)

    def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
                bias=True, batchnorm=False):
        if batchnorm:
            layer = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
                nn.BatchNorm3d(out_channels), #BatchNorm2d
                nn.ReLU())
        else:
            layer = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
                nn.ReLU())
        return layer

    def decoder(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
                output_padding=0, bias=True):
        layer = nn.Sequential(
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
                               padding=padding, output_padding=output_padding, bias=bias),
            nn.ReLU())
        return layer

    def center_crop(self, layer, target_size): # layer size = target size
        """
        :param layer: from encoder path(syn)
        :param target_size: recent output of decoder(dn) [5x1] sized vector
        :return: center_croped layer that matches target_size
        """
        batch_size, n_channels, layer_width, layer_height, layer_depth = layer.size()
        xy1 = (layer_width - target_size[2]) // 2
        xy2 = (layer_height - target_size[3]) // 2
        xy3 = (layer_depth - target_size[4]) //2
        return layer[:, :, xy1:(xy1 + target_size[2]), xy2:(xy2 + target_size[3]), xy3:(xy3 + target_size[4])]

    def forward(self, x):
        e0 = self.ec0_1_32(x)
        syn0 = self.ec1_32_64(e0)
        e1 = self.pool0(syn0)
        e2 = self.ec2_64_64(e1)
        syn1 = self.ec3_64_128(e2)
        del e0, e1, e2
        e3 = self.pool1(syn1)
        e4 = self.ec4_128_128(e3)
        syn2 = self.ec5_128_256(e4)
        del e3, e4
        e5 = self.pool2(syn2)
        e6 = self.ec6_256_256(e5)
        e7 = self.ec7_256_512(e6)
        del e5, e6
        d9_demo = self.dc9_512_512(e7)
        d9 = torch.cat((d9_demo, self.center_crop(syn2, d9_demo.size())), 1) #[16, 512, 10, 10, 10] , syn2
        del e7, syn2, d9_demo
        d8 = self.dc8_768_256(d9)
        d7 = self.dc7_256_256(d8)
        del d9, d8
        d6_demo = self.dc6_256_256(d7)
        d6 = torch.cat((d6_demo, self.center_crop(syn1, d6_demo.size())), 1)    
        del d7, syn1, d6_demo
        d5 = self.dc5_384_128(d6)
        d4 = self.dc4_128_128(d5)
        del d6, d5
        d3_demo = self.dc3_128_128(d4)
        d3 = torch.cat((d3_demo, self.center_crop(syn0, d3_demo.size())), 1)
        del d4, syn0, d3_demo
        d2 = self.dc2_192_64(d3)
        d1 = self.dc1_64_64(d2)
        del d3, d2
        d0 = self.dc0_64_nClasses(d1) 
        out = F.softmax(d0, dim=1)
        # out = torch.sigmoid(d0)
        return out #out

Is there an easier way to create/implement the single encoder-multi-decoder module?
Problem with this architecture:
– computationally heavy, lots of trainable params.