I use bert model for multi level text classification (6 classes) batch_size=256

pred output for single post=[0.6 0.4 0.8 0.2 0.3 0.1] dim for batch=(256,6)

true output =2 for single post dim for batch=(256)

I want to use dice_loss so I found this code

from mxnet import nd, np

import numpy as np

smooth = 10

def dice_loss(y_pred, y_true):

```
product = np.multiply(y_pred, y_true)
intersection = np.sum(product)
coefficient = (2.*intersection +smooth) / (np.sum(y_pred)+np.sum(y_true) +smooth)
loss = 1. - coefficient
# or "-coefficient"
return (torch.tensor(loss, requires_grad=True))
```

but it seems for binary classification not multi one

so how to modify it to work for multi classification?