UNet training in pytorch

Hi guys, I’m trying to use UNet to perform the training on breast images.
in particular i have 3 tensors:

  • input tensor that has the shape ([32, 1, 64, 64])
  • labels, that is a tensor of shape ([32])
  • Maps tensor that has the shape ([32, 1, 64, 64])
    The code i used is the following…
 class UNet(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(out_channels),
        )
        return block


    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2,
                                     padding=1, output_padding=1)
        )
        return block

    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(out_channels),
        )
        return block

    def __init__(self, in_channel=1, out_channel=2):
        super(UNet, self).__init__()
        # Encode
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=1)
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=2, in_channels=256, out_channels=512),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(512),
            torch.nn.Conv2d(kernel_size=2, in_channels=512, out_channels=512),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(512),
            torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=1,
                                     output_padding=1)
        )
        # Decode
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer = self.final_block(128, 64, out_channel)

    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        # Encode
        #x = x.view(x.size(0), -1)
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        #print(x.shape, encode_block1.shape, encode_block2.shape, encode_block3.shape, bottleneck1.shape)
        #print('Decode Block 3')
        #print(bottleneck1.shape, encode_block3.shape)
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
        #print(decode_block3.shape)
        #print('Decode Block 2')
        cat_layer2 = self.conv_decode3(decode_block3)
        #print(cat_layer2.shape, encode_block2.shape)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        cat_layer1 = self.conv_decode2(decode_block2)
        #print(cat_layer1.shape, encode_block1.shape)
        #print('Final Layer')
        #print(cat_layer1.shape, encode_block1.shape)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
        #print(decode_block1.shape)
        final_layer = self.final_layer(decode_block1)
        #print(final_layer.shape)
        return final_layer


And here i wrote the part used to call the net..

        for i, data in enumerate(dataloader, 0):

            total_time_data_load += time.time() - t0_data_load

            # get the inputs
            t0_other = time.time()
            inputs, labels, maps = data

            print("...Inputs has shape:", inputs.shape)
            print("...Labels shape:", labels.shape)
            print("...Maps shape:", maps.shape)
         
            # send to GPU
            #inputs, labels = inputs.to(DEVICE, non_blocking=True), labels.to(DEVICE, non_blocking=True)
            inputs, labels, maps = inputs.to(DEVICE, non_blocking=True), labels.to(DEVICE, non_blocking=True), maps.to(DEVICE, non_blocking=True)
            # update data statistics
            if ARGS.data_stats:
                inputs_sum += inputs.sum().detach().cpu()
                inputs_sum_sq += inputs.pow(2).sum().detach().cpu()
                inputs_min = min(inputs_min, inputs.min().detach().cpu())
                inputs_max = max(inputs_max, inputs.max().detach().cpu())

            # zero the parameter gradients
            optimizer.zero_grad()
            total_time_other += time.time() - t0_other

            # forward
            t0_forward = time.time()
            outputs = net(inputs)
            total_time_forward += time.time() - t0_forward

            # backward
            t0_backward = time.time()
            
            print("The selected loss is:", criterion)
            print("new outputs is:", outputs.shape)
            loss = criterion(outputs, labels, maps)



The question in my case is…can i pass 3 elements to the criterion? Because if i try to use the criterion with 3 elements i catch the error: TypeError: forward() takes 3 positional arguments but 4 were given
The loss function used is the CrossEntropyLoss found on Pytorch site.
The batch_size is 32
width= 64
height=64

Can anyone help me to fix this and to start the training? I’m stuck on this problem by 2 weeks… and i’m new/dummy on this context.
Thanks a lot.

I added new mods… In particular:

class UNet(nn.Module):
    def contracting_block(self, in_channels, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(out_channels),
        )
        return block

    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2,
                                     padding=1, output_padding=1)
        )
        return block

    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(mid_channel),
            torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(out_channels),
        )
        return block

    def __init__(self, in_channel=1, out_channel=2):
        super(UNet, self).__init__()
        # Encode
        self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(64, 128)
        self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(128, 256)
        self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=1, ceil_mode=True)
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(512),
            torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512),
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(512),
            torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1,
                                     output_padding=1)
        )
        # Decode
        self.conv_decode3 = self.expansive_block(512, 256, 128)
        self.conv_decode2 = self.expansive_block(256, 128, 64)
        self.final_layer = self.final_block(128, 64, out_channel)

    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            #print(bypass.shape)
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        print("CROP",upsampled.shape, bypass.shape)
        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        # Encode
        #x = x.view(x.size(0), -1)
        encode_block1 = self.conv_encode1(x)
        print("econde block1", encode_block1.shape)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        print("econde pool1", encode_pool1.shape)
        encode_block2 = self.conv_encode2(encode_pool1)
        print("econde block2", encode_block2.shape)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        print("econde pool2", encode_pool2.shape)
        encode_block3 = self.conv_encode3(encode_pool2)
        print("econde block3", encode_block3.shape)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        print("econde pool3", encode_pool3.shape)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        print("Bottleneck1", bottleneck1.shape)
        # Decode
        print('Decode Block 3')
        print(bottleneck1.shape, encode_block3.shape)
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
        print("Decoded block3", decode_block3.shape)
        print('Decode Block 2')
        cat_layer2 = self.conv_decode3(decode_block3)
        print(cat_layer2.shape, encode_block2.shape)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
        cat_layer1 = self.conv_decode2(decode_block2)
        print(cat_layer1.shape, encode_block1.shape)
        print('Final Layer')
        print(cat_layer1.shape, encode_block1.shape)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
        print(decode_block1.shape)
        final_layer = self.final_layer(decode_block1)
        print(final_layer.shape)
        return final_layer

But when i run the code, i received these shapes with the following error…

…Inputs has shape: torch.Size([32, 1, 64, 64])
…Labels shape: torch.Size([32])
…Maps shape: torch.Size([32, 1, 64, 64])
width = 64
height = 64
n_chans = 1
Batch_size= 32
econde block1 torch.Size([32, 64, 60, 60])
econde pool1 torch.Size([32, 64, 30, 30])
econde block2 torch.Size([32, 128, 26, 26])
econde pool2 torch.Size([32, 128, 13, 13])
econde block3 torch.Size([32, 256, 9, 9])
econde pool3 torch.Size([32, 256, 9, 9])
Bottleneck1 torch.Size([32, 256, 10, 10])
Decode Block 3
torch.Size([32, 256, 10, 10]) torch.Size([32, 256, 9, 9])
CROP torch.Size([32, 256, 10, 10]) torch.Size([32, 256, 11, 11])

Traceback (most recent call last):
  File "/Users/.../work.py", line 627, in <module>
    main()
  File "/Users/.../work.py", line 624, in main
    crossvalid()
  File "/Users/.../work.py", line 583, in crossvalid
    train(cross_valid_folder, i)
  File "/Users/.../work.py", line 324, in train
    outputs = net(inputs)
  File "/Users/.../lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/.../mynets.py", line 571, in forward
    decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
  File "/Users/.../mynets.py", line 548, in crop_and_concat
    return torch.cat((upsampled, bypass), 1)
RuntimeError: Sizes of tensors must match except in dimension 1. Got 10 and 11 in dimension 2 (The offending index is 1)

The error should be the dimensions of the tensors printed in the crop_and_concat function…
Can anyone help me?
Thanks a lot.