Cross Entropy Loss outputting Nan

I am trying to train a model with the Kitti road segmentation dataset: http://www.cvlibs.net/datasets/kitti/eval_road.php.

I am getting Nan from the CrossEntropyLoss module. Notice that it is returning Nan already in the first mini-batch. I already checked my input tensor for Nans and Infs.
The tensor shapes I am giving to the loss func are: (b_size, n_class, h, w) and (b_size, h, w).

When I try to reshape the tensor in the following way:

loss = criterion(prediction.permute(0,2,3,1).contiguous().view(-1, n_class), target.view(-1))

the Nans disappear.
Someone knows what can be happening?
I am afraid that maybe something is still wrong, because I am still having troubles too make the network to converge.

Could you upload the prediction and target tensors so that we could try to reproduce this issue, please?
Also, could you post the output of python -m torch.utils.collect_env?

Thanks for replying @ptrblck.
Here is the prediction and target tensors:
https://drive.google.com/file/d/1XjKMt4kBv8Ipzm0ttWlrp906p8LXdTy1/view?usp=sharing
Here is the environment information:

Collecting environment information...
PyTorch version: 1.8.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 2060
Nvidia driver version: 460.73.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.8.0+cu111
[pip3] torchaudio==0.8.0
[pip3] torchvision==0.9.0+cu111
[conda] Could not collect

Thanks for the data. It seems you are using a manual float16 approach, which could easily overflow.
If I wrap the loss calculation into with torch.cuda.amp.autocast() the result will be valid.
What’s the use case you are not using autocast, but a manual approach?

I am training with half precision because the final hardware I will deploy my model works with half precision.
I was doing this cast manually because I though it was the right way.
Sorry about that.
But there is still one thing without answer. Why when I changed the shape of the tensor the Nans disappeared?

So, to summarize, I shouldn’t train my model with half precision this way:

model = Model()
model.to('cuda').half()
for layer in model.modules():
    if isinstance(layer, torch.nn.BatchNorm2d):
        layer.float()

for i, (x y) in enumerate(trainLoader):
        x = x.to('cuda').half()
        y = y.to('cuda').half()

Where should I put the with torch.cuda.amp.autocast()?

Depending on the shapes, different methods will be called as seen here.
In particular you could compare the methods used in nll_loss_out_frame and nll_loss2d_forward_out_frame, which might have a different order of accumulation (and the latter seems to over-/underflow in float16) and is why they promote to float32 using amp.

You can wrap the model execution in it as described in the examples:

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same dtype autocast chose for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

You could try to manually cast the model to float16, but as the model can easily create invalid outputs, you would also have to verify, that all operations are working correctly.