Concatenate feature maps of different size from layers (P-Net)

Hi,

I am currently implementing the 3D variant of the P-Net from the paper DeepIGeoS: [1707.00652] DeepIGeoS: A Deep Interactive Geodesic Framework for Medical Image Segmentation

The architecture looks like the following:

At the end of each block they are feeding the feature maps through a 1x1x1 conv in order to compress it and then concatenate all feature maps over their channel dim before feeding it to the classfier block.

However, as it turns out the compressed feature maps have all different sizes obviously and can therefore not be concatenated over their channels. The paper does not state anything in detail on how the feature maps are concatenated.

Is there a trick I am missing?
Right now I bypass the problem by upsampling all feature maps to the input shape of the original input, but I don’t know if that is the correct solution. There is also no mention of upsampling in the paper.

Thanks in advance!

It’s not completely clear, but I guess the authors make sure to keep the spatial size equal as described in 3.2:

[…]
To obtain an exponential increase of the receptive field,VGG-16 uses a max-pooling and downsampling layer after each block. However, this implementation would decrease the resolution of feature maps exponentially. Therefore, to preserve resolution through the network, we remove the max-pooling and downsampling layers and use dilated convolution in each block.
[…]

Thanks for the answer! However, I am using already the dilated convolutions and the feature maps still get smaller with each block.

Should dilated convolutions prevent this normally?
Did I implement something wrong?

Here is my implementation: (Note: This is currently still with the upsampling at the end)

class P_Net(nn.Module):
    def __init__(self, in_channels=2, out_channels=16):  # or out_channels = 64
        super(P_Net, self).__init__()

        self.block1 = nn.Sequential(
          nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1]),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1]), # or kernel_size=[3, 3, 3]
          nn.ReLU(),
        )
        self.block2 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[2, 2, 2]),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[2, 2, 2]), # or kernel_size=[3, 3, 3]
          nn.ReLU(),
        )
        self.block3 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[3, 3, 3]), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[3, 3, 3]),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[3, 3, 3]),
          nn.ReLU(),
        )
        self.block4 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[4, 4, 4]), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[4, 4, 4]),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[4, 4, 4]),
          nn.ReLU(),
        )
        self.block5 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[5, 5, 5]), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[5, 5, 5]),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=[3, 3, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[5, 5, 5]),
          nn.ReLU(),
        )
        self.block6 = nn.Sequential(
          nn.Conv3d(in_channels=int(out_channels/4)*5, out_channels=out_channels, kernel_size=[1, 1, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1]), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=2, kernel_size=[3, 3, 3], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1]),
          nn.ReLU(),
        )

        self.compress1 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=[1, 1, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1]),
          nn.ReLU(),
        )
        self.compress2 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=[1, 1, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1]),
          nn.ReLU(),
        )
        self.compress3 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=[1, 1, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1]),
          nn.ReLU(),
        )
        self.compress4 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=[1, 1, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1]),
          nn.ReLU(),
        )
        self.compress5 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=[1, 1, 1], stride=[1, 1, 1], padding=[0, 0, 0], dilation=[1, 1, 1]),
          nn.ReLU(),
        )

        self.upsample1 = nn.Upsample(size=[96, 160, 160], mode='trilinear', align_corners=False)
        self.upsample2 = nn.Upsample(size=[96, 160, 160], mode='trilinear', align_corners=False)
        self.upsample3 = nn.Upsample(size=[96, 160, 160], mode='trilinear', align_corners=False)
        self.upsample4 = nn.Upsample(size=[96, 160, 160], mode='trilinear', align_corners=False)
        self.upsample5 = nn.Upsample(size=[96, 160, 160], mode='trilinear', align_corners=False)
        self.upsample6 = nn.Upsample(size=[96, 160, 160], mode='trilinear', align_corners=False)

    def forward(self, x):
        x = self.block1(x)
        compress1 = self.compress1(x)
        x = self.block2(x)
        compress2 = self.compress2(x)
        x = self.block3(x)
        compress3 = self.compress3(x)
        x = self.block4(x)
        compress4 = self.compress4(x)
        x = self.block5(x)
        compress5 = self.compress5(x)
        compress1 = self.upsample1(compress1)
        compress2 = self.upsample2(compress2)
        compress3 = self.upsample3(compress3)
        compress4 = self.upsample4(compress4)
        compress5 = self.upsample5(compress5)
        x = torch.cat((compress1, compress2, compress3, compress4, compress5), dim=1)
        x = self.block6(x)
        x = self.upsample6(x)
        return x

