Run time Error in Focal loss of smp.loss

Hi all
i am new in using of Sematic Segmantation models
so i used the smp modsl for smenatic segamtnion amd i built model and loss lik follosing :

model = smp.Unet(encoder_name=‘resnet34’,
encoder_depth=5,
encoder_weights=‘imagenet’,
decoder_use_batchnorm=True,
decoder_channels=(256, 128, 64, 32, 16),
decoder_attention_type=None,
in_channels=3,
classes=5,
activation=None,
aux_params=None)

loss = smp.losses.FocalLoss(mode=‘multiclass’, gamma=2.0)
loss.name = ‘FocalLoss’

and the target mask size is 8x512x512 (contain indices in each pixel represents the class value)
with image size is 8x3x512x512

after run the following code to train the model :

train_epoch = smp.utils.train.TrainEpoch(
model=model,
loss=loss,
metrics= metrics,
optimizer=optimizer,
device=device,
verbose=True,)
train_logs = train_epoch.run(traindataloader)

i got this error:
…/ Deeplearning\lib\site-packages\segmentation_models_pytorch\utils\functional.py", line 34, in iou
intersection = torch.sum(gt * pr)
RuntimeError: The size of tensor a (8) must match the size of tensor b (5) at non-singleton dimension 1

why the gt and pr mismatch ??
how can I overcome this error?

when I Remove the metric list the error is removed but I can’t show any metric during running.
the error comes from : intersection = torch.sum(gt * pr)
the size of gt is 8x512x512
the size of pr is 8x5x512x512
where the 8 is batch size and 5 is the class number

how can i solve this issu ?

I’m not familiar with the smp repository and their metrics, but based on the error message I would suggest to check the inputs to this particular metric function and check what the expected shapes would be (e.g. you might miss an argmax operation or any other op which changes the shape).

1 Like

thank you for replying
actually, I want to calculate the Semantic segmentation metrics like:
MIoU, F1 Score , IOU per class, and Confusion matrix , so i used a already define metric function but it doesn’t get the correct value
I have multiclass (No of the class is 5) and I want to calculate above metric
any suggestion?
thank you in advance