Adding in the weight parameter for PyTorch's cross-entropy loss causes datatype RuntimeError

I’m currently using PyTorch to train a neural network. The dataset that I’m using is a binary classification dataset with a large number of 0’s.

I decided to try and use the weight parameter of PyTorch’s cross-entropy loss. I calculated the weights via sklearn.utils.class_weight.compute_class_weight and got weight values of [0.58479532, 3.44827586].

When I added this class_weights tensor into the weight parameter of my loss (i.e., criterion = nn.CrossEntropyLoss(weight=class_weights), I’m suddenly getting a RuntimeError: expected scalar type Float but found Double. The outputs and labels that I’m feeding into my loss are of types float32 and int64, respectively. The loss was working fine but when I add the weight parameter I’m getting this error. Attempting to cast my data via outputs.float() doesn’t seem to work either.

Does anybody know why this error may be occurring and how I might fix this? Thanks.

Most likely you are passing the class_weights as float64 as this is the default dtype in numpy which raises the error as seen here:

x = torch.randn(2, 2, requires_grad=True)
y = torch.randint(0, 2, (2,))

criterion = nn.CrossEntropyLoss()
loss = criterion(x, y) # works

weight = torch.tensor([0.1, 0.2], dtype=torch.float32) # float32 is the default so specifying it is just for clarity
criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(x, y) # works

weight = weight.double()
criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(x, y)
# RuntimeError: expected scalar type Float but found Double

Transform the weight to float32 and it should work.