As title suggests I have a modified mnist dataset with multiple digits per image 0-3 and each digit can be 0-9.

I have modelled it as a multi label classification problem and am using cross entropy on 9*3 labels for each sample.

Is this the right way to go about?

My training accuracy is going down but validation accuracy is monotonically increasing

I think I am supposed to do something like cross entropy among the 10 labels for each digit seperately and average them out or something

But I am not able to formulate it.

`nn.CrossEntropyLoss`

is used for multi-class classification use cases and I’m unsure how your model output and target looks exactly. For multi-label classification use cases you would commonly use `nn.BCEWithLogitsLoss`

adding another “no-class” class so that your output and target would have the shape `[batch_size, nb_total_classes=11*3]`

or with another dimension as `[batch_size, num_digits, num_classes=11]`

where `num_classes = 10 digit classes + 1 "background/no-class"`

class. Let me know if this works.

Since there were 3 labels to be predicted and each could be 0-9 independent of the other

I used average of the cross entropy per label as the loss function and it worked.

I don’t think binary cross entropy is suitable for this case.