RuntimeError: CUDA error: device-side assert triggered... When train parallel and 4 classes

I am using this code: pytorch-cifar100/train.py at master · weiaicunzai/pytorch-cifar100 · GitHub

I am getting this error, when I increase the image size of the dataset:

/home/user/.conda/envs/bm/lib/python3.7/site-packages/torch/jit/_trace.py:152: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
  if a.grad is not None:
checkpoint_path:  checkpoint/resnet50/128x128_16_0.1_Hair_Color_Wednesday_07_April_2021_23h_21m_26s
Training Epoch: 1 [16/13670]    Loss: 1.5864    Acc: 1.0000     LR: 0.000000
epoch 1 training time consumed: 241.82s, Elapsed time: 0:04:01.824466
/opt/conda/conda-bld/pytorch_1607370128159/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
/opt/conda/conda-bld/pytorch_1607370128159/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [13,0,0] Assertion `t >= 0 && t < n_classes` failed.
Traceback (most recent call last):
  File "train.py", line 289, in <module>
    acc, total_time = eval_training(epoch, total_time)
  File "/home/user/.conda/envs/bm/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "train.py", line 124, in eval_training
    test_loss += loss.item()
RuntimeError: CUDA error: device-side assert triggered

I train a classifier with 4 classes.

Could you check the target indices and make sure they are in the expected range [0, nb_classes-1]?
Based on the error message the criterion raises the exception as the target values are outside of the valid range.

My classes are 0,1,2,3 and should be correct…