Here is a function for calculating the weight vector used by CrossEntropyLoss based on a set of targets with integers.
def scaled_weights(targets):
'''
Gets class weights
:param targets: pass in the matrix of size (batch, classes) containing correct targets or
labels as one_hot
:return: returns a vector for weight rescaling for use in cross entropy loss function
'''
N, C = targets.size(0), targets.size(1)
# get vector of total for each class
totals = targets.sum(0)
# get the scaled vector
#address division by zero
totals = torch.where(totals==0, N, totals)
scaled_weight = N/totals/C
return scaled_weight
num_classes = 10
batch_size = 100
# create dummy outputs and integer targets
model_outputs = torch.randn((batch_size, num_classes), requires_grad =True)
targets = torch.empty(batch_size, dtype=torch.long).random_(num_classes)
# pass targets into function, after converting to one_hot
scaled_weight = scaled_weights(F.one_hot(targets, num_classes))
#define loss function
loss_funct = torch.nn.CrossEntropyLoss(weight=scaled_weight)
#calculate loss
loss = loss_funct(model_outputs, targets)