Crossentropyloss with Float16

Hello there,
I’m trying to reduce the memory used by my u-net network, in order to increase the batch size and increase the speed of convergence. It’s been very tricky so far but one of the biggest savings was to use float16 instead of float32.

However, I’m having trouble with the Cross Entropy Loss function - I’m getting NaNs from the first go.

def loss_func(result, target) -> torch.Tensor:
    class_weights = torch.tensor(
        [0.1, 1.0, 1.0, 1.0, 1.0], dtype=torch.float16, device=result.device)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    loss = criterion(result, target)
    return loss

result and target have a number of 0 values but no NaNs or Infs or anything like that. The only thing I can think of is perhaps the target mask being of type long might not work with the input and unet type of float16.

Changing to float32 seems to fix things, so I’m wondering if theres something I’ve missed or is multi-class U-Net cross entropy loss just not compatible with half precision?

Cheers
Ben

1 Like

nn.CrossEntropyLoss is considers unsafe in float16 which is why it’s using float32 if you are using the mixed-precision training utilities via torch.cuda.amp.autocast as described here.
Since you are seeing invalid outputs, my best guess is that you are indeed hitting an overlow or other numerical issues.

Thanks for pointing out that page - that explains it.