To be sure, you could contact the authors and in the meantime increase the padding etc. to keep the spatial shape equal.

I contacted the author but did not receive an answer sadly. Currently I have the problem that the model is not learning. When given random input the prediction and the loss is always the same. I implemented the P-Net once with upsampling and once with padding. Neither does learn.
Any idea why it does not learn?

Version with upsampling:

class P_Net(nn.Module):
    def __init__(self, in_channels=2, out_channels=16, deep_supervision=False):  # or out_channels = 16/64
        super(P_Net, self).__init__()

        self.do_ds = False

        self.block1 = nn.Sequential(
          nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=1),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=1), # or kernel_size=[3, 3, 3]
          nn.ReLU(),
        )
        self.block2 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=2),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=2), # or kernel_size=[3, 3, 3]
          nn.ReLU(),
        )
        self.block3 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=3), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=3),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=3),
          nn.ReLU(),
        )
        self.block4 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=4), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=4),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=4),
          nn.ReLU(),
        )
        self.block5 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=5), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=5),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=5),
          nn.ReLU(),
        )
        self.block6 = nn.Sequential(
          nn.Conv3d(in_channels=int(out_channels/4)*5, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=1), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=2, kernel_size=3, stride=1, padding=0, dilation=1),
          # nn.ReLU(),
        )

        self.compress1 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress2 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress3 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress4 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress5 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )

        self.upsample1 = nn.Upsample(size=[96, 160, 160], mode='trilinear')
        self.upsample2 = nn.Upsample(size=[96, 160, 160], mode='trilinear')
        self.upsample3 = nn.Upsample(size=[96, 160, 160], mode='trilinear')
        self.upsample4 = nn.Upsample(size=[96, 160, 160], mode='trilinear')
        self.upsample5 = nn.Upsample(size=[96, 160, 160], mode='trilinear')
        self.upsample6 = nn.Upsample(size=[96, 160, 160], mode='trilinear')

    def forward(self, x):
        x = self.block1(x)
        compress1 = self.compress1(x)
        x = self.block2(x)
        compress2 = self.compress2(x)
        x = self.block3(x)
        compress3 = self.compress3(x)
        x = self.block4(x)
        compress4 = self.compress4(x)
        x = self.block5(x)
        compress5 = self.compress5(x)
        compress1 = self.upsample1(compress1)
        compress2 = self.upsample2(compress2)
        compress3 = self.upsample3(compress3)
        compress4 = self.upsample4(compress4)
        compress5 = self.upsample5(compress5)
        x = torch.cat((compress1, compress2, compress3, compress4, compress5), dim=1)
        x = self.block6(x)
        x = self.upsample6(x)
        return x

Version with padding:

class P_Net(nn.Module):
    def __init__(self, in_channels=2, out_channels=16, deep_supervision=False):  # or out_channels = 16/64
        super(P_Net, self).__init__()

        self.do_ds = False

        self.block1 = nn.Sequential(
          nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=1),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=1), # or kernel_size=[3, 3, 3]
          nn.ReLU(),
        )
        self.block2 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=2),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=2), # or kernel_size=[3, 3, 3]
          nn.ReLU(),
        )
        self.block3 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=3), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=3),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=3),
          nn.ReLU(),
        )
        self.block4 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=4), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=4),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=4),
          nn.ReLU(),
        )
        self.block5 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=5), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=5),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=5),
          nn.ReLU(),
        )
        self.block6 = nn.Sequential(
          nn.Conv3d(in_channels=int(out_channels/4)*5, out_channels=out_channels, kernel_size=3, stride=1, padding=0, dilation=1), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=2, kernel_size=3, stride=1, padding=0, dilation=1),
          # nn.ReLU(),
        )

        self.compress1 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress2 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress3 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress4 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress5 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )

        self.pad1 = nn.ReplicationPad3d((2, 2, 2, 2, 2, 2))
        self.pad2 = nn.ReplicationPad3d((4, 4, 4, 4, 4, 4))
        self.pad3 = nn.ReplicationPad3d((9, 9, 9, 9, 9, 9))
        self.pad4 = nn.ReplicationPad3d((12, 12, 12, 12, 12, 12))
        self.pad5 = nn.ReplicationPad3d((15, 15, 15, 15, 15, 15))
        self.pad6 = nn.ReplicationPad3d((2, 2, 2, 2, 2, 2))

    def forward(self, x):
        x = self.block1(x)
        x = self.pad1(x)
        compress1 = self.compress1(x)
        x = self.block2(x)
        x = self.pad2(x)
        compress2 = self.compress2(x)
        x = self.block3(x)
        x = self.pad3(x)
        compress3 = self.compress3(x)
        x = self.block4(x)
        x = self.pad4(x)
        compress4 = self.compress4(x)
        x = self.block5(x)
        x = self.pad5(x)
        compress5 = self.compress5(x)
        x = torch.cat((compress1, compress2, compress3, compress4, compress5), dim=1)
        x = self.block6(x)
        x = self.pad6(x)
        return x

