I use bert model for multi level text classification (6 classes) batch_size=256
pred output for single post=[0.6 0.4 0.8 0.2 0.3 0.1] dim for batch=(256,6)
true output =2 for single post dim for batch=(256)
I want to use dice_loss so I found this code
from mxnet import nd, np
import numpy as np
smooth = 10
def dice_loss(y_pred, y_true):
product = np.multiply(y_pred, y_true)
intersection = np.sum(product)
coefficient = (2.*intersection +smooth) / (np.sum(y_pred)+np.sum(y_true) +smooth)
loss = 1. - coefficient
# or "-coefficient"
return (torch.tensor(loss, requires_grad=True))
but it seems for binary classification not multi one
so how to modify it to work for multi classification?