Argmax for semantic segmentation

In case your current target shape is [batch_size, c, h, w], try to convert it using:

target = torch.argmax(target, 1)

Please find link: Semantic segmentation loss function / shape of prediction and target