In case your current target shape is [batch_size, c, h, w]
, try to convert it using:
target = torch.argmax(target, 1)
Please find link: Semantic segmentation loss function / shape of prediction and target
In case your current target shape is [batch_size, c, h, w]
, try to convert it using:
target = torch.argmax(target, 1)
Please find link: Semantic segmentation loss function / shape of prediction and target