Multiclass Segmentation

I’m unsure which error you are referring to, as this topic contains now several different issues.
Could you describe what kind of shape error you are seeing?

1 Like

Sorry, I forgot to include the quote.
I was referring to the shape of the model output.

Hi @ptrblck ,
I’m trying to adjust a binary segmentation U-net model, to be able to train a multi-class U-net on the German Asfalt Pavement Distress (GAPs) dataset. I understand from your above reply that I should use use nn.CrossEntropyLoss instead of nn.BCELoss. I did so, but I’m getting the following error:

Traceback (most recent call last):
  File "/content/drive/Othercomputers/My Laptop/crack_segmentation_khanhha/crack_segmentation-master/train_unet_GAPs.py", line 263, in <module>
    train(train_loader, model, criterion, optimizer, validate, args)
  File "/content/drive/Othercomputers/My Laptop/crack_segmentation_khanhha/crack_segmentation-master/train_unet_GAPs.py", line 124, in train
    loss = criterion(masks_probs_flat, true_masks_flat)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py", line 1165, in forward
    label_smoothing=self.label_smoothing)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2996, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: size mismatch (got input: [6422528], target: [802816])

The code files, and the dataset are available through the following link:
https://drive.google.com/drive/folders/14NQdtMXokIixBJ5XizexVECn23Jh9aTM?usp=sharing

The following link is for the last stackoverflow question (before I change the criterion to use nn.CrossEntropyLoss). I’m totally new to pytorch, and I look forward to receiving your valuable advice.

Check the shape requirements for a multi-class segmentation and nn.CrossEntropyLoss from this post.
Based on the error message it seems you are flattening the model output tensor as well as the targets and might thus be creating the shape mismatch.