I’m trying to reduce the memory used by my u-net network, in order to increase the batch size and increase the speed of convergence. It’s been very tricky so far but one of the biggest savings was to use float16 instead of float32.
However, I’m having trouble with the Cross Entropy Loss function - I’m getting NaNs from the first go.
def loss_func(result, target) -> torch.Tensor: class_weights = torch.tensor( [0.1, 1.0, 1.0, 1.0, 1.0], dtype=torch.float16, device=result.device) criterion = nn.CrossEntropyLoss(weight=class_weights) loss = criterion(result, target) return loss
result and target have a number of 0 values but no NaNs or Infs or anything like that. The only thing I can think of is perhaps the target mask being of type long might not work with the input and unet type of float16.
Changing to float32 seems to fix things, so I’m wondering if theres something I’ve missed or is multi-class U-Net cross entropy loss just not compatible with half precision?