For training I use the following test code:

import torch
import torch.nn as nn

    model = P_Net()
    model = model.to("cuda:5")
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.BCEWithLogitsLoss()
    model.train()

    while True:
        input = torch.rand((1, 2, 96, 160, 160)).to("cuda:5")
        label = torch.rand((1, 2, 96, 160, 160)).to("cuda:5")
        prediction = model(input)
        loss = criterion(prediction, label)
        optimizer.zero_grad()
        loss.backward()

Upsampling or padding the activations only before concatenation seems to be a bit strange.
Could you add the padding to the conv layers directly and/or pass the upsampled activations to the next layer?
At the moment all layer still see the “smaller” activation, while the padding/upsampling is only performed in the path leading to the concatenation.

I added the padding directly to the layers, but the problem remains the same.

class P_Net(nn.Module):
    def __init__(self, in_channels=2, out_channels=16):  # or out_channels = 16/64
        super(P_Net, self).__init__()

        self.block1 = nn.Sequential(
          nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, dilation=1),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, dilation=1), # or kernel_size=[3, 3, 3]
          nn.ReLU(),
        )
        self.block2 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=2, dilation=2),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=2, dilation=2), # or kernel_size=[3, 3, 3]
          nn.ReLU(),
        )
        self.block3 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=3, dilation=3), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=3, dilation=3),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=3, dilation=3),
          nn.ReLU(),
        )
        self.block4 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=4, dilation=4), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=4, dilation=4),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=4, dilation=4),
          nn.ReLU(),
        )
        self.block5 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=5, dilation=5), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=5, dilation=5),
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=5, dilation=5),
          nn.ReLU(),
        )
        self.block6 = nn.Sequential(
          nn.Conv3d(in_channels=int(out_channels/4)*5, out_channels=out_channels, kernel_size=3, stride=1, padding=1, dilation=1), # or kernel_size=[3, 3, 1]
          nn.ReLU(),
          nn.Conv3d(in_channels=out_channels, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1),
          # nn.ReLU(),
        )

        self.compress1 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress2 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress3 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress4 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )
        self.compress5 = nn.Sequential(
          nn.Conv3d(in_channels=out_channels, out_channels=int(out_channels/4), kernel_size=1, stride=1, padding=0, dilation=1),
          nn.ReLU(),
        )

    def forward(self, x):
        x = self.block1(x)
        compress1 = self.compress1(x)
        x = self.block2(x)
        compress2 = self.compress2(x)
        x = self.block3(x)
        compress3 = self.compress3(x)
        x = self.block4(x)
        compress4 = self.compress4(x)
        x = self.block5(x)
        compress5 = self.compress5(x)
        x = torch.cat((compress1, compress2, compress3, compress4, compress5), dim=1)
        x = self.block6(x)
        return x

You could try to overfit a tiny dataset by playing around with hyperparameters. Once your model is able to overfit this dataset, you could then try to scale up the use case again.

It is learning now, tank you. However, it seems to be extremly slow. I compared the P-Net with only 93878 parameters against a normal U-Net with 30786048 parameters (same data, same hardware). Even though the P-Net has so few parameters it is training 3x slower than the U-Net.
Why could this be the case?
Could it be because in my forward I am computing the compression first and only then it is forwared to the next block? I am currently looking into torch.cuda.Stream(). Can this be used for speeding it up on the same gpu?