How to train the original U-Net model?

Hi guys!

I’m trying to implement and train the original U-Net model, but I’m stuck in when I’m trying to train the model using the ISBI Challenge Dataset.

According with the original U-Net model, the network outputs an image with 2 channels and size of 388 x 388. So, my data loader for training generates a tensor with size of [batch, channels=1, width=572, height=572] for the input images and [batch, channels=2, width=388, width=388] for target/output images.

My problem actually is that when I’m trying to use the nn.CrossEntropyLoss() the following error is raised:

RuntimeError: invalid argument 3: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4 at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THNN/generic/SpatialClassNLLCriterion.c:59

I’m just starting with PyTorch (newbie here)… so, I’ll really appreciate if someone could help me to overcome this problem.

The sourcecode is available on GitHub: https://github.com/dalifreire/cnn_unet_pytorch
https://github.com/dalifreire/cnn_unet_pytorch/blob/master/unet_pytorch.ipynb

Best regards!

How many classes do you have?

In that you say it outputs two channels, I’d assume that you’re making a binary segmentation.

According to https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss, your target should be just a 2D map where the pixels has the value of the appropriate class, i.e. a binary mask with 0’s and 1’s. Hope this makes sense.

Hi Simon,

First of all, thanks for your reply!
I’m trying to implement the same thing presented in the original paper… i.e. two classes used for cell segmentation. So the masks contains only pixels with values ‘0’ and ‘1’.

Here are the point in the sourcecode were I’m creating the masks.

Kind regards

Hi again

I’ve recently done something similar :slight_smile:

I think for the nn.CrossEntropy to work,

your mask/target should have the shape:
torch.Size([batch_size, 388, 388])

while your output from the unet should have the shape:
torch.Size([batch_size, 2, 388, 388])

I’d guess a simple fix is to change this line:

# Crop the mask to the desired output size
mask = transforms.CenterCrop(img_output_size)(mask)

so you don’t get the second channel which the output has :slight_smile:

BR

1 Like

Simon,

Thanks a lot!
I just remove the channel dimension from my masks and everything works well… now I’m generating masks with the shape [width=388, height=388].

After that, I’m working with input images (X), target masks (y) and predicted output masks (y_hat) as follow:

X     --> torch.Size([10, 1, 572, 572])
y     --> torch.Size([10, 388, 388])
y_hat --> torch.Size([10, 2, 388, 388])

But, I don’t understand why target masks (y) and predicted masks (y_hat) must have different shapes? It’s so weird for me…

Here are my jupyter notebook that I’m using in my tests.