Dice loss for multiclass text classification 1D

I use bert model for multi level text classification (6 classes) batch_size=256
pred output for single post=[0.6 0.4 0.8 0.2 0.3 0.1] dim for batch=(256,6)
true output =2 for single post dim for batch=(256)
I want to use dice_loss so I found this code
from mxnet import nd, np

import numpy as np

smooth = 10

def dice_loss(y_pred, y_true):

product = np.multiply(y_pred, y_true)

intersection = np.sum(product)

coefficient = (2.*intersection +smooth) / (np.sum(y_pred)+np.sum(y_true) +smooth)

loss = 1. - coefficient

# or "-coefficient"

return (torch.tensor(loss, requires_grad=True))

but it seems for binary classification not multi one
so how to modify it to work for multi classification?

this implementation is for image segmentation with 2 spatial dim.
but you should be able to use it by reshaping output of your model and traget to have size NxCxHxW and NxHxW where is W =1